/*
 * Decompiled with CFR 0.152.
 */
package org.numenta.nupic.encoders;

import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.set.hash.TIntHashSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.numenta.nupic.FieldMetaType;
import org.numenta.nupic.encoders.DecodeResult;
import org.numenta.nupic.encoders.Encoder;
import org.numenta.nupic.encoders.Encoding;
import org.numenta.nupic.encoders.RangeList;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.Condition;
import org.numenta.nupic.util.MinMax;
import org.numenta.nupic.util.SparseObjectMatrix;
import org.numenta.nupic.util.Tuple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SDRCategoryEncoder
extends Encoder<String> {
    private static final long serialVersionUID = 1L;
    private static final Logger LOG = LoggerFactory.getLogger(SDRCategoryEncoder.class);
    private Random random;
    private int thresholdOverlap;
    private final SDRByCategoryMap sdrByCategory = new SDRByCategoryMap();

    public static Builder builder() {
        return new Builder();
    }

    private SDRCategoryEncoder() {
    }

    private void init(int n, int w, List<String> categoryList, String name, int encoderSeed, boolean forced) {
        this.n = n;
        this.w = w;
        this.encLearningEnabled = true;
        this.random = new Random();
        if (encoderSeed != -1) {
            this.random.setSeed(encoderSeed);
        }
        if (!forced) {
            if (n / w < 2) {
                throw new IllegalArgumentException(String.format("Number of ON bits in SDR (%d) must be much smaller than the output width (%d)", w, n));
            }
            if (w < 21) {
                throw new IllegalArgumentException(String.format("Number of bits in the SDR (%d) must be greater than 2, and should be >= 21, pass forced=True to init() to override this check", w));
            }
        }
        double density = (double)this.w / (double)this.n;
        double averageOverlap = (double)w * density;
        this.thresholdOverlap = (int)(averageOverlap + (double)this.w) / 2;
        if (this.thresholdOverlap < this.w - 3) {
            this.thresholdOverlap = this.w - 3;
        }
        this.description.add(new Tuple(name, 0));
        this.name = name;
        this.addCategory("<UNKNOWN>");
        if (categoryList == null || categoryList.size() == 0) {
            this.setLearningEnabled(true);
        } else {
            this.setLearningEnabled(false);
            for (String category : categoryList) {
                this.addCategory(category);
            }
        }
    }

    @Override
    public int getWidth() {
        return this.getN();
    }

    @Override
    public boolean isDelta() {
        return false;
    }

    @Override
    public void encodeIntoArray(String input, int[] output) {
        int index;
        if (input == null || input.isEmpty()) {
            Arrays.fill(output, 0);
            index = 0;
        } else {
            index = this.getBucketIndices(input)[0];
            int[] categoryEncoding = this.sdrByCategory.getSdr(index);
            System.arraycopy(categoryEncoding, 0, output, 0, categoryEncoding.length);
        }
        LOG.trace("input:" + input + ", index:" + index + ", output:" + ArrayUtils.intArrayToString(output));
        LOG.trace("decoded:" + this.decodedToStr(this.decode(output, "")));
    }

    @Override
    public Set<FieldMetaType> getDecoderOutputFieldTypes() {
        return new HashSet<FieldMetaType>(Arrays.asList(FieldMetaType.LIST, FieldMetaType.STRING));
    }

    @Override
    public int[] getBucketIndices(String input) {
        return new int[]{(int)this.getScalars(input).get(0)};
    }

    @Override
    public <S> TDoubleList getScalars(S input) {
        String inputCasted = (String)input;
        int index = 0;
        TDoubleArrayList result = new TDoubleArrayList();
        if (inputCasted == null || inputCasted.isEmpty()) {
            result.add(0.0);
            return result;
        }
        if (!this.sdrByCategory.containsKey(input)) {
            if (this.isEncoderLearningEnabled()) {
                index = this.sdrByCategory.size();
                this.addCategory(inputCasted);
            }
        } else {
            index = this.sdrByCategory.getIndexByCategory(inputCasted);
        }
        result.add((double)index);
        return result;
    }

    public DecodeResult decode(int[] encoded) {
        return this.decode(encoded, null);
    }

    @Override
    public DecodeResult decode(int[] encoded, String parentFieldName) {
        Object sdr;
        assert (ArrayUtils.all(encoded, new Condition.Adapter<Integer>(){

            @Override
            public boolean eval(int i) {
                return i <= 1;
            }
        }));
        int[] overlap = new int[this.sdrByCategory.size()];
        for (int i = 0; i < this.sdrByCategory.size(); ++i) {
            sdr = this.sdrByCategory.getSdr(i);
            for (int j = 0; j < ((int[])sdr).length; ++j) {
                if (sdr[j] != encoded[j] || encoded[j] != 1) continue;
                int n = i;
                overlap[n] = overlap[n] + 1;
            }
        }
        LOG.trace("Overlaps for decoding:");
        if (LOG.isTraceEnabled()) {
            int inx = 0;
            sdr = this.sdrByCategory.keySet().iterator();
            while (sdr.hasNext()) {
                String category = (String)sdr.next();
                LOG.trace(overlap[inx] + " " + category);
                ++inx;
            }
        }
        int[] matchingCategories = ArrayUtils.where(overlap, new Condition.Adapter<Integer>(){

            @Override
            public boolean eval(int overlaps) {
                return overlaps > SDRCategoryEncoder.this.thresholdOverlap;
            }
        });
        StringBuilder resultString = new StringBuilder();
        ArrayList<MinMax> resultRanges = new ArrayList<MinMax>();
        for (int index : matchingCategories) {
            if (resultString.length() != 0) {
                resultString.append(" ");
            }
            resultString.append(this.sdrByCategory.getCategory(index));
            resultRanges.add(new MinMax(index, index));
        }
        String fieldName = parentFieldName == null || parentFieldName.isEmpty() ? this.getName() : String.format("%s.%s", parentFieldName, this.getName());
        HashMap<String, RangeList> fieldsDict = new HashMap<String, RangeList>();
        fieldsDict.put(fieldName, new RangeList((List<MinMax>)resultRanges, resultString.toString()));
        return new DecodeResult((Map<String, RangeList>)fieldsDict, Arrays.asList(fieldName));
    }

    @Override
    public List<Encoding> topDownCompute(int[] encoded) {
        if (this.sdrByCategory.size() == 0) {
            return new ArrayList<Encoding>();
        }
        int categoryIndex = ArrayUtils.argmax(this.rightVecProd(this.getTopDownMapping(), encoded));
        return this.getEncoderResultsByIndex(this.getTopDownMapping(), categoryIndex);
    }

    @Override
    public List<Encoding> getBucketInfo(int[] buckets) {
        if (this.sdrByCategory.size() == 0) {
            return new ArrayList<Encoding>();
        }
        int categoryIndex = buckets[0];
        return this.getEncoderResultsByIndex(this.getTopDownMapping(), categoryIndex);
    }

    public SparseObjectMatrix<int[]> getTopDownMapping() {
        if (this.topDownMapping == null) {
            this.topDownMapping = new SparseObjectMatrix(new int[]{this.sdrByCategory.size()});
            int[] outputSpace = new int[this.getN()];
            Set categories = this.sdrByCategory.keySet();
            int inx = 0;
            for (String category : categories) {
                this.encodeIntoArray(category, outputSpace);
                this.topDownMapping.set(inx, (Object)Arrays.copyOf(outputSpace, outputSpace.length));
                ++inx;
            }
        }
        return this.topDownMapping;
    }

    @Override
    public <S> List<S> getBucketValues(Class<S> returnType) {
        return new ArrayList(this.sdrByCategory.keySet());
    }

    public Collection<int[]> getSDRs() {
        return Collections.unmodifiableCollection(this.sdrByCategory.values());
    }

    private List<Encoding> getEncoderResultsByIndex(SparseObjectMatrix<int[]> topDownMapping, int categoryIndex) {
        ArrayList<Encoding> result = new ArrayList<Encoding>();
        String category = this.sdrByCategory.getCategory(categoryIndex);
        int[] encoding = topDownMapping.getObject(categoryIndex);
        result.add(new Encoding((Object)category, categoryIndex, encoding));
        return result;
    }

    private void addCategory(String category) {
        if (this.sdrByCategory.containsKey(category)) {
            throw new IllegalArgumentException(String.format("Attempt to add encoder category '%s' that already exists", category));
        }
        this.sdrByCategory.put(category, this.newRep());
        this.topDownMapping = null;
    }

    private int[] getSortedSample(int populationSize, int sampleLength) {
        TIntHashSet resultSet = new TIntHashSet();
        while (resultSet.size() < sampleLength) {
            resultSet.add(this.random.nextInt(populationSize));
        }
        int[] result = resultSet.toArray();
        Arrays.sort(result);
        return result;
    }

    private int[] newRep() {
        int maxAttempts = 1000;
        boolean foundUnique = true;
        int[] sdr = new int[this.n];
        for (int index = 0; index < maxAttempts; ++index) {
            foundUnique = true;
            int[] oneBits = this.getSortedSample(this.n, this.w);
            sdr = new int[this.n];
            for (int i = 0; i < oneBits.length; ++i) {
                int oneBitInx = oneBits[i];
                sdr[oneBitInx] = 1;
            }
            for (int[] existingSdr : this.sdrByCategory.values()) {
                if (!Arrays.equals(sdr, existingSdr)) continue;
                foundUnique = false;
                break;
            }
            if (foundUnique) break;
        }
        if (!foundUnique) {
            throw new RuntimeException(String.format("Error, could not find unique pattern %d after %d attempts", this.sdrByCategory.size(), maxAttempts));
        }
        return sdr;
    }

    public static final class Builder
    extends Encoder.Builder<Builder, SDRCategoryEncoder> {
        private List<String> categoryList = new ArrayList<String>();
        private int encoderSeed = 1;

        @Override
        public SDRCategoryEncoder build() {
            if (this.n == 0) {
                throw new IllegalStateException("\"N\" should be set");
            }
            if (this.w == 0) {
                throw new IllegalStateException("\"W\" should be set");
            }
            if (this.categoryList == null) {
                throw new IllegalStateException("Category List cannot be null");
            }
            SDRCategoryEncoder sdrCategoryEncoder = new SDRCategoryEncoder();
            sdrCategoryEncoder.init(this.n, this.w, this.categoryList, this.name, this.encoderSeed, this.forced);
            return sdrCategoryEncoder;
        }

        public Builder categoryList(List<String> categoryList) {
            this.categoryList = categoryList;
            return this;
        }

        public Builder encoderSeed(int encoderSeed) {
            this.encoderSeed = encoderSeed;
            return this;
        }

        @Override
        public Builder radius(double radius) {
            throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder");
        }

        @Override
        public Builder resolution(double resolution) {
            throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder");
        }

        @Override
        public Builder periodic(boolean periodic) {
            throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder");
        }

        @Override
        public Builder clipInput(boolean clipInput) {
            throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder");
        }

        @Override
        public Builder maxVal(double maxVal) {
            throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder");
        }

        @Override
        public Builder minVal(double minVal) {
            throw new IllegalArgumentException("Not supported for this SDRCategoryEncoder");
        }
    }

    private static final class SDRByCategoryMap
    extends LinkedHashMap<String, int[]> {
        private SDRByCategoryMap() {
        }

        public int[] getSdr(int index) {
            Map.Entry<String, int[]> entry = this.getEntry(index);
            if (entry == null) {
                return null;
            }
            return entry.getValue();
        }

        public String getCategory(int index) {
            Map.Entry<String, int[]> entry = this.getEntry(index);
            if (entry == null) {
                return null;
            }
            return entry.getKey();
        }

        public int getIndexByCategory(String category) {
            Set categories = this.keySet();
            int inx = 0;
            for (String s : categories) {
                if (s.equals(category)) {
                    return inx;
                }
                ++inx;
            }
            return 0;
        }

        private Map.Entry<String, int[]> getEntry(int i) {
            Set entries = this.entrySet();
            if (i < 0 || i > entries.size()) {
                throw new IllegalArgumentException("Index should be in following range:[0," + entries.size() + "]");
            }
            int j = 0;
            for (Map.Entry<String, int[]> entry : entries) {
                if (j++ != i) continue;
                return entry;
            }
            return null;
        }
    }
}

