/*
 * Decompiled with CFR 0.152.
 */
package de.cem.tensorflow.utility;

import java.util.HashMap;
import java.util.Map;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.GPUOptions;

public class ModelLoader {
    private final Map<String, SavedModelBundle> savedModelBundleList = new HashMap<String, SavedModelBundle>();
    private final String modelPath;
    private final boolean useGPU;

    public ModelLoader(String modelPath, boolean useGPU) {
        this.modelPath = modelPath;
        this.useGPU = useGPU;
    }

    public SavedModelBundle getBundle(String modelName, String modelVersion) {
        if (!this.isBundleAdded(modelName, modelVersion)) {
            this.savedModelBundleList.put(this.getModelName(modelName, modelVersion), this.loadBundle(this.getModelPath(modelName, modelVersion)));
        }
        return this.savedModelBundleList.get(this.getModelName(modelName, modelVersion));
    }

    public boolean isBundleAdded(String modelName, String modelVersion) {
        return this.savedModelBundleList.containsKey(this.getModelName(modelName, modelVersion));
    }

    public SavedModelBundle loadBundle(String modelPath) {
        return SavedModelBundle.loader((String)modelPath).withTags(new String[]{"serve"}).withConfigProto(this.getConfigProto()).load();
    }

    public ConfigProto getConfigProto() {
        if (!this.useGPU) {
            return ConfigProto.newBuilder().setAllowSoftPlacement(true).putDeviceCount("GPU", 0).build();
        }
        return ConfigProto.newBuilder().setAllowSoftPlacement(true).setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true).build()).build();
    }

    public String getModelName(String modelVersion, String modelName) {
        return modelName + "_" + modelVersion;
    }

    public String getModelPath(String modelName, String modelVersion) {
        return this.modelPath + "\\" + modelName + "_" + modelVersion + "\\";
    }
}

