/*
 * Decompiled with CFR 0.152.
 */
package org.numenta.nupic.encoders;

import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.map.TIntObjectMap;
import gnu.trove.map.TObjectIntMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.numenta.nupic.encoders.DecodeResult;
import org.numenta.nupic.encoders.Encoder;
import org.numenta.nupic.encoders.Encoding;
import org.numenta.nupic.encoders.RangeList;
import org.numenta.nupic.encoders.ScalarEncoder;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.MinMax;
import org.numenta.nupic.util.SparseObjectMatrix;
import org.numenta.nupic.util.Tuple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CategoryEncoder
extends Encoder<String> {
    private static final long serialVersionUID = 1L;
    private static final Logger LOG = LoggerFactory.getLogger(CategoryEncoder.class);
    protected int ncategories;
    protected TObjectIntMap<String> categoryToIndex = new TObjectIntHashMap();
    protected TIntObjectMap<String> indexToCategory = new TIntObjectHashMap();
    protected List<String> categoryList;
    protected int width;
    private ScalarEncoder scalarEncoder;

    private CategoryEncoder() {
    }

    public static Encoder.Builder<Builder, CategoryEncoder> builder() {
        return new Builder();
    }

    public void init() {
        block5: {
            this.ncategories = this.categoryList == null ? 0 : this.categoryList.size() + 1;
            this.minVal = 0.0;
            this.maxVal = this.ncategories - 1;
            try {
                this.scalarEncoder = ((ScalarEncoder.Builder)((ScalarEncoder.Builder)((ScalarEncoder.Builder)((ScalarEncoder.Builder)((ScalarEncoder.Builder)((ScalarEncoder.Builder)ScalarEncoder.builder().n(this.n).w(this.w)).radius(this.radius)).minVal(this.minVal)).maxVal(this.maxVal)).periodic(this.periodic)).forced(this.forced)).build();
            }
            catch (Exception e) {
                String msg = null;
                int idx = -1;
                msg = e.getMessage();
                idx = msg.indexOf("ScalarEncoder");
                if (idx == -1) break block5;
                msg = msg.substring(0, idx).concat("CategoryEncoder");
                throw new IllegalStateException(msg);
            }
        }
        this.indexToCategory.put(0, (Object)"<UNKNOWN>");
        if (this.categoryList != null && !this.categoryList.isEmpty()) {
            int len = this.categoryList.size();
            for (int i = 0; i < len; ++i) {
                this.categoryToIndex.put((Object)this.categoryList.get(i), i + 1);
                this.indexToCategory.put(i + 1, (Object)this.categoryList.get(i));
            }
        }
        this.width = this.n = this.w * this.ncategories;
        this.scalarEncoder.n = this.n;
        if (this.getWidth() != this.width) {
            throw new IllegalStateException("Width != w (num bits to represent output item) * #categories");
        }
        this.description.add(new Tuple(this.name, 0));
    }

    @Override
    public <T> TDoubleList getScalars(T d) {
        return new TDoubleArrayList(new double[]{this.categoryToIndex.get(d)});
    }

    @Override
    public int[] getBucketIndices(String input) {
        if (input == null) {
            return null;
        }
        return this.scalarEncoder.getBucketIndices(this.categoryToIndex.get((Object)input));
    }

    @Override
    public void encodeIntoArray(String input, int[] output) {
        String val = null;
        double value = 0.0;
        if (input == null) {
            val = "<missing>";
        } else {
            value = this.categoryToIndex.get((Object)input);
            value = value == (double)this.categoryToIndex.getNoEntryValue() ? 0.0 : value;
            this.scalarEncoder.encodeIntoArray(value, output);
        }
        LOG.trace("input: {}, val: {}, value: {}, output: {}", new Object[]{input, val, value, Arrays.toString(output)});
    }

    @Override
    public DecodeResult decode(int[] encoded, String parentFieldName) {
        DecodeResult result = this.scalarEncoder.decode(encoded, parentFieldName);
        if (result.getFields().size() == 0) {
            return result;
        }
        if (result.getFields().size() != 1) {
            throw new IllegalStateException("Expecting only one field");
        }
        Map<String, RangeList> fieldRanges = result.getFields();
        ArrayList<MinMax> outRanges = new ArrayList<MinMax>();
        StringBuilder desc = new StringBuilder();
        for (String descripStr : fieldRanges.keySet()) {
            int minV;
            MinMax minMax = fieldRanges.get(descripStr).getRange(0);
            int maxV = (int)Math.round(minMax.max());
            outRanges.add(new MinMax(minV, maxV));
            for (minV = (int)Math.round(minMax.min()); minV <= maxV; ++minV) {
                if (desc.length() > 0) {
                    desc.append(", ");
                }
                desc.append((String)this.indexToCategory.get(minV));
            }
        }
        String fieldName = !parentFieldName.isEmpty() ? String.format("%s.%s", parentFieldName, this.name) : this.name;
        HashMap<String, RangeList> retVal = new HashMap<String, RangeList>();
        retVal.put(fieldName, new RangeList((List<MinMax>)outRanges, desc.toString()));
        return new DecodeResult((Map<String, RangeList>)retVal, Arrays.asList(fieldName));
    }

    @Override
    public TDoubleList closenessScores(TDoubleList expValues, TDoubleList actValues, boolean fractional) {
        double actValue;
        double closeness;
        double expValue = expValues.get(0);
        double d = closeness = expValue == (actValue = actValues.get(0)) ? 1.0 : 0.0;
        if (!fractional) {
            closeness = 1.0 - closeness;
        }
        return new TDoubleArrayList(new double[]{closeness});
    }

    @Override
    public <T> List<T> getBucketValues(Class<T> t) {
        if (this.bucketValues == null) {
            SparseObjectMatrix<int[]> topDownMapping = this.scalarEncoder.getTopDownMapping();
            int numBuckets = topDownMapping.getMaxIndex() + 1;
            this.bucketValues = new ArrayList();
            int i = 0;
            while (i < numBuckets) {
                this.bucketValues.add((String)this.getBucketInfo(new int[]{i++}).get(0).getValue());
            }
        }
        return this.bucketValues;
    }

    @Override
    public List<Encoding> getBucketInfo(int[] buckets) {
        List<Encoding> bucketInfo = this.scalarEncoder.getBucketInfo(buckets);
        int categoryIndex = (int)Math.round((Double)bucketInfo.get(0).getValue());
        String category = (String)this.indexToCategory.get(categoryIndex);
        bucketInfo.set(0, new Encoding((Object)category, categoryIndex, bucketInfo.get(0).getEncoding()));
        return bucketInfo;
    }

    @Override
    public List<Encoding> topDownCompute(int[] encoded) {
        SparseObjectMatrix<int[]> topDownMapping = this.scalarEncoder.getTopDownMapping();
        int category = ArrayUtils.argmax(this.rightVecProd(topDownMapping, encoded));
        return this.getBucketInfo(new int[]{category});
    }

    public List<String> getCategoryList() {
        return this.categoryList;
    }

    public void setCategoryList(List<String> categoryList) {
        this.categoryList = categoryList;
    }

    @Override
    public int getWidth() {
        return this.getN();
    }

    @Override
    public boolean isDelta() {
        return false;
    }

    public static class Builder
    extends Encoder.Builder<Builder, CategoryEncoder> {
        private List<String> categoryList;

        private Builder() {
        }

        @Override
        public CategoryEncoder build() {
            this.encoder = new CategoryEncoder();
            super.build();
            if (this.categoryList == null) {
                throw new IllegalStateException("Category List cannot be null");
            }
            ((CategoryEncoder)this.encoder).setCategoryList(this.categoryList);
            ((CategoryEncoder)this.encoder).init();
            return (CategoryEncoder)this.encoder;
        }

        public Builder categoryList(List<String> categoryList) {
            this.categoryList = categoryList;
            return this;
        }
    }
}

