/*
 * Decompiled with CFR 0.152.
 */
package com.feedzai.openml.python;

import com.feedzai.openml.data.Instance;
import com.feedzai.openml.data.schema.AbstractValueSchema;
import com.feedzai.openml.data.schema.CategoricalValueSchema;
import com.feedzai.openml.data.schema.DatasetSchema;
import com.feedzai.openml.data.schema.FieldSchema;
import com.feedzai.openml.model.ClassificationMLModel;
import com.feedzai.openml.provider.exception.ModelLoadingException;
import com.feedzai.openml.python.jep.instance.JepInstance;
import com.feedzai.openml.util.data.ClassificationDatasetSchemaUtil;
import com.feedzai.openml.util.data.encoding.EncodingHelper;
import com.google.common.collect.ImmutableList;
import java.io.Serializable;
import java.nio.file.Path;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import java.util.stream.IntStream;
import jep.JepException;
import jep.NDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ClassificationPythonModel
implements ClassificationMLModel {
    private static final Logger logger = LoggerFactory.getLogger(ClassificationPythonModel.class);
    public static final String DEFAULT_CLASSIFY_FUNCTION_NAME = "classify";
    public static final String DEFAULT_GETCLASSDISTRIBUTION_FUNCTION_NAME = "getClassDistribution";
    private final Function<Serializable, Integer> classToIndexConverter;
    private final JepInstance jepInstance;
    private final int[] predictiveFieldIndexes;
    private final DatasetSchema schema;
    private final String id;
    private final String classifyFunctionName;
    private final String getClassDistributionFunctionName;

    public ClassificationPythonModel(JepInstance jepInstance, DatasetSchema schema, String id, String classifyFunctionName, String getClassDistributionFunctionName) {
        int targetIndex = (Integer)schema.getTargetIndex().orElseThrow(() -> new IllegalArgumentException("Python classification models do not support datasets without schema."));
        this.jepInstance = jepInstance;
        this.schema = schema;
        this.predictiveFieldIndexes = IntStream.range(0, schema.getFieldSchemas().size()).filter(index -> index != targetIndex).toArray();
        this.id = id;
        this.classifyFunctionName = classifyFunctionName;
        this.getClassDistributionFunctionName = getClassDistributionFunctionName;
        this.classToIndexConverter = this.getClassToIndexConverter(schema);
    }

    public ClassificationPythonModel(JepInstance jepInstance, DatasetSchema schema, String id) {
        this(jepInstance, schema, id, DEFAULT_CLASSIFY_FUNCTION_NAME, DEFAULT_GETCLASSDISTRIBUTION_FUNCTION_NAME);
    }

    public boolean save(Path dir, String name) {
        return false;
    }

    public DatasetSchema getSchema() {
        return this.schema;
    }

    public double[] getClassDistribution(Instance instance) {
        NDArray result = (NDArray)this.invokeFunction(instance, this.getClassDistributionFunctionName, "numpy.array(%s)");
        Object data = result.getData();
        if (data instanceof float[]) {
            float[] x = (float[])data;
            return IntStream.range(0, x.length).mapToDouble(i -> x[i]).toArray();
        }
        return (double[])data;
    }

    public int classify(Instance instance) {
        int asNotNullable;
        String classValue = (String)this.invokeFunction(instance, this.classifyFunctionName, "str(%s[0])");
        try {
            asNotNullable = this.classToIndexConverter.apply((Serializable)((Object)classValue));
        }
        catch (NullPointerException e) {
            AbstractValueSchema targetVarSchema = this.schema.getTargetFieldSchema().map(FieldSchema::getValueSchema).get();
            Function<CategoricalValueSchema, String> block = targetSchema -> String.format("Unexpected class provided by model: %s. Expected values: %s", classValue, targetSchema.getNominalValues());
            String msg = (String)ClassificationDatasetSchemaUtil.withCategoricalValueSchema((AbstractValueSchema)targetVarSchema, block).orElseThrow(() -> new RuntimeException("The target variable is not a categorical value: " + targetVarSchema));
            logger.error(msg, (Throwable)e);
            throw e;
        }
        return asNotNullable;
    }

    public void validate(JepInstance jepInstance, String id) throws ModelLoadingException {
        try {
            jepInstance.submitEvaluation(jep -> {
                ImmutableList functionNames = ImmutableList.of((Object)this.classifyFunctionName, (Object)this.getClassDistributionFunctionName);
                for (String functionName : functionNames) {
                    if (((Boolean)jep.getValue(String.format("callable(getattr(%s, \"%s\", None))", id, functionName))).booleanValue()) continue;
                    throw new JepException(String.format("Model does not implement %s function.", functionName));
                }
                return null;
            }).get();
        }
        catch (InterruptedException | ExecutionException e) {
            logger.error(e.getMessage(), (Throwable)e);
            throw new ModelLoadingException(e.getMessage(), (Throwable)e);
        }
    }

    public void close() {
        this.jepInstance.stop();
    }

    private Function<Serializable, Integer> getClassToIndexConverter(DatasetSchema schema) {
        AbstractValueSchema targetVariableSchema = schema.getTargetFieldSchema().map(FieldSchema::getValueSchema).get();
        if (!(targetVariableSchema instanceof CategoricalValueSchema)) {
            logger.error("Provided schema's target field is not categorical: {}", (Object)schema);
            throw new IllegalArgumentException("Classification models require Categorical target fields. Got " + targetVariableSchema);
        }
        return EncodingHelper.classToIndexConverter((CategoricalValueSchema)((CategoricalValueSchema)targetVariableSchema));
    }

    private <T> T invokeFunction(Instance instance, String classificationFunction, String pythonResultWrapping) {
        int numberPredictiveFields = this.predictiveFieldIndexes.length;
        double[] data = new double[numberPredictiveFields];
        for (int index = 0; index < numberPredictiveFields; ++index) {
            data[index] = instance.getValue(this.predictiveFieldIndexes[index]);
        }
        NDArray encodedInstance = new NDArray((Object)data, new int[]{1, numberPredictiveFields});
        try {
            return this.jepInstance.submitEvaluation(jep -> {
                String defineFunction = String.format("classification_function = %s.%s", this.id, classificationFunction);
                jep.eval(defineFunction);
                jep.set("encodedInstance", (Object)encodedInstance);
                String callFunction = String.format(pythonResultWrapping, "classification_function(encodedInstance)");
                return jep.getValue(callFunction);
            }).get();
        }
        catch (Exception e) {
            logger.warn("Error during instance evaluation.");
            throw new RuntimeException("Error during instance evaluation.", e);
        }
    }
}

