/*
 * Decompiled with CFR 0.152.
 */
package org.numenta.nupic.research;

import gnu.trove.TIntCollection;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.set.hash.TIntHashSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.numenta.nupic.Connections;
import org.numenta.nupic.model.Column;
import org.numenta.nupic.model.Pool;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.Condition;
import org.numenta.nupic.util.SparseBinaryMatrix;
import org.numenta.nupic.util.SparseMatrix;
import org.numenta.nupic.util.SparseObjectMatrix;

public class SpatialPooler {
    public void init(Connections c) {
        this.initMatrices(c);
        this.connectAndConfigureInputs(c);
    }

    public void initMatrices(Connections c) {
        SparseObjectMatrix<Column> sparseObjectMatrix;
        SparseObjectMatrix<Column> mem = c.getMemory();
        if (mem == null) {
            mem = new SparseObjectMatrix(c.getColumnDimensions());
            sparseObjectMatrix = mem;
        } else {
            sparseObjectMatrix = mem;
        }
        c.setMemory(sparseObjectMatrix);
        c.setInputMatrix(new SparseBinaryMatrix(c.getInputDimensions()));
        int numInputs = c.getInputMatrix().getMaxIndex() + 1;
        int numColumns = c.getMemory().getMaxIndex() + 1;
        c.setNumInputs(numInputs);
        c.setNumColumns(numColumns);
        for (int i = 0; i < numColumns; ++i) {
            mem.set(i, new Column(c.getCellsPerColumn(), i));
        }
        c.setPotentialPools(new SparseObjectMatrix<Pool>(c.getMemory().getDimensions()));
        c.setConnectedMatrix(new SparseBinaryMatrix(new int[]{numColumns, numInputs}));
        double[] tieBreaker = new double[numColumns];
        for (int i = 0; i < numColumns; ++i) {
            tieBreaker[i] = 0.01 * c.getRandom().nextDouble();
        }
        c.setTieBreaker(tieBreaker);
        c.setOverlapDutyCycles(new double[numColumns]);
        c.setActiveDutyCycles(new double[numColumns]);
        c.setMinOverlapDutyCycles(new double[numColumns]);
        c.setMinActiveDutyCycles(new double[numColumns]);
        c.setBoostFactors(new double[numColumns]);
        Arrays.fill(c.getBoostFactors(), 1.0);
    }

    public void connectAndConfigureInputs(Connections c) {
        int numColumns = c.getNumColumns();
        for (int i = 0; i < numColumns; ++i) {
            int[] potential = this.mapPotential(c, i, true);
            Column column = c.getColumn(i);
            c.getPotentialPools().set(i, column.createPotentialPool(c, potential));
            double[] perm = this.initPermanence(c, potential, i, c.getInitConnectedPct());
            this.updatePermanencesForColumn(c, perm, column, potential, true);
        }
        this.updateInhibitionRadius(c);
    }

    public void compute(Connections c, int[] inputVector, int[] activeArray, boolean learn, boolean stripNeverLearned) {
        if (inputVector.length != c.getNumInputs()) {
            throw new IllegalArgumentException("Input array must be same size as the defined number of inputs: From Params: " + c.getNumInputs() + ", From Input Vector: " + inputVector.length);
        }
        this.updateBookeepingVars(c, learn);
        int[] overlaps = this.calculateOverlap(c, inputVector);
        double[] boostedOverlaps = learn ? ArrayUtils.multiply(c.getBoostFactors(), overlaps) : ArrayUtils.toDoubleArray(overlaps);
        int[] activeColumns = this.inhibitColumns(c, boostedOverlaps);
        if (learn) {
            this.adaptSynapses(c, inputVector, activeColumns);
            this.updateDutyCycles(c, overlaps, activeColumns);
            this.bumpUpWeakColumns(c);
            this.updateBoostFactors(c);
            if (this.isUpdateRound(c)) {
                this.updateInhibitionRadius(c);
                this.updateMinDutyCycles(c);
            }
        } else if (stripNeverLearned) {
            activeColumns = this.stripUnlearnedColumns(c, activeColumns).toArray();
        }
        Arrays.fill(activeArray, 0);
        if (activeColumns.length > 0) {
            ArrayUtils.setIndexesTo(activeArray, activeColumns, 1);
        }
    }

