/*
 * Decompiled with CFR 0.152.
 */
package org.numenta.nupic.algorithms;

import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import no.uib.cipr.matrix.sparse.FlexCompRowMatrix;
import org.numenta.nupic.algorithms.Classification;
import org.numenta.nupic.model.Persistable;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.Deque;
import org.numenta.nupic.util.Tuple;

public class SDRClassifier
implements Persistable {
    private static final long serialVersionUID = 1L;
    int verbosity = 0;
    double alpha = 0.001;
    double actValueAlpha = 0.3;
    int learnIteration;
    int recordNumMinusLearnIteration = -1;
    int maxInputIdx = 0;
    int maxBucketIdx;
    Map<Integer, FlexCompRowMatrix> weightMatrix = new HashMap<Integer, FlexCompRowMatrix>();
    TIntList steps = new TIntArrayList();
    Deque<Tuple> patternNZHistory;
    List<?> actualValues = new ArrayList();
    String g_debugPrefix = "SDRClassifier";

    public SDRClassifier() {
        this((TIntList)new TIntArrayList(new int[]{1}), 0.001, 0.3, 0);
    }

    public SDRClassifier(TIntList steps, double alpha, double actValueAlpha, int verbosity) {
        this.steps = steps;
        this.alpha = alpha;
        this.actValueAlpha = actValueAlpha;
        this.verbosity = verbosity;
        this.actualValues.add(null);
        this.patternNZHistory = new Deque(ArrayUtils.max(steps.toArray()) + 1);
        for (int step : steps.toArray()) {
            this.weightMatrix.put(step, new FlexCompRowMatrix(this.maxBucketIdx + 1, this.maxInputIdx + 1));
        }
    }

    public <T> Classification<T> compute(int recordNum, Map<String, Object> classification, int[] patternNZ, boolean learn, boolean infer) {
        Classification<T> retVal = null;
        List<?> actualValues = this.actualValues;
        if (this.recordNumMinusLearnIteration == -1) {
            this.recordNumMinusLearnIteration = recordNum - this.learnIteration;
        }
        this.learnIteration = recordNum - this.recordNumMinusLearnIteration;
        if (this.verbosity >= 1) {
            System.out.println(String.format("\n%s: compute ", this.g_debugPrefix));
            System.out.printf("recordNum: %d\n", recordNum);
            System.out.printf("learnIteration: %d\n", this.learnIteration);
            System.out.printf("patternNZ (%d): %s\n", patternNZ.length, ArrayUtils.intArrayToString(patternNZ));
            System.out.println("classificationIn: " + classification);
        }
        this.patternNZHistory.append(new Tuple(this.learnIteration, patternNZ));
        if (ArrayUtils.max(patternNZ) > this.maxInputIdx) {
            int newMaxInputIdx = ArrayUtils.max(patternNZ);
            int[] nArray = this.steps.toArray();
            int n = nArray.length;
            for (int i = 0; i < n; ++i) {
                int nSteps = nArray[i];
                for (int i2 = this.maxInputIdx; i2 < newMaxInputIdx; ++i2) {
                    this.weightMatrix.get(nSteps).addCol(new double[this.maxBucketIdx + 1]);
                }
            }
            this.maxInputIdx = newMaxInputIdx;
        }
        if (infer) {
            retVal = this.infer(patternNZ, classification);
        }
        if (learn && classification.get("bucketIdx") != null) {
            int bucketIdx = (Integer)classification.get("bucketIdx");
            Object actValue = classification.get("actValue");
            if (bucketIdx > this.maxBucketIdx) {
                for (int nSteps : this.steps.toArray()) {
                    for (int i = this.maxBucketIdx; i < bucketIdx; ++i) {
                        this.weightMatrix.get(nSteps).addRow(new double[this.maxInputIdx + 1]);
                    }
                }
                this.maxBucketIdx = bucketIdx;
            }
            while (this.maxBucketIdx > actualValues.size() - 1) {
                actualValues.add(null);
            }
            if (actualValues.get(bucketIdx) == null) {
                actualValues.set(bucketIdx, actValue);
            } else if (Number.class.isAssignableFrom(actValue.getClass())) {
                Double val = (1.0 - this.actValueAlpha) * ((Number)actualValues.get(bucketIdx)).doubleValue() + this.actValueAlpha * ((Number)actValue).doubleValue();
                actualValues.set(bucketIdx, val);
            } else {
                actualValues.set(bucketIdx, actValue);
            }
            int iteration = 0;
            int[] learnPatternNZ = null;
            for (Tuple t : this.patternNZHistory) {
                iteration = (Integer)t.get(0);
                learnPatternNZ = (int[])t.get(1);
                Map<Integer, double[]> error = this.calculateError(classification);
                int nSteps = this.learnIteration - iteration;
                if (!this.steps.contains(nSteps)) continue;
                for (int row = 0; row <= this.maxBucketIdx; ++row) {
                    for (int bit : learnPatternNZ) {
                        this.weightMatrix.get(nSteps).add(row, bit, this.alpha * error.get(nSteps)[row]);
                    }
                }
            }
        }
        if (infer && this.verbosity >= 1) {
            System.out.println(" inference: combined bucket likelihoods:");
            System.out.println("   actual bucket values: " + Arrays.toString(retVal.getActualValues()));
            for (int key : retVal.stepSet()) {
                if (retVal.getActualValue(key) == null) continue;
                Object[] actual = new Object[]{retVal.getActualValue(key)};
                System.out.println(String.format("  %d steps: ", key, this.pFormatArray(actual)));
                int bestBucketIdx = retVal.getMostProbableBucketIndex(key);
                System.out.println(String.format("   most likely bucket idx: %d, value: %s ", bestBucketIdx, retVal.getActualValue(bestBucketIdx)));
            }
        }
        return retVal;
    }

    private <T> Classification<T> infer(int[] patternNZ, Map<String, Object> classification) {
        Classification<Object> retVal = new Classification<Object>();
        Object defaultValue = null;
        defaultValue = this.steps.get(0) == 0 || classification == null ? Integer.valueOf(0) : classification.get("actValue");
        Object[] actValues = new Object[this.actualValues.size()];
        for (int i = 0; i < this.actualValues.size(); ++i) {
            actValues[i] = this.actualValues.get(i) == null ? defaultValue : this.actualValues.get(i);
        }
        retVal.setActualValues(actValues);
        for (int nSteps : this.steps.toArray()) {
            double[] predictDist = this.inferSingleStep(patternNZ, this.weightMatrix.get(nSteps));
            retVal.setStats(nSteps, predictDist);
        }
        return retVal;
    }

    private double[] inferSingleStep(int[] patternNZ, FlexCompRowMatrix weightMatrix) {
        double[] outputActivation = new double[this.maxBucketIdx + 1];
        for (int row = 0; row <= this.maxBucketIdx; ++row) {
            for (int bit : patternNZ) {
                int n = row;
                outputActivation[n] = outputActivation[n] + weightMatrix.get(row, bit);
            }
        }
        double[] expOutputActivation = new double[outputActivation.length];
        for (int i = 0; i < expOutputActivation.length; ++i) {
            expOutputActivation[i] = Math.exp(outputActivation[i]);
        }
        double[] predictDist = new double[outputActivation.length];
        for (int i = 0; i < predictDist.length; ++i) {
            predictDist[i] = expOutputActivation[i] / ArrayUtils.sum(expOutputActivation);
        }
        return predictDist;
    }

    private Map<Integer, double[]> calculateError(Map<String, Object> classification) {
        HashMap<Integer, double[]> error = new HashMap<Integer, double[]>();
        int[] targetDist = new int[this.maxBucketIdx + 1];
        targetDist[((Integer)classification.get((Object)"bucketIdx")).intValue()] = 1;
        int iteration = 0;
        int[] learnPatternNZ = null;
        int nSteps = 0;
        for (Tuple t : this.patternNZHistory) {
            iteration = (Integer)t.get(0);
            learnPatternNZ = (int[])t.get(1);
            nSteps = this.learnIteration - iteration;
            if (!this.steps.contains(nSteps)) continue;
            double[] predictDist = this.inferSingleStep(learnPatternNZ, this.weightMatrix.get(nSteps));
            double[] targetDistMinusPredictDist = new double[this.maxBucketIdx + 1];
            for (int i = 0; i <= this.maxBucketIdx; ++i) {
                targetDistMinusPredictDist[i] = (double)targetDist[i] - predictDist[i];
            }
            error.put(nSteps, targetDistMinusPredictDist);
        }
        return error;
    }

    private <T> String pFormatArray(T[] arr) {
        if (arr == null) {
            return "";
        }
        StringBuilder sb = new StringBuilder("[ ");
        for (T t : arr) {
            sb.append(String.format("%.2s", t));
        }
        sb.append(" ]");
        return sb.toString();
    }
}

