/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.python.transforms;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.extensions.python.PythonExternalTransform;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.util.PythonCallableSource;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RunInference<OutputT>
extends PTransform<PCollection<?>, PCollection<OutputT>> {
    private static final Logger LOG = LoggerFactory.getLogger(RunInference.class);
    private final String modelLoader;
    private final Schema schema;
    private final Map<String, Object> kwargs;
    private final String expansionService;
    private final @Nullable Coder<?> keyCoder;
    private final List<String> extraPackages;

    public static RunInference<Row> of(String modelLoader, Schema.FieldType exampleType, Schema.FieldType inferenceType) {
        Schema schema = Schema.of((Schema.Field[])new Schema.Field[]{Schema.Field.of((String)"example", (Schema.FieldType)exampleType), Schema.Field.of((String)"inference", (Schema.FieldType)inferenceType)});
        return new RunInference<Row>(modelLoader, schema, (Map<String, Object>)ImmutableMap.of(), null, "");
    }

    public static <KeyT> RunInference<KV<KeyT, Row>> ofKVs(String modelLoader, Schema.FieldType exampleType, Schema.FieldType inferenceType, Coder<KeyT> keyCoder) {
        Schema schema = Schema.of((Schema.Field[])new Schema.Field[]{Schema.Field.of((String)"example", (Schema.FieldType)exampleType), Schema.Field.of((String)"inference", (Schema.FieldType)inferenceType)});
        return new RunInference<KV<KeyT, Row>>(modelLoader, schema, (Map<String, Object>)ImmutableMap.of(), keyCoder, "");
    }

    public static RunInference<Row> of(String modelLoader, Schema schema) {
        return new RunInference<Row>(modelLoader, schema, (Map<String, Object>)ImmutableMap.of(), null, "");
    }

    public static <KeyT> RunInference<KV<KeyT, Row>> ofKVs(String modelLoader, Schema schema, Coder<KeyT> keyCoder) {
        return new RunInference<KV<KeyT, Row>>(modelLoader, schema, (Map<String, Object>)ImmutableMap.of(), keyCoder, "");
    }

    public RunInference<OutputT> withKwarg(String key, Object arg) {
        ImmutableMap.Builder builder = ImmutableMap.builder().putAll(this.kwargs).put((Object)key, arg);
        return new RunInference<OutputT>(this.modelLoader, this.schema, (Map<String, Object>)builder.build(), this.keyCoder, this.expansionService);
    }

    public void withExtraPackages(List<String> extraPackages) {
        if (!this.extraPackages.isEmpty()) {
            throw new IllegalArgumentException("Extra packages were already specified");
        }
        this.extraPackages.addAll(extraPackages);
    }

    public RunInference<OutputT> withExpansionService(String expansionService) {
        return new RunInference<OutputT>(this.modelLoader, this.schema, this.kwargs, this.keyCoder, expansionService);
    }

    private RunInference(String modelLoader, Schema schema, Map<String, Object> kwargs, @Nullable Coder<?> keyCoder, String expansionService) {
        this.modelLoader = modelLoader;
        this.schema = schema;
        this.kwargs = kwargs;
        this.keyCoder = keyCoder;
        this.expansionService = expansionService;
        this.extraPackages = new ArrayList<String>();
    }

    private List<String> inferExtraPackagesFromModelHandler() {
        ArrayList<String> extraPackages = new ArrayList<String>();
        if (this.modelLoader.toLowerCase().contains("sklearn")) {
            extraPackages.add("scikit-learn");
            extraPackages.add("pandas");
        } else if (this.modelLoader.toLowerCase().contains("pytorch")) {
            extraPackages.add("torch");
        }
        if (!extraPackages.isEmpty()) {
            LOG.info("Automatically inferred dependencies {} from the provided model handler.", extraPackages);
        }
        return extraPackages;
    }

    public PCollection<OutputT> expand(PCollection<?> input) {
        Object outputCoder = this.keyCoder == null ? RowCoder.of((Schema)this.schema) : KvCoder.of(this.keyCoder, (Coder)RowCoder.of((Schema)this.schema));
        if (this.expansionService.isEmpty() && this.extraPackages.isEmpty()) {
            this.extraPackages.addAll(this.inferExtraPackagesFromModelHandler());
        }
        return (PCollection)input.apply(PythonExternalTransform.from("apache_beam.ml.inference.base.RunInference.from_callable", this.expansionService).withKwarg("model_handler_provider", PythonCallableSource.of((String)this.modelLoader)).withOutputCoder((Coder<?>)outputCoder).withExtraPackages(this.extraPackages).withKwargs(this.kwargs));
    }
}