    public TIntArrayList stripUnlearnedColumns(Connections c, int[] activeColumns) {
        TIntHashSet active = new TIntHashSet(activeColumns);
        TIntHashSet aboveZero = new TIntHashSet();
        int numCols = c.getNumColumns();
        double[] colDutyCycles = c.getActiveDutyCycles();
        for (int i = 0; i < numCols; ++i) {
            if (!(colDutyCycles[i] <= 0.0)) continue;
            aboveZero.add(i);
        }
        active.removeAll((TIntCollection)aboveZero);
        TIntArrayList l = new TIntArrayList((TIntCollection)active);
        l.sort();
        return l;
    }

    public void updateMinDutyCycles(Connections c) {
        if (c.getGlobalInhibition() || c.getInhibitionRadius() > c.getNumInputs()) {
            this.updateMinDutyCyclesGlobal(c);
        } else {
            this.updateMinDutyCyclesLocal(c);
        }
    }

    public void updateMinDutyCyclesGlobal(Connections c) {
        Arrays.fill(c.getMinOverlapDutyCycles(), c.getMinPctOverlapDutyCycles() * ArrayUtils.max(c.getOverlapDutyCycles()));
        Arrays.fill(c.getMinActiveDutyCycles(), c.getMinPctActiveDutyCycles() * ArrayUtils.max(c.getActiveDutyCycles()));
    }

    public void updateMinDutyCyclesLocal(Connections c) {
        int len = c.getNumColumns();
        for (int i = 0; i < len; ++i) {
            int[] maskNeighbors = this.getNeighborsND(c, i, c.getMemory(), c.getInhibitionRadius(), true).toArray();
            c.getMinOverlapDutyCycles()[i] = ArrayUtils.max(ArrayUtils.sub(c.getOverlapDutyCycles(), maskNeighbors)) * c.getMinPctOverlapDutyCycles();
            c.getMinActiveDutyCycles()[i] = ArrayUtils.max(ArrayUtils.sub(c.getActiveDutyCycles(), maskNeighbors)) * c.getMinPctActiveDutyCycles();
        }
    }

    public void updateDutyCycles(Connections c, int[] overlaps, int[] activeColumns) {
        int period;
        double[] overlapArray = new double[c.getNumColumns()];
        double[] activeArray = new double[c.getNumColumns()];
        ArrayUtils.greaterThanXThanSetToY(overlaps, 0, 1);
        if (activeColumns.length > 0) {
            ArrayUtils.setIndexesTo(activeArray, activeColumns, 1.0);
        }
        if ((period = c.getDutyCyclePeriod()) > c.getIterationNum()) {
            period = c.getIterationNum();
        }
        c.setOverlapDutyCycles(this.updateDutyCyclesHelper(c, c.getOverlapDutyCycles(), overlapArray, period));
        c.setActiveDutyCycles(this.updateDutyCyclesHelper(c, c.getActiveDutyCycles(), activeArray, period));
    }

    public double[] updateDutyCyclesHelper(Connections c, double[] dutyCycles, double[] newInput, double period) {
        return ArrayUtils.divide(ArrayUtils.d_add(ArrayUtils.multiply(dutyCycles, period - 1.0), newInput), period);
    }

    public double avgConnectedSpanForColumnND(Connections c, int columnIndex) {
        int[] dimensions = c.getInputDimensions();
        int[] connected = c.getColumn(columnIndex).getProximalDendrite().getConnectedSynapsesSparse(c);
        if (connected == null || connected.length == 0) {
            return 0.0;
        }
        int[] maxCoord = new int[c.getInputDimensions().length];
        int[] minCoord = new int[c.getInputDimensions().length];
        Arrays.fill(maxCoord, -1);
        Arrays.fill(minCoord, ArrayUtils.max(dimensions));
        SparseMatrix<?> inputMatrix = c.getInputMatrix();
        for (int i = 0; i < connected.length; ++i) {
            maxCoord = ArrayUtils.maxBetween(maxCoord, inputMatrix.computeCoordinates(connected[i]));
            minCoord = ArrayUtils.minBetween(minCoord, inputMatrix.computeCoordinates(connected[i]));
        }
        return ArrayUtils.average(ArrayUtils.add(ArrayUtils.subtract(maxCoord, minCoord), 1));
    }

    public void updateInhibitionRadius(Connections c) {
        if (c.getGlobalInhibition()) {
            c.setInhibitionRadius(ArrayUtils.max(c.getColumnDimensions()));
            return;
        }
        TDoubleArrayList avgCollected = new TDoubleArrayList();
        int len = c.getNumColumns();
        for (int i = 0; i < len; ++i) {
            avgCollected.add(this.avgConnectedSpanForColumnND(c, i));
        }
        double avgConnectedSpan = ArrayUtils.average(avgCollected.toArray());
        double diameter = avgConnectedSpan * this.avgColumnsPerInput(c);
        double radius = (diameter - 1.0) / 2.0;
        radius = Math.max(1.0, radius);
        c.setInhibitionRadius((int)Math.round(radius));
    }

