/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.face.model.facerec;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Translator;
import ai.djl.util.Progress;
import cn.smartjavaai.common.entity.DetectionResponse;
import cn.smartjavaai.common.enums.DeviceEnum;
import cn.smartjavaai.common.pool.PredictorFactory;
import cn.smartjavaai.common.utils.FileUtils;
import cn.smartjavaai.common.utils.ImageUtils;
import cn.smartjavaai.face.config.FaceModelConfig;
import cn.smartjavaai.face.exception.FaceException;
import cn.smartjavaai.face.model.facerec.AbstractFaceModel;
import cn.smartjavaai.face.translator.FaceDetectionTranslator;
import cn.smartjavaai.face.utils.FaceUtils;
import cn.smartjavaai.face.utils.OpenCVUtils;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Objects;
import javax.imageio.ImageIO;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.pool2.ObjectPool;
import org.apache.commons.pool2.PooledObjectFactory;
import org.apache.commons.pool2.impl.GenericObjectPool;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RetinaFaceModel
extends AbstractFaceModel
implements AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(RetinaFaceModel.class);
    private ObjectPool<Predictor<Image, DetectedObjects>> predictorPool;
    private ZooModel<Image, DetectedObjects> model;
    public static final int[][] scales = new int[][]{{16, 32}, {64, 128}, {256, 512}};
    public static final int[] steps = new int[]{8, 16, 32};
    public static final double[] variance = new double[]{0.1f, 0.2f};

    @Override
    public void loadModel(FaceModelConfig config) {
        Device device = null;
        if (!Objects.isNull(config.getDevice())) {
            device = config.getDevice() == DeviceEnum.CPU ? Device.cpu() : Device.gpu();
        }
        FaceDetectionTranslator translator = new FaceDetectionTranslator(config.getConfidenceThreshold(), config.getNmsThresh(), variance, 5000, scales, steps);
        Criteria criteria = Criteria.builder().setTypes(Image.class, DetectedObjects.class).optModelUrls(StringUtils.isNotBlank((CharSequence)config.getModelPath()) ? null : "https://resources.djl.ai/test-models/pytorch/retinaface.zip").optModelPath(StringUtils.isNotBlank((CharSequence)config.getModelPath()) ? Paths.get(config.getModelPath(), new String[0]) : null).optModelName("retinaface").optTranslator((Translator)translator).optDevice(device).optProgress((Progress)new ProgressBar()).optEngine("PyTorch").build();
        try {
            this.model = criteria.loadModel();
            this.predictorPool = new GenericObjectPool((PooledObjectFactory)new PredictorFactory(this.model));
            log.info("\u5f53\u524d\u8bbe\u5907: " + this.model.getNDManager().getDevice());
        }
        catch (MalformedModelException | ModelNotFoundException | IOException e) {
            throw new FaceException("\u6a21\u578b\u52a0\u8f7d\u5931\u8d25", e);
        }
    }

    @Override
    public DetectionResponse detect(String imagePath) {
        if (!FileUtils.isFileExists((String)imagePath)) {
            throw new FaceException("\u56fe\u50cf\u6587\u4ef6\u4e0d\u5b58\u5728");
        }
        Image img = null;
        try {
            img = ImageFactory.getInstance().fromFile(Paths.get(imagePath, new String[0]));
        }
        catch (IOException e) {
            throw new FaceException("\u65e0\u6548\u7684\u56fe\u7247", e);
        }
        DetectedObjects detection = this.detect(img);
        return FaceUtils.convertToDetectionResponse(detection, img);
    }

    @Override
    public DetectionResponse detect(InputStream imageInputStream) {
        if (Objects.isNull(imageInputStream)) {
            throw new FaceException("\u56fe\u50cf\u8f93\u5165\u6d41\u65e0\u6548");
        }
        try {
            Image img = ImageFactory.getInstance().fromInputStream(imageInputStream);
            DetectedObjects detection = this.detect(img);
            return FaceUtils.convertToDetectionResponse(detection, img);
        }
        catch (IOException e) {
            throw new FaceException("\u65e0\u6548\u56fe\u7247\u8f93\u5165\u6d41", e);
        }
    }

    @Override
    public DetectionResponse detect(BufferedImage image) {
        if (!ImageUtils.isImageValid((BufferedImage)image)) {
            throw new FaceException("\u56fe\u50cf\u65e0\u6548");
        }
        Image img = ImageFactory.getInstance().fromImage((Object)OpenCVUtils.image2Mat(image));
        DetectedObjects detection = this.detect(img);
        return FaceUtils.convertToDetectionResponse(detection, img);
    }

    @Override
    public DetectionResponse detect(byte[] imageData) {
        if (Objects.isNull(imageData)) {
            throw new FaceException("\u56fe\u50cf\u65e0\u6548");
        }
        try {
            return this.detect(ImageIO.read(new ByteArrayInputStream(imageData)));
        }
        catch (IOException e) {
            throw new FaceException("\u9519\u8bef\u7684\u56fe\u50cf", e);
        }
    }

    @Override
    public void detectAndDraw(String imagePath, String outputPath) {
        if (!FileUtils.isFileExists((String)imagePath)) {
            throw new FaceException("\u56fe\u50cf\u6587\u4ef6\u4e0d\u5b58\u5728");
        }
        try {
            Image img = ImageFactory.getInstance().fromFile(Paths.get(imagePath, new String[0]));
            DetectedObjects detectedObjects = this.detect(img);
            if (Objects.isNull(detectedObjects) || detectedObjects.getNumberOfObjects() == 0) {
                throw new FaceException("\u672a\u8bc6\u522b\u5230\u4eba\u8138");
            }
            img.drawBoundingBoxes(detectedObjects);
            Path output = Paths.get(outputPath, new String[0]);
            log.info("Saving to {}", (Object)output.toAbsolutePath().toString());
            img.save(Files.newOutputStream(output, new OpenOption[0]), "png");
        }
        catch (IOException e) {
            throw new FaceException(e);
        }
    }

    @Override
    public BufferedImage detectAndDraw(BufferedImage sourceImage) {
        if (!ImageUtils.isImageValid((BufferedImage)sourceImage)) {
            throw new FaceException("\u56fe\u50cf\u65e0\u6548");
        }
        Image img = ImageFactory.getInstance().fromImage((Object)OpenCVUtils.image2Mat(sourceImage));
        DetectedObjects detectedObjects = this.detect(img);
        if (Objects.isNull(detectedObjects) || detectedObjects.getNumberOfObjects() == 0) {
            throw new FaceException("\u672a\u8bc6\u522b\u5230\u4eba\u8138");
        }
        img.drawBoundingBoxes(detectedObjects);
        try {
            ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
            img.save((OutputStream)outputStream, "png");
            byte[] imageBytes = outputStream.toByteArray();
            return ImageIO.read(new ByteArrayInputStream(imageBytes));
        }
        catch (IOException e) {
            throw new FaceException("\u5bfc\u51fa\u56fe\u7247\u5931\u8d25", e);
        }
    }

    private DetectedObjects detect(Image image) {
        Predictor predictor = null;
        try {
            predictor = (Predictor)this.predictorPool.borrowObject();
            DetectedObjects detectedObjects = (DetectedObjects)predictor.predict((Object)image);
            return detectedObjects;
        }
        catch (Exception e) {
            throw new FaceException("\u76ee\u6807\u68c0\u6d4b\u9519\u8bef", e);
        }
        finally {
            if (predictor != null) {
                try {
                    this.predictorPool.returnObject((Object)predictor);
                }
                catch (Exception e) {
                    log.warn("\u5f52\u8fd8Predictor\u5931\u8d25", (Throwable)e);
                    try {
                        predictor.close();
                    }
                    catch (Exception ex) {
                        log.error("\u5173\u95edPredictor\u5931\u8d25", (Throwable)ex);
                    }
                }
            }
        }
    }

    @Override
    public void close() {
        if (this.predictorPool != null) {
            this.predictorPool.close();
        }
    }
}

