/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.face.model.facerec;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.opencv.OpenCVImageFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Translator;
import ai.djl.util.Progress;
import cn.smartjavaai.common.entity.DetectionInfo;
import cn.smartjavaai.common.entity.DetectionRectangle;
import cn.smartjavaai.common.entity.DetectionResponse;
import cn.smartjavaai.common.enums.DeviceEnum;
import cn.smartjavaai.common.pool.PredictorFactory;
import cn.smartjavaai.common.utils.FileUtils;
import cn.smartjavaai.common.utils.ImageUtils;
import cn.smartjavaai.face.config.FaceExtractConfig;
import cn.smartjavaai.face.config.FaceModelConfig;
import cn.smartjavaai.face.enums.FaceModelEnum;
import cn.smartjavaai.face.exception.FaceException;
import cn.smartjavaai.face.factory.FaceModelFactory;
import cn.smartjavaai.face.model.facerec.AbstractFaceModel;
import cn.smartjavaai.face.model.facerec.FaceModel;
import cn.smartjavaai.face.translator.FaceFeatureTranslator;
import cn.smartjavaai.face.utils.FaceAlignUtils;
import cn.smartjavaai.face.utils.FaceUtils;
import cn.smartjavaai.face.utils.OpenCVUtils;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.imageio.ImageIO;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.pool2.ObjectPool;
import org.apache.commons.pool2.PooledObjectFactory;
import org.apache.commons.pool2.impl.GenericObjectPool;
import org.opencv.core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FeatureExtractionModel
extends AbstractFaceModel
implements AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(FeatureExtractionModel.class);
    private ObjectPool<Predictor<Image, float[]>> predictorPool;
    private ZooModel<Image, float[]> model;
    private FaceModelConfig config;
    public static final List<Float> mean = Arrays.asList(Float.valueOf(0.5f), Float.valueOf(0.5f), Float.valueOf(0.5f), Float.valueOf(0.5019608f), Float.valueOf(0.5019608f), Float.valueOf(0.5019608f));

    @Override
    public void loadModel(FaceModelConfig config) {
        if (Objects.isNull(config)) {
            throw new FaceException("config\u4e3anull");
        }
        if (Objects.isNull(config.getExtractConfig())) {
            config.setExtractConfig(this.getDefaultConfig());
        } else if (Objects.isNull(config.getExtractConfig().getDetectModel())) {
            throw new FaceException("\u8bf7\u8bbe\u7f6e\u4eba\u8138\u68c0\u6d4b\u6a21\u578b");
        }
        Device device = null;
        if (!Objects.isNull(config.getDevice())) {
            device = config.getDevice() == DeviceEnum.CPU ? Device.cpu() : Device.gpu();
        }
        this.config = config;
        String normalize = mean.stream().map(Object::toString).collect(Collectors.joining(","));
        Criteria faceFeatureCriteria = Criteria.builder().setTypes(Image.class, float[].class).optModelName("face_feature").optModelUrls(StringUtils.isNotBlank((CharSequence)config.getModelPath()) ? null : "https://resources.djl.ai/test-models/pytorch/face_feature.zip").optModelPath(StringUtils.isNotBlank((CharSequence)config.getModelPath()) ? Paths.get(config.getModelPath(), new String[0]) : null).optTranslator((Translator)new FaceFeatureTranslator()).optArgument("normalize", (Object)normalize).optDevice(device).optEngine("PyTorch").optProgress((Progress)new ProgressBar()).build();
        try {
            this.model = faceFeatureCriteria.loadModel();
            this.predictorPool = new GenericObjectPool((PooledObjectFactory)new PredictorFactory(this.model));
            log.info("\u5f53\u524d\u8bbe\u5907: " + this.model.getNDManager().getDevice());
        }
        catch (MalformedModelException | ModelNotFoundException | IOException e) {
            throw new FaceException("\u6a21\u578b\u52a0\u8f7d\u5931\u8d25", e);
        }
    }

    private float[] featureExtraction(Image image) {
        image.getWrappedImage();
        Predictor predictor = null;
        try {
            predictor = (Predictor)this.predictorPool.borrowObject();
            float[] fArray = (float[])predictor.predict((Object)image);
            return fArray;
        }
        catch (Exception e) {
            throw new FaceException("\u76ee\u6807\u68c0\u6d4b\u9519\u8bef", e);
        }
        finally {
            if (predictor != null) {
                try {
                    this.predictorPool.returnObject((Object)predictor);
                }
                catch (Exception e) {
                    log.warn("\u5f52\u8fd8Predictor\u5931\u8d25", (Throwable)e);
                    try {
                        predictor.close();
                    }
                    catch (Exception ex) {
                        log.error("\u5173\u95edPredictor\u5931\u8d25", (Throwable)ex);
                    }
                }
            }
        }
    }

    @Override
    public float calculSimilar(float[] feature1, float[] feature2) {
        float ret = 0.0f;
        float mod1 = 0.0f;
        float mod2 = 0.0f;
        int length = feature1.length;
        for (int i = 0; i < length; ++i) {
            ret += feature1[i] * feature2[i];
            mod1 += feature1[i] * feature1[i];
            mod2 += feature2[i] * feature2[i];
        }
        return (float)(((double)ret / Math.sqrt(mod1) / Math.sqrt(mod2) + 1.0) / 2.0);
    }

    @Override
    public float featureComparison(String imagePath1, String imagePath2) {
        if (!FileUtils.isFileExists((String)imagePath1) || !FileUtils.isFileExists((String)imagePath2)) {
            throw new FaceException("\u56fe\u50cf\u6587\u4ef6\u4e0d\u5b58\u5728");
        }
        float[] feature1 = this.extractTopFaceFeature(imagePath1);
        float[] feature2 = this.extractTopFaceFeature(imagePath2);
        float ret = this.calculSimilar(feature1, feature2);
        return ret;
    }

    @Override
    public float featureComparison(BufferedImage sourceImage1, BufferedImage sourceImag2) {
        if (!ImageUtils.isImageValid((BufferedImage)sourceImage1) || !ImageUtils.isImageValid((BufferedImage)sourceImag2)) {
            throw new FaceException("\u56fe\u50cf\u65e0\u6548");
        }
        float[] feature1 = this.extractTopFaceFeature(sourceImage1);
        float[] feature2 = this.extractTopFaceFeature(sourceImag2);
        return this.calculSimilar(feature1, feature2);
    }

    @Override
    public float featureComparison(byte[] imageData1, byte[] imageData2) {
        if (Objects.isNull(imageData1) || Objects.isNull(imageData2)) {
            throw new FaceException("\u56fe\u50cf\u65e0\u6548");
        }
        float[] feature1 = this.extractTopFaceFeature(imageData1);
        float[] feature2 = this.extractTopFaceFeature(imageData2);
        return this.calculSimilar(feature1, feature2);
    }

    private FaceExtractConfig getDefaultConfig() {
        FaceExtractConfig config = new FaceExtractConfig();
        FaceModelConfig detectModelConfig = new FaceModelConfig();
        detectModelConfig.setModelEnum(FaceModelEnum.ULTRA_LIGHT_FAST_GENERIC_FACE);
        log.debug("\u521b\u5efa\u9ed8\u8ba4\u68c0\u6d4b\u6a21\u578b\uff1aULTRA_LIGHT_FAST_GENERIC_FACE");
        FaceModel detectModel = FaceModelFactory.getInstance().getModel(detectModelConfig);
        log.debug("\u521b\u5efa\u68c0\u6d4b\u6a21\u578b\u5b8c\u6bd5");
        config.setDetectModel(detectModel);
        return config;
    }

    @Override
    public List<float[]> extractFeatures(BufferedImage image) {
        ArrayList<float[]> featureList = new ArrayList<float[]>();
        DetectionResponse detectedResult = this.config.getExtractConfig().getDetectModel().detect(image);
        if (Objects.isNull(detectedResult) || Objects.isNull(detectedResult.getDetectionInfoList()) || detectedResult.getDetectionInfoList().isEmpty()) {
            throw new FaceException("\u672a\u68c0\u6d4b\u5230\u4eba\u8138");
        }
        Image djlImage = ImageFactory.getInstance().fromImage((Object)OpenCVUtils.image2Mat(image));
        NDManager manager = NDManager.newBaseManager();
        for (DetectionInfo detectionInfo : detectedResult.getDetectionInfoList()) {
            DetectionRectangle rectangle = detectionInfo.getDetectionRectangle();
            float[] features = null;
            Image subImage = djlImage.getSubImage(rectangle.getX(), rectangle.getY(), rectangle.getWidth(), rectangle.getHeight());
            if (this.config.getExtractConfig().isAlign()) {
                double[][] pointsArray = FaceUtils.facePoints(detectionInfo.getFaceInfo().getKeyPoints());
                NDArray srcPoints = manager.create(pointsArray);
                NDArray dstPoints = FaceUtils.faceTemplate512x512(manager);
                Mat affine_matrix = OpenCVUtils.toOpenCVMat(manager, srcPoints, dstPoints);
                Mat mat = FaceAlignUtils.warpAffine((Mat)djlImage.getWrappedImage(), affine_matrix);
                Image alignedImg = OpenCVImageFactory.getInstance().fromImage((Object)mat);
                features = this.featureExtraction(alignedImg);
            } else {
                features = this.featureExtraction(subImage);
            }
            if (!Objects.nonNull(features)) continue;
            featureList.add(features);
        }
        return featureList;
    }

    @Override
    public List<float[]> extractFeatures(byte[] imageData) {
        if (Objects.isNull(imageData)) {
            throw new FaceException("\u56fe\u50cf\u65e0\u6548");
        }
        try {
            return this.extractFeatures(ImageIO.read(new ByteArrayInputStream(imageData)));
        }
        catch (IOException e) {
            throw new FaceException("\u9519\u8bef\u7684\u56fe\u50cf", e);
        }
    }

    @Override
    public List<float[]> extractFeatures(String imagePath) {
        if (!FileUtils.isFileExists((String)imagePath)) {
            throw new FaceException("\u56fe\u50cf\u6587\u4ef6\u4e0d\u5b58\u5728");
        }
        BufferedImage image = null;
        try {
            image = ImageIO.read(new File(Paths.get(imagePath, new String[0]).toAbsolutePath().toString()));
        }
        catch (IOException e) {
            throw new FaceException("\u65e0\u6548\u56fe\u7247\u8def\u5f84", e);
        }
        return this.extractFeatures(image);
    }

    @Override
    public float[] extractTopFaceFeature(BufferedImage image) {
        Image djlImage = ImageFactory.getInstance().fromImage((Object)OpenCVUtils.image2Mat(image));
        float[] features = null;
        if (this.config.getExtractConfig().isCropFace()) {
            DetectionResponse detectedResult = this.config.getExtractConfig().getDetectModel().detect(image);
            if (Objects.isNull(detectedResult) || Objects.isNull(detectedResult.getDetectionInfoList()) || detectedResult.getDetectionInfoList().isEmpty()) {
                throw new FaceException("\u672a\u68c0\u6d4b\u5230\u4eba\u8138");
            }
            DetectionInfo detectionInfo = (DetectionInfo)detectedResult.getDetectionInfoList().get(0);
            DetectionRectangle rectangle = detectionInfo.getDetectionRectangle();
            Image subImage = djlImage.getSubImage(rectangle.getX(), rectangle.getY(), rectangle.getWidth(), rectangle.getHeight());
            if (this.config.getExtractConfig().isAlign()) {
                NDManager manager = NDManager.newBaseManager();
                double[][] pointsArray = FaceUtils.facePoints(detectionInfo.getFaceInfo().getKeyPoints());
                NDArray srcPoints = manager.create(pointsArray);
                NDArray dstPoints = FaceUtils.faceTemplate512x512(manager);
                Mat affine_matrix = OpenCVUtils.toOpenCVMat(manager, srcPoints, dstPoints);
                Mat mat = FaceAlignUtils.warpAffine((Mat)djlImage.getWrappedImage(), affine_matrix);
                Image alignedImg = OpenCVImageFactory.getInstance().fromImage((Object)mat);
                features = this.featureExtraction(alignedImg);
            } else {
                features = this.featureExtraction(subImage);
            }
        } else {
            features = this.featureExtraction(djlImage);
        }
        return features;
    }

    @Override
    public float[] extractTopFaceFeature(String imagePath) {
        if (!FileUtils.isFileExists((String)imagePath)) {
            throw new FaceException("\u56fe\u50cf\u6587\u4ef6\u4e0d\u5b58\u5728");
        }
        BufferedImage image = null;
        try {
            image = ImageIO.read(new File(Paths.get(imagePath, new String[0]).toAbsolutePath().toString()));
        }
        catch (IOException e) {
            throw new FaceException("\u65e0\u6548\u56fe\u7247\u8def\u5f84", e);
        }
        return this.extractTopFaceFeature(image);
    }

    @Override
    public float[] extractTopFaceFeature(byte[] imageData) {
        if (Objects.isNull(imageData)) {
            throw new FaceException("\u56fe\u50cf\u65e0\u6548");
        }
        try {
            return this.extractTopFaceFeature(ImageIO.read(new ByteArrayInputStream(imageData)));
        }
        catch (IOException e) {
            throw new FaceException("\u9519\u8bef\u7684\u56fe\u50cf", e);
        }
    }

    @Override
    public void close() {
        if (this.predictorPool != null) {
            this.predictorPool.close();
        }
    }
}