    public double avgColumnsPerInput(Connections c) {
        int[] colDim = Arrays.copyOf(c.getColumnDimensions(), c.getColumnDimensions().length);
        int[] inputDim = Arrays.copyOf(c.getInputDimensions(), c.getInputDimensions().length);
        double[] columnsPerInput = ArrayUtils.divide(ArrayUtils.toDoubleArray(colDim), ArrayUtils.toDoubleArray(inputDim), 0.0, 0.0);
        return ArrayUtils.average(columnsPerInput);
    }

    public void adaptSynapses(Connections c, int[] inputVector, int[] activeColumns) {
        int[] inputIndices = ArrayUtils.where(inputVector, ArrayUtils.INT_GREATER_THAN_0);
        double[] permChanges = new double[c.getNumInputs()];
        Arrays.fill(permChanges, -1.0 * c.getSynPermInactiveDec());
        ArrayUtils.setIndexesTo(permChanges, inputIndices, c.getSynPermActiveInc());
        for (int i = 0; i < activeColumns.length; ++i) {
            Pool pool = c.getPotentialPools().getObject(activeColumns[i]);
            double[] perm = pool.getDensePermanences(c);
            int[] indexes = pool.getSparseConnections();
            ArrayUtils.raiseValuesBy(permChanges, perm);
            Column col = c.getColumn(activeColumns[i]);
            this.updatePermanencesForColumn(c, perm, col, indexes, true);
        }
    }

    public void bumpUpWeakColumns(final Connections c) {
        int[] weakColumns = ArrayUtils.where(c.getMemory().get1DIndexes(), new Condition.Adapter<Integer>(){

            @Override
            public boolean eval(int i) {
                return c.getOverlapDutyCycles()[i] < c.getMinOverlapDutyCycles()[i];
            }
        });
        for (int i = 0; i < weakColumns.length; ++i) {
            Pool pool = c.getPotentialPools().getObject(weakColumns[i]);
            double[] perm = pool.getSparsePermanences();
            ArrayUtils.raiseValuesBy(c.getSynPermBelowStimulusInc(), perm);
            int[] indexes = pool.getSparseConnections();
            Column col = c.getColumn(weakColumns[i]);
            this.updatePermanencesForColumnSparse(c, perm, col, indexes, true);
        }
    }

    public void raisePermanenceToThreshold(Connections c, double[] perm, int[] maskPotential) {
        ArrayUtils.clip(perm, c.getSynPermMin(), c.getSynPermMax());
        int numConnected;
        while (!((double)(numConnected = ArrayUtils.valueGreaterCountAtIndex(c.getSynPermConnected(), perm, maskPotential)) >= c.getStimulusThreshold())) {
            ArrayUtils.raiseValuesBy(c.getSynPermBelowStimulusInc(), perm, maskPotential);
        }
        return;
    }

    public void raisePermanenceToThresholdSparse(Connections c, double[] perm) {
        ArrayUtils.clip(perm, c.getSynPermMin(), c.getSynPermMax());
        int numConnected;
        while (!((double)(numConnected = ArrayUtils.valueGreaterCount(c.getSynPermConnected(), perm)) >= c.getStimulusThreshold())) {
            ArrayUtils.raiseValuesBy(c.getSynPermBelowStimulusInc(), perm);
        }
        return;
    }

    public void updatePermanencesForColumn(Connections c, double[] perm, Column column, int[] maskPotential, boolean raisePerm) {
        if (raisePerm) {
            this.raisePermanenceToThreshold(c, perm, maskPotential);
        }
        ArrayUtils.lessThanOrEqualXThanSetToY(perm, c.getSynPermTrimThreshold(), 0.0);
        ArrayUtils.clip(perm, c.getSynPermMin(), c.getSynPermMax());
        column.setProximalPermanences(c, perm);
    }

    public void updatePermanencesForColumnSparse(Connections c, double[] perm, Column column, int[] maskPotential, boolean raisePerm) {
        if (raisePerm) {
            this.raisePermanenceToThresholdSparse(c, perm);
        }
        ArrayUtils.lessThanOrEqualXThanSetToY(perm, c.getSynPermTrimThreshold(), 0.0);
        ArrayUtils.clip(perm, c.getSynPermMin(), c.getSynPermMax());
        column.setProximalPermanencesSparse(c, perm, maskPotential);
    }

