/*
 * Decompiled with CFR 0.152.
 */
package org.numenta.nupic.algorithms;

import chaschev.lang.Pair;
import gnu.trove.TIntCollection;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.set.hash.TIntHashSet;
import java.util.Arrays;
import java.util.stream.IntStream;
import org.numenta.nupic.model.Column;
import org.numenta.nupic.model.Connections;
import org.numenta.nupic.model.Persistable;
import org.numenta.nupic.model.Pool;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.Condition;
import org.numenta.nupic.util.SparseBinaryMatrix;
import org.numenta.nupic.util.SparseMatrix;
import org.numenta.nupic.util.SparseObjectMatrix;
import org.numenta.nupic.util.Topology;

public class SpatialPooler
implements Persistable {
    private static final long serialVersionUID = 1L;

    public void init(Connections c) {
        if (c.getNumActiveColumnsPerInhArea() == 0.0 && (c.getLocalAreaDensity() == 0.0 || c.getLocalAreaDensity() > 0.5)) {
            throw new InvalidSPParamValueException("Inhibition parameters are invalid");
        }
        c.doSpatialPoolerPostInit();
        this.initMatrices(c);
        this.connectAndConfigureInputs(c);
    }

    public void initMatrices(Connections c) {
        SparseObjectMatrix<Column> sparseObjectMatrix;
        SparseObjectMatrix<Column> mem = c.getMemory();
        if (mem == null) {
            mem = new SparseObjectMatrix(c.getColumnDimensions());
            sparseObjectMatrix = mem;
        } else {
            sparseObjectMatrix = mem;
        }
        c.setMemory(sparseObjectMatrix);
        c.setInputMatrix(new SparseBinaryMatrix(c.getInputDimensions()));
        c.setColumnTopology(new Topology(c.getColumnDimensions()));
        c.setInputTopology(new Topology(c.getInputDimensions()));
        int numInputs = c.getInputMatrix().getMaxIndex() + 1;
        int numColumns = c.getMemory().getMaxIndex() + 1;
        if (numColumns <= 0) {
            throw new InvalidSPParamValueException("Invalid number of columns: " + numColumns);
        }
        if (numInputs <= 0) {
            throw new InvalidSPParamValueException("Invalid number of inputs: " + numInputs);
        }
        c.setNumInputs(numInputs);
        c.setNumColumns(numColumns);
        for (int i = 0; i < numColumns; ++i) {
            mem.set(i, (Object)new Column(c.getCellsPerColumn(), i));
        }
        c.setPotentialPools(new SparseObjectMatrix<Pool>(c.getMemory().getDimensions()));
        c.setConnectedMatrix(new SparseBinaryMatrix(new int[]{numColumns, numInputs}));
        c.setOverlapDutyCycles(new double[numColumns]);
        c.setActiveDutyCycles(new double[numColumns]);
        c.setMinOverlapDutyCycles(new double[numColumns]);
        c.setMinActiveDutyCycles(new double[numColumns]);
        c.setBoostFactors(new double[numColumns]);
        Arrays.fill(c.getBoostFactors(), 1.0);
    }

    public void connectAndConfigureInputs(Connections c) {
        int numColumns = c.getNumColumns();
        for (int i = 0; i < numColumns; ++i) {
            int[] potential = this.mapPotential(c, i, c.isWrapAround());
            Column column = c.getColumn(i);
            c.getPotentialPools().set(i, column.createPotentialPool(c, potential));
            double[] perm = this.initPermanence(c, potential, i, c.getInitConnectedPct());
            this.updatePermanencesForColumn(c, perm, column, potential, true);
        }
        this.updateInhibitionRadius(c);
    }

    public void compute(Connections c, int[] inputVector, int[] activeArray, boolean learn) {
        if (inputVector.length != c.getNumInputs()) {
            throw new InvalidSPParamValueException("Input array must be same size as the defined number of inputs: From Params: " + c.getNumInputs() + ", From Input Vector: " + inputVector.length);
        }
        this.updateBookeepingVars(c, learn);
        int[] overlaps = c.setOverlaps(this.calculateOverlap(c, inputVector));
        double[] boostedOverlaps = learn ? ArrayUtils.multiply(c.getBoostFactors(), overlaps) : ArrayUtils.toDoubleArray(overlaps);
        int[] activeColumns = this.inhibitColumns(c, c.setBoostedOverlaps(boostedOverlaps));
        if (learn) {
            this.adaptSynapses(c, inputVector, activeColumns);
            this.updateDutyCycles(c, overlaps, activeColumns);
            this.bumpUpWeakColumns(c);
            this.updateBoostFactors(c);
            if (this.isUpdateRound(c)) {
                this.updateInhibitionRadius(c);
                this.updateMinDutyCycles(c);
            }
        }
        Arrays.fill(activeArray, 0);
        if (activeColumns.length > 0) {
            ArrayUtils.setIndexesTo(activeArray, activeColumns, 1);
        }
    }

    public int[] stripUnlearnedColumns(Connections c, int[] activeColumns) {
        TIntHashSet active = new TIntHashSet(activeColumns);
        TIntHashSet aboveZero = new TIntHashSet();
        int numCols = c.getNumColumns();
        double[] colDutyCycles = c.getActiveDutyCycles();
        for (int i2 = 0; i2 < numCols; ++i2) {
            if (!(colDutyCycles[i2] <= 0.0)) continue;
            aboveZero.add(i2);
        }
        active.removeAll((TIntCollection)aboveZero);
        TIntArrayList l = new TIntArrayList((TIntCollection)active);
        l.sort();
        return Arrays.stream(activeColumns).filter(i -> c.getActiveDutyCycles()[i] > 0.0).toArray();
    }

    public void updateMinDutyCycles(Connections c) {
        if (c.getGlobalInhibition() || c.getInhibitionRadius() > c.getNumInputs()) {
            this.updateMinDutyCyclesGlobal(c);
        } else {
            this.updateMinDutyCyclesLocal(c);
        }
    }

    public void updateMinDutyCyclesGlobal(Connections c) {
        Arrays.fill(c.getMinOverlapDutyCycles(), c.getMinPctOverlapDutyCycles() * ArrayUtils.max(c.getOverlapDutyCycles()));
        Arrays.fill(c.getMinActiveDutyCycles(), c.getMinPctActiveDutyCycles() * ArrayUtils.max(c.getActiveDutyCycles()));
    }

    public void updateMinDutyCyclesLocal(Connections c) {
        int len = c.getNumColumns();
        int inhibitionRadius = c.getInhibitionRadius();
        double[] activeDutyCycles = c.getActiveDutyCycles();
        double minPctActiveDutyCycles = c.getMinPctActiveDutyCycles();
        double[] overlapDutyCycles = c.getOverlapDutyCycles();
        double minPctOverlapDutyCycles = c.getMinPctOverlapDutyCycles();
        IntStream.range(0, len).forEach(i -> {
            int[] neighborhood = this.getColumnNeighborhood(c, i, inhibitionRadius);
            double maxActiveDuty = ArrayUtils.max(ArrayUtils.sub(activeDutyCycles, neighborhood));
            double maxOverlapDuty = ArrayUtils.max(ArrayUtils.sub(overlapDutyCycles, neighborhood));
            c.getMinActiveDutyCycles()[i] = maxActiveDuty * minPctActiveDutyCycles;
            c.getMinOverlapDutyCycles()[i] = maxOverlapDuty * minPctOverlapDutyCycles;
        });
    }

    public void updateDutyCycles(Connections c, int[] overlaps, int[] activeColumns) {
        int period;
        double[] overlapArray = new double[c.getNumColumns()];
        double[] activeArray = new double[c.getNumColumns()];
        ArrayUtils.greaterThanXThanSetToYInB(overlaps, overlapArray, 0, 1.0);
        if (activeColumns.length > 0) {
            ArrayUtils.setIndexesTo(activeArray, activeColumns, 1.0);
        }
        if ((period = c.getDutyCyclePeriod()) > c.getIterationNum()) {
            period = c.getIterationNum();
        }
        c.setOverlapDutyCycles(this.updateDutyCyclesHelper(c, c.getOverlapDutyCycles(), overlapArray, period));
        c.setActiveDutyCycles(this.updateDutyCyclesHelper(c, c.getActiveDutyCycles(), activeArray, period));
    }

    public double[] updateDutyCyclesHelper(Connections c, double[] dutyCycles, double[] newInput, double period) {
        return ArrayUtils.divide(ArrayUtils.d_add(ArrayUtils.multiply(dutyCycles, period - 1.0), newInput), period);
    }

    public void updateInhibitionRadius(Connections c) {
        if (c.getGlobalInhibition()) {
            c.setInhibitionRadius(ArrayUtils.max(c.getColumnDimensions()));
            return;
        }
        TDoubleArrayList avgCollected = new TDoubleArrayList();
        int len = c.getNumColumns();
        for (int i = 0; i < len; ++i) {
            avgCollected.add(this.avgConnectedSpanForColumnND(c, i));
        }
        double avgConnectedSpan = ArrayUtils.average(avgCollected.toArray());
        double diameter = avgConnectedSpan * this.avgColumnsPerInput(c);
        double radius = (diameter - 1.0) / 2.0;
        radius = Math.max(1.0, radius);
        c.setInhibitionRadius((int)(radius + 0.5));
    }

    public double avgColumnsPerInput(Connections c) {
        int[] colDim = Arrays.copyOf(c.getColumnDimensions(), c.getColumnDimensions().length);
        int[] inputDim = Arrays.copyOf(c.getInputDimensions(), c.getInputDimensions().length);
        double[] columnsPerInput = ArrayUtils.divide(ArrayUtils.toDoubleArray(colDim), ArrayUtils.toDoubleArray(inputDim), 0.0, 0.0);
        return ArrayUtils.average(columnsPerInput);
    }

    public double avgConnectedSpanForColumnND(Connections c, int columnIndex) {
        int[] dimensions = c.getInputDimensions();
        int[] connected = c.getColumn(columnIndex).getProximalDendrite().getConnectedSynapsesSparse(c);
        if (connected == null || connected.length == 0) {
            return 0.0;
        }
        int[] maxCoord = new int[c.getInputDimensions().length];
        int[] minCoord = new int[c.getInputDimensions().length];
        Arrays.fill(maxCoord, -1);
        Arrays.fill(minCoord, ArrayUtils.max(dimensions));
        SparseMatrix<?> inputMatrix = c.getInputMatrix();
        for (int i = 0; i < connected.length; ++i) {
            maxCoord = ArrayUtils.maxBetween(maxCoord, inputMatrix.computeCoordinates(connected[i]));
            minCoord = ArrayUtils.minBetween(minCoord, inputMatrix.computeCoordinates(connected[i]));
        }
        return ArrayUtils.average(ArrayUtils.add(ArrayUtils.subtract(maxCoord, minCoord), 1));
    }

    public void adaptSynapses(Connections c, int[] inputVector, int[] activeColumns) {
        int[] inputIndices = ArrayUtils.where(inputVector, ArrayUtils.INT_GREATER_THAN_0);
        double[] permChanges = new double[c.getNumInputs()];
        Arrays.fill(permChanges, -1.0 * c.getSynPermInactiveDec());
        ArrayUtils.setIndexesTo(permChanges, inputIndices, c.getSynPermActiveInc());
        for (int i = 0; i < activeColumns.length; ++i) {
            Pool pool = c.getPotentialPools().get(activeColumns[i]);
            double[] perm = pool.getDensePermanences(c);
            int[] indexes = pool.getSparsePotential();
            ArrayUtils.raiseValuesBy(permChanges, perm);
            Column col = c.getColumn(activeColumns[i]);
            this.updatePermanencesForColumn(c, perm, col, indexes, true);
        }
    }

    public void bumpUpWeakColumns(final Connections c) {
        int[] weakColumns = ArrayUtils.where(c.getMemory().get1DIndexes(), new Condition.Adapter<Integer>(){

            @Override
            public boolean eval(int i) {
                return c.getOverlapDutyCycles()[i] < c.getMinOverlapDutyCycles()[i];
            }
        });
        for (int i = 0; i < weakColumns.length; ++i) {
            Pool pool = c.getPotentialPools().get(weakColumns[i]);
            double[] perm = pool.getSparsePermanences();
            ArrayUtils.raiseValuesBy(c.getSynPermBelowStimulusInc(), perm);
            int[] indexes = pool.getSparsePotential();
            Column col = c.getColumn(weakColumns[i]);
            this.updatePermanencesForColumnSparse(c, perm, col, indexes, true);
        }
    }

    public void raisePermanenceToThreshold(Connections c, double[] perm, int[] maskPotential) {
        if ((double)maskPotential.length < c.getStimulusThreshold()) {
            throw new IllegalStateException("This is likely due to a value of stimulusThreshold that is too large relative to the input size. [len(mask) < self._stimulusThreshold]");
        }
        ArrayUtils.clip(perm, c.getSynPermMin(), c.getSynPermMax());
        int numConnected;
        while (!((double)(numConnected = ArrayUtils.valueGreaterCountAtIndex(c.getSynPermConnected(), perm, maskPotential)) >= c.getStimulusThreshold())) {
            ArrayUtils.raiseValuesBy(c.getSynPermBelowStimulusInc(), perm, maskPotential);
        }
        return;
    }

    public void raisePermanenceToThresholdSparse(Connections c, double[] perm) {
        ArrayUtils.clip(perm, c.getSynPermMin(), c.getSynPermMax());
        int numConnected;
        while (!((double)(numConnected = ArrayUtils.valueGreaterCount(c.getSynPermConnected(), perm)) >= c.getStimulusThreshold())) {
            ArrayUtils.raiseValuesBy(c.getSynPermBelowStimulusInc(), perm);
        }
        return;
    }

    public void updatePermanencesForColumn(Connections c, double[] perm, Column column, int[] maskPotential, boolean raisePerm) {
        if (raisePerm) {
            this.raisePermanenceToThreshold(c, perm, maskPotential);
        }
        ArrayUtils.lessThanOrEqualXThanSetToY(perm, c.getSynPermTrimThreshold(), 0.0);
        ArrayUtils.clip(perm, c.getSynPermMin(), c.getSynPermMax());
        column.setProximalPermanences(c, perm);
    }

    public void updatePermanencesForColumnSparse(Connections c, double[] perm, Column column, int[] maskPotential, boolean raisePerm) {
        if (raisePerm) {
            this.raisePermanenceToThresholdSparse(c, perm);
        }
        ArrayUtils.lessThanOrEqualXThanSetToY(perm, c.getSynPermTrimThreshold(), 0.0);
        ArrayUtils.clip(perm, c.getSynPermMin(), c.getSynPermMax());
        column.setProximalPermanencesSparse(c, perm, maskPotential);
    }

    public static double initPermConnected(Connections c) {
        double p = c.getSynPermConnected() + (c.getSynPermMax() - c.getSynPermConnected()) * c.random.nextDouble();
        p = (double)((int)(p * 100000.0)) / 100000.0;
        return p;
    }

    public static double initPermNonConnected(Connections c) {
        double p = c.getSynPermConnected() * c.getRandom().nextDouble();
        p = (double)((int)(p * 100000.0)) / 100000.0;
        return p;
    }

    public double[] initPermanence(Connections c, int[] potentialPool, int index, double connectedPct) {
        double[] perm = new double[c.getNumInputs()];
        for (int idx : potentialPool) {
            perm[idx] = c.random.nextDouble() <= connectedPct ? SpatialPooler.initPermConnected(c) : SpatialPooler.initPermNonConnected(c);
            perm[idx] = perm[idx] < c.getSynPermTrimThreshold() ? 0.0 : perm[idx];
        }
        c.getColumn(index).setProximalPermanences(c, perm);
        return perm;
    }

    public int mapColumn(Connections c, int columnIndex) {
        int[] columnCoords = c.getMemory().computeCoordinates(columnIndex);
        double[] colCoords = ArrayUtils.toDoubleArray(columnCoords);
        double[] ratios = ArrayUtils.divide(colCoords, ArrayUtils.toDoubleArray(c.getColumnDimensions()), 0.0, 0.0);
        double[] inputCoords = ArrayUtils.multiply(ArrayUtils.toDoubleArray(c.getInputDimensions()), ratios, 0.0, 0.0);
        inputCoords = ArrayUtils.d_add(inputCoords, ArrayUtils.multiply(ArrayUtils.divide(ArrayUtils.toDoubleArray(c.getInputDimensions()), ArrayUtils.toDoubleArray(c.getColumnDimensions()), 0.0, 0.0), 0.5));
        int[] inputCoordInts = ArrayUtils.clip(ArrayUtils.toIntArray(inputCoords), c.getInputDimensions(), -1);
        return c.getInputMatrix().computeIndex(inputCoordInts);
    }

    public int[] mapPotential(Connections c, int columnIndex, boolean wrapAround) {
        int centerInput = this.mapColumn(c, columnIndex);
        int[] columnInputs = this.getInputNeighborhood(c, centerInput, c.getPotentialRadius());
        int numPotential = (int)((double)columnInputs.length * c.getPotentialPct() + 0.5);
        int[] retVal = new int[numPotential];
        return ArrayUtils.sample(columnInputs, retVal, c.getRandom());
    }

    public int[] inhibitColumns(Connections c, double[] overlaps) {
        double d;
        overlaps = Arrays.copyOf(overlaps, overlaps.length);
        double density = c.getLocalAreaDensity();
        if (d <= 0.0) {
            double inhibitionArea = Math.pow(2 * c.getInhibitionRadius() + 1, c.getColumnDimensions().length);
            inhibitionArea = Math.min((double)c.getNumColumns(), inhibitionArea);
            density = c.getNumActiveColumnsPerInhArea() / inhibitionArea;
            density = Math.min(density, 0.5);
        }
        if (c.getGlobalInhibition() || c.getInhibitionRadius() > ArrayUtils.max(c.getColumnDimensions())) {
            return this.inhibitColumnsGlobal(c, overlaps, density);
        }
        return this.inhibitColumnsLocal(c, overlaps, density);
    }

    public int[] inhibitColumnsGlobal(Connections c, double[] overlaps, double density) {
        int i2;
        int start;
        int numCols = c.getNumColumns();
        int numActive = (int)(density * (double)numCols);
        int[] sortedWinnerIndices = IntStream.range(0, overlaps.length).mapToObj(i -> new Pair((Object)i, (Object)overlaps[i])).sorted(c.inhibitionComparator).mapToInt(Pair::getFirst).toArray();
        double stimulusThreshold = c.getStimulusThreshold();
        for (start = sortedWinnerIndices.length - numActive; start < sortedWinnerIndices.length && !(overlaps[i2 = sortedWinnerIndices[start]] >= stimulusThreshold); ++start) {
        }
        return IntStream.of(sortedWinnerIndices).skip(start).toArray();
    }

    public int[] inhibitColumnsLocal(Connections c, double[] overlaps, double density) {
        double addToWinners = ArrayUtils.max(overlaps) / 1000.0;
        if (addToWinners == 0.0) {
            addToWinners = 0.001;
        }
        double[] tieBrokenOverlaps = Arrays.copyOf(overlaps, overlaps.length);
        TIntArrayList winners = new TIntArrayList();
        double stimulusThreshold = c.getStimulusThreshold();
        int inhibitionRadius = c.getInhibitionRadius();
        for (int i = 0; i < overlaps.length; ++i) {
            int numActive;
            int[] neighborhood;
            double[] neighborhoodOverlaps;
            long numBigger;
            int column = i;
            if (!(overlaps[column] >= stimulusThreshold) || (numBigger = Arrays.stream(neighborhoodOverlaps = ArrayUtils.sub(tieBrokenOverlaps, neighborhood = this.getColumnNeighborhood(c, column, inhibitionRadius))).parallel().filter(d -> d > overlaps[column]).count()) >= (long)(numActive = (int)(0.5 + density * (double)neighborhood.length))) continue;
            winners.add(column);
            int n = column;
            tieBrokenOverlaps[n] = tieBrokenOverlaps[n] + addToWinners;
        }
        return winners.toArray();
    }

    public void updateBoostFactors(Connections c) {
        double[] boostInterim;
        double[] activeDutyCycles = c.getActiveDutyCycles();
        final double[] minActiveDutyCycles = c.getMinActiveDutyCycles();
        int[] mask = ArrayUtils.where(minActiveDutyCycles, ArrayUtils.GREATER_THAN_0);
        if (mask.length < 1) {
            boostInterim = c.getBoostFactors();
        } else {
            double[] numerator = new double[c.getNumColumns()];
            Arrays.fill(numerator, 1.0 - c.getMaxBoost());
            boostInterim = ArrayUtils.divide(numerator, minActiveDutyCycles, 0.0, 0.0);
            boostInterim = ArrayUtils.multiply(boostInterim, activeDutyCycles, 0.0, 0.0);
            boostInterim = ArrayUtils.d_add(boostInterim, c.getMaxBoost());
        }
        ArrayUtils.setIndexesTo(boostInterim, ArrayUtils.where(activeDutyCycles, new Condition.Adapter<Object>(){
            int i = 0;

            @Override
            public boolean eval(double d) {
                return d > minActiveDutyCycles[this.i++];
            }
        }), 1.0);
        c.setBoostFactors(boostInterim);
    }

    public int[] calculateOverlap(Connections c, int[] inputVector) {
        int[] overlaps = new int[c.getNumColumns()];
        c.getConnectedCounts().rightVecSumAtNZ(inputVector, overlaps, c.getStimulusThreshold());
        return overlaps;
    }

    public double[] calculateOverlapPct(Connections c, int[] overlaps) {
        return ArrayUtils.divide(overlaps, c.getConnectedCounts().getTrueCounts());
    }

    public boolean isUpdateRound(Connections c) {
        return c.getIterationNum() % c.getUpdatePeriod() == 0;
    }

    public void updateBookeepingVars(Connections c, boolean learn) {
        ++c.spIterationNum;
        if (learn) {
            ++c.spIterationLearnNum;
        }
    }

    public int[] getColumnNeighborhood(Connections c, int centerColumn, int inhibitionRadius) {
        return c.isWrapAround() ? c.getColumnTopology().wrappingNeighborhood(centerColumn, inhibitionRadius) : c.getColumnTopology().neighborhood(centerColumn, inhibitionRadius);
    }

    public int[] getInputNeighborhood(Connections c, int centerInput, int potentialRadius) {
        return c.isWrapAround() ? c.getInputTopology().wrappingNeighborhood(centerInput, potentialRadius) : c.getInputTopology().neighborhood(centerInput, potentialRadius);
    }

    class InvalidSPParamValueException
    extends RuntimeException {
        private static final long serialVersionUID = 1L;

        public InvalidSPParamValueException(String message) {
            super(message);
        }
    }
}

