/*
 * Decompiled with CFR 0.152.
 */
package de.cem.tensorflow.utility;

import de.cem.tensorflow.models.DetectedObject;
import de.cem.tensorflow.models.GraphInputModel;
import de.cem.tensorflow.models.RawDetectedObject;
import de.cem.tensorflow.utility.ImageUtils;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.FloatNdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.ByteDataBuffer;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TUint8;

public class TensorUtils {
    private final SavedModelBundle model;
    private final GraphInputModel graphModel;
    private final BufferedImage image;
    private final Float HIT_THRESHOLD;

    public TensorUtils(SavedModelBundle savedModelBundle, GraphInputModel graphModel, BufferedImage bufferedImage, Float HIT_THRESHOLD) {
        this.model = savedModelBundle;
        this.graphModel = graphModel;
        this.image = bufferedImage;
        this.HIT_THRESHOLD = HIT_THRESHOLD;
    }

    public List<DetectedObject> detectFromBufferedImage() {
        try (Tensor tensor = this.makeImageTensor(this.image);){
            List<Tensor> output = this.getObjectDetectionOutput(tensor);
            List<DetectedObject> list = ImageUtils.sortList(this.extractTensors(output));
            return list;
        }
    }

    public List<Tensor> getObjectDetectionOutput(Tensor input) {
        return this.model.session().runner().feed(this.graphModel.inputTensorName(), input).fetch(this.graphModel.outputDetectionBoxesName()).fetch(this.graphModel.outputDetectionClassesName()).fetch(this.graphModel.outputDetectionScoresName()).fetch(this.graphModel.outputNumDetectsName()).run();
    }

    public List<DetectedObject> extractTensors(List<Tensor> tensorList) {
        if (tensorList.size() == 4) {
            TFloat32 detectionBoxes = (TFloat32)tensorList.get(0);
            TFloat32 detectionClasses = (TFloat32)tensorList.get(1);
            TFloat32 detectionScores = (TFloat32)tensorList.get(2);
            TFloat32 numDetections = (TFloat32)tensorList.get(3);
            int numDetects = (int)numDetections.getFloat(new long[]{0L});
            return this.calculateBoundingBox(numDetects, detectionBoxes, detectionScores, detectionClasses);
        }
        System.out.println("Something went wrong, tensors are missing!");
        return List.of();
    }

    public List<DetectedObject> calculateBoundingBox(int numDetects, TFloat32 detectionBoxes, TFloat32 detectionScores, TFloat32 detectionClasses) {
        if (numDetects > 0) {
            ArrayList<DetectedObject> boxArray = new ArrayList<DetectedObject>();
            for (RawDetectedObject rawDetectedObject : this.getBoxArray(numDetects, detectionBoxes, detectionScores, detectionClasses)) {
                boxArray.add(this.getDetectedObject(rawDetectedObject));
            }
            return boxArray;
        }
        return List.of();
    }

    public DetectedObject getDetectedObject(RawDetectedObject rawDetectedObject) {
        int maxX = Math.round(rawDetectedObject.detectionBox().getFloat(new long[]{3L}) * (float)this.image.getWidth());
        int maxY = Math.round(rawDetectedObject.detectionBox().getFloat(new long[]{2L}) * (float)this.image.getHeight());
        int minX = Math.round(rawDetectedObject.detectionBox().getFloat(new long[]{1L}) * (float)this.image.getWidth());
        int minY = Math.round(rawDetectedObject.detectionBox().getFloat(new long[]{0L}) * (float)this.image.getHeight());
        return new DetectedObject(minX, minY, maxX, maxY, (int)rawDetectedObject.detectionClass(), Math.round(rawDetectedObject.detectionScore() * 100.0f));
    }

    public ArrayList<RawDetectedObject> getBoxArray(int numDetects, TFloat32 _detectionBoxes, TFloat32 detectionScores, TFloat32 detectionClasses) {
        ArrayList<RawDetectedObject> boxArray = new ArrayList<RawDetectedObject>();
        for (int n = 0; n < numDetects; ++n) {
            FloatNdArray detectionBoxes = _detectionBoxes.get(new long[]{0L, n});
            float detectionScore = detectionScores.getFloat(new long[]{0L, n});
            float detectionClass = detectionClasses.getFloat(new long[]{0L, n});
            RawDetectedObject object = new RawDetectedObject(detectionBoxes, detectionScore, detectionClass);
            if (!(detectionScore > this.HIT_THRESHOLD.floatValue())) continue;
            boxArray.add(object);
        }
        return boxArray;
    }

    public Tensor makeImageTensor(BufferedImage img) {
        long BATCH_SIZE = 1L;
        long CHANNELS = 3L;
        if (img.getType() == 13 || img.getType() == 12 || img.getType() == 10 || img.getType() == 11) {
            BufferedImage bgr = new BufferedImage(img.getWidth(), img.getHeight(), 5);
            bgr.getGraphics().drawImage(img, 0, 0, null);
            img = bgr;
        }
        if (img.getType() != 5) {
            img = TensorUtils.convertBufferedImage(img);
        }
        byte[] data = ((DataBufferByte)img.getData().getDataBuffer()).getData();
        this.bgr2rgb(data);
        long[] shape = new long[]{1L, img.getHeight(), img.getWidth(), 3L};
        return Tensor.of(TUint8.class, (Shape)Shape.of((long[])shape), (ByteDataBuffer)DataBuffers.of((byte[])data));
    }

    public static BufferedImage convertBufferedImage(BufferedImage original) {
        BufferedImage newRGB = new BufferedImage(original.getWidth(), original.getHeight(), 5);
        newRGB.createGraphics().drawImage(original, 0, 0, original.getWidth(), original.getHeight(), null);
        return newRGB;
    }

    public void bgr2rgb(byte[] data) {
        for (int i = 0; i < data.length; i += 3) {
            byte tmp = data[i];
            data[i] = data[i + 2];
            data[i + 2] = tmp;
        }
    }
}