    public static double initPermConnected(Connections c) {
        double p = c.getSynPermConnected() + c.getRandom().nextDouble() * c.getSynPermActiveInc() / 4.0;
        p = (double)((int)(p * 100000.0)) / 100000.0;
        return p;
    }

    public static double initPermNonConnected(Connections c) {
        double p = c.getSynPermConnected() * c.getRandom().nextDouble();
        p = (double)((int)(p * 100000.0)) / 100000.0;
        return p;
    }

    public double[] initPermanence(Connections c, int[] potentialPool, int index, double connectedPct) {
        int count = (int)Math.round((double)potentialPool.length * connectedPct);
        TIntHashSet pick = new TIntHashSet();
        Random random = c.getRandom();
        while (pick.size() < count) {
            int randIdx = random.nextInt(potentialPool.length);
            pick.add(potentialPool[randIdx]);
        }
        double[] perm = new double[c.getNumInputs()];
        for (int idx : potentialPool) {
            perm[idx] = pick.contains(idx) ? SpatialPooler.initPermConnected(c) : SpatialPooler.initPermNonConnected(c);
            perm[idx] = perm[idx] < c.getSynPermTrimThreshold() ? 0.0 : perm[idx];
        }
        c.getColumn(index).setProximalPermanences(c, perm);
        return perm;
    }

    public int mapColumn(Connections c, int columnIndex) {
        int[] columnCoords = c.getMemory().computeCoordinates(columnIndex);
        double[] colCoords = ArrayUtils.toDoubleArray(columnCoords);
        double[] ratios = ArrayUtils.divide(colCoords, ArrayUtils.toDoubleArray(c.getColumnDimensions()), 0.0, 0.0);
        double[] inputCoords = ArrayUtils.multiply(ArrayUtils.toDoubleArray(c.getInputDimensions()), ratios, 0.0, 0.0);
        inputCoords = ArrayUtils.d_add(inputCoords, ArrayUtils.multiply(ArrayUtils.divide(ArrayUtils.toDoubleArray(c.getInputDimensions()), ArrayUtils.toDoubleArray(c.getColumnDimensions()), 0.0, 0.0), 0.5));
        int[] inputCoordInts = ArrayUtils.clip(ArrayUtils.toIntArray(inputCoords), c.getInputDimensions(), -1);
        return c.getInputMatrix().computeIndex(inputCoordInts);
    }

    public int[] mapPotential(Connections c, int columnIndex, boolean wrapAround) {
        int inputIndex = this.mapColumn(c, columnIndex);
        TIntArrayList indices = this.getNeighborsND(c, inputIndex, c.getInputMatrix(), c.getPotentialRadius(), wrapAround);
        indices.add(inputIndex);
        indices.sort();
        return ArrayUtils.sample((int)Math.round((double)indices.size() * c.getPotentialPct()), indices, c.getRandom());
    }

    public TIntArrayList getNeighborsND(Connections c, int columnIndex, SparseMatrix<?> topology, int inhibitionRadius, boolean wrapAround) {
        final int[] dimensions = topology.getDimensions();
        int[] columnCoords = topology.computeCoordinates(columnIndex);
        ArrayList<int[]> dimensionCoords = new ArrayList<int[]>();
        for (int i = 0; i < dimensions.length; ++i) {
            int[] range = ArrayUtils.range(columnCoords[i] - inhibitionRadius, columnCoords[i] + inhibitionRadius + 1);
            int[] curRange = new int[range.length];
            if (wrapAround) {
                for (int j = 0; j < curRange.length; ++j) {
                    curRange[j] = (int)ArrayUtils.positiveRemainder(range[j], dimensions[i]);
                }
            } else {
                final int idx = i;
                curRange = ArrayUtils.retainLogicalAnd(range, new Condition[]{ArrayUtils.GREATER_OR_EQUAL_0, new Condition.Adapter<Integer>(){

                    @Override
                    public boolean eval(int n) {
                        return n < dimensions[idx];
                    }
                }});
            }
            dimensionCoords.add(ArrayUtils.unique(curRange));
        }
        List<int[]> neighborList = ArrayUtils.dimensionsToCoordinateList(dimensionCoords);
        TIntArrayList neighbors = new TIntArrayList(neighborList.size());
        int size = neighborList.size();
        for (int i = 0; i < size; ++i) {
            int flatIndex = c.getInputMatrix().computeIndex(neighborList.get(i), false);
            if (flatIndex == columnIndex) continue;
            neighbors.add(flatIndex);
        }
        return neighbors;
    }

    public boolean isUpdateRound(Connections c) {
        return c.getIterationNum() % c.getUpdatePeriod() == 0;
    }

    public void updateBookeepingVars(Connections c, boolean learn) {
        ++c.iterationNum;
        if (learn) {
            ++c.iterationLearnNum;
        }
    }

    public int[] calculateOverlap(Connections c, int[] inputVector) {
        int[] overlaps = new int[c.getNumColumns()];
        c.getConnectedCounts().rightVecSumAtNZ(inputVector, overlaps);
        ArrayUtils.lessThanXThanSetToY(overlaps, (int)c.getStimulusThreshold(), 0);
        return overlaps;
    }

    public double[] calculateOverlapPct(Connections c, int[] overlaps) {
        return ArrayUtils.divide(overlaps, c.getConnectedCounts().getTrueCounts());
    }

    public int[] inhibitColumns(Connections c, double[] overlaps) {
        double d;
        overlaps = Arrays.copyOf(overlaps, overlaps.length);
        double density = c.getLocalAreaDensity();
        if (d <= 0.0) {
            double inhibitionArea = Math.pow(2 * c.getInhibitionRadius() + 1, c.getColumnDimensions().length);
            inhibitionArea = Math.min((double)c.getNumColumns(), inhibitionArea);
            density = c.getNumActiveColumnsPerInhArea() / inhibitionArea;
            density = Math.min(density, 0.5);
        }
        ArrayUtils.d_add(overlaps, c.getTieBreaker());
        if (c.getGlobalInhibition() || c.getInhibitionRadius() > ArrayUtils.max(c.getColumnDimensions())) {
            return this.inhibitColumnsGlobal(c, overlaps, density);
        }
        return this.inhibitColumnsLocal(c, overlaps, density);
    }

    public int[] inhibitColumnsGlobal(Connections c, double[] overlaps, double density) {
        int numCols = c.getNumColumns();
        int numActive = (int)(density * (double)numCols);
        int[] winners = ArrayUtils.nGreatest(overlaps, numActive);
        Arrays.sort(winners);
        return winners;
    }

    public int[] inhibitColumnsLocal(Connections c, double[] overlaps, double density) {
        int numCols = c.getNumColumns();
        int[] activeColumns = new int[numCols];
        double addToWinners = ArrayUtils.max(overlaps) / 1000.0;
        for (int i = 0; i < numCols; ++i) {
            TIntArrayList maskNeighbors = this.getNeighborsND(c, i, c.getMemory(), c.getInhibitionRadius(), false);
            double[] overlapSlice = ArrayUtils.sub(overlaps, maskNeighbors.toArray());
            int numActive = (int)(0.5 + density * (double)(maskNeighbors.size() + 1));
            int numBigger = ArrayUtils.valueGreaterCount(overlaps[i], overlapSlice);
            if (numBigger >= numActive) continue;
            activeColumns[i] = 1;
            int n = i;
            overlaps[n] = overlaps[n] + addToWinners;
        }
        return ArrayUtils.where(activeColumns, ArrayUtils.INT_GREATER_THAN_0);
    }

    public void updateBoostFactors(Connections c) {
        double[] boostInterim;
        double[] activeDutyCycles = c.getActiveDutyCycles();
        final double[] minActiveDutyCycles = c.getMinActiveDutyCycles();
        int[] mask = ArrayUtils.where(minActiveDutyCycles, ArrayUtils.GREATER_THAN_0);
        if (mask.length < 1) {
            boostInterim = c.getBoostFactors();
        } else {
            double[] numerator = new double[c.getNumColumns()];
            Arrays.fill(numerator, 1.0 - c.getMaxBoost());
            boostInterim = ArrayUtils.divide(numerator, minActiveDutyCycles, 0.0, 0.0);
            boostInterim = ArrayUtils.multiply(boostInterim, activeDutyCycles, 0.0, 0.0);
            boostInterim = ArrayUtils.d_add(boostInterim, c.getMaxBoost());
        }
        ArrayUtils.setIndexesTo(boostInterim, ArrayUtils.where(activeDutyCycles, new Condition.Adapter<Object>(){
            int i = 0;

            @Override
            public boolean eval(double d) {
                return d > minActiveDutyCycles[this.i++];
            }
        }), 1.0);
        c.setBoostFactors(boostInterim);
    }
}

