/*
 * Decompiled with CFR 0.152.
 */
package org.numenta.nupic.algorithms;

import chaschev.lang.Pair;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.numenta.nupic.model.Cell;
import org.numenta.nupic.model.Column;
import org.numenta.nupic.model.ComputeCycle;
import org.numenta.nupic.model.Connections;
import org.numenta.nupic.model.DistalDendrite;
import org.numenta.nupic.model.Synapse;
import org.numenta.nupic.monitor.ComputeDecorator;
import org.numenta.nupic.util.GroupBy2;
import org.numenta.nupic.util.SparseObjectMatrix;
import org.numenta.nupic.util.Tuple;

public class TemporalMemory
implements ComputeDecorator,
Serializable {
    private static final long serialVersionUID = 1L;
    private static final double EPSILON = 1.0E-5;
    private static final int ACTIVE_COLUMNS = 1;

    public static void init(Connections c) {
        SparseObjectMatrix<Column> matrix = c.getMemory() == null ? new SparseObjectMatrix<Column>(c.getColumnDimensions()) : c.getMemory();
        c.setMemory(matrix);
        int numColumns = matrix.getMaxIndex() + 1;
        c.setNumColumns(numColumns);
        int cellsPerColumn = c.getCellsPerColumn();
        Cell[] cells = new Cell[numColumns * cellsPerColumn];
        Column colZero = matrix.getObject(0);
        for (int i = 0; i < numColumns; ++i) {
            Column column = colZero == null ? new Column(cellsPerColumn, i) : matrix.getObject(i);
            for (int j = 0; j < cellsPerColumn; ++j) {
                cells[i * cellsPerColumn + j] = column.getCell(j);
            }
            if (colZero != null) continue;
            matrix.set(i, (Object)column);
        }
        c.setCells(cells);
    }

    @Override
    public ComputeCycle compute(Connections connections, int[] activeColumns, boolean learn) {
        ComputeCycle cycle = new ComputeCycle();
        this.activateCells(connections, cycle, activeColumns, learn);
        this.activateDendrites(connections, cycle, learn);
        return cycle;
    }

    public void activateCells(Connections conn, ComputeCycle cycle, int[] activeColumnIndices, boolean learn) {
        ColumnData columnData = new ColumnData();
        Set<Cell> prevActiveCells = conn.getActiveCells();
        Set<Cell> prevWinnerCells = conn.getWinnerCells();
        List activeColumns = Arrays.stream(activeColumnIndices).sorted().mapToObj(i -> conn.getColumn(i)).collect(Collectors.toList());
        Function identity = Function.identity();
        Function<DistalDendrite, Column> segToCol = segment -> segment.getParentCell().getColumn();
        GroupBy2 grouper = GroupBy2.of(new Pair(activeColumns, identity), new Pair(new ArrayList<DistalDendrite>(conn.getActiveSegments()), segToCol), new Pair(new ArrayList<DistalDendrite>(conn.getMatchingSegments()), segToCol));
        double permanenceIncrement = conn.getPermanenceIncrement();
        double permanenceDecrement = conn.getPermanenceDecrement();
        for (Tuple t : grouper) {
            if ((columnData = columnData.set(t)).isNotNone(1)) {
                if (!columnData.activeSegments().isEmpty()) {
                    List<Cell> cellsToAdd = this.activatePredictedColumn(conn, columnData.activeSegments(), columnData.matchingSegments(), prevActiveCells, prevWinnerCells, permanenceIncrement, permanenceDecrement, learn);
                    cycle.activeCells.addAll(cellsToAdd);
                    cycle.winnerCells.addAll(cellsToAdd);
                    continue;
                }
                Tuple cellsXwinnerCell = this.burstColumn(conn, columnData.column(), columnData.matchingSegments(), prevActiveCells, prevWinnerCells, permanenceIncrement, permanenceDecrement, conn.getRandom(), learn);
                cycle.activeCells.addAll((List)cellsXwinnerCell.get(0));
                cycle.winnerCells.add((Cell)cellsXwinnerCell.get(1));
                continue;
            }
            if (!learn) continue;
            this.punishPredictedColumn(conn, columnData.activeSegments(), columnData.matchingSegments(), prevActiveCells, prevWinnerCells, conn.getPredictedSegmentDecrement());
        }
    }

    public void activateDendrites(Connections conn, ComputeCycle cycle, boolean learn) {
        Connections.Activity activity = conn.computeActivity(cycle.activeCells, conn.getConnectedPermanence());
        List<DistalDendrite> activeSegments = IntStream.range(0, activity.numActiveConnected.length).filter(i -> activity.numActiveConnected[i] >= conn.getActivationThreshold()).mapToObj(i -> conn.segmentForFlatIdx(i)).collect(Collectors.toList());
        List<DistalDendrite> matchingSegments = IntStream.range(0, activity.numActiveConnected.length).filter(i -> activity.numActivePotential[i] >= conn.getMinThreshold()).mapToObj(i -> conn.segmentForFlatIdx(i)).collect(Collectors.toList());
        Collections.sort(activeSegments, conn.segmentPositionSortKey);
        Collections.sort(matchingSegments, conn.segmentPositionSortKey);
        cycle.activeSegments = activeSegments;
        cycle.matchingSegments = matchingSegments;
        conn.lastActivity = activity;
        conn.setActiveCells(new LinkedHashSet<Cell>(cycle.activeCells));
        conn.setWinnerCells(new LinkedHashSet<Cell>(cycle.winnerCells));
        conn.setActiveSegments(activeSegments);
        conn.setMatchingSegments(matchingSegments);
        conn.clearPredictiveCells();
        conn.getPredictiveCells();
        if (learn) {
            activeSegments.stream().forEach(s -> conn.recordSegmentActivity((DistalDendrite)s));
            conn.startNewIteration();
        }
    }

    @Override
    public void reset(Connections connections) {
        connections.getActiveCells().clear();
        connections.getWinnerCells().clear();
        connections.getActiveSegments().clear();
        connections.getMatchingSegments().clear();
    }

    public List<Cell> activatePredictedColumn(Connections conn, List<DistalDendrite> activeSegments, List<DistalDendrite> matchingSegments, Set<Cell> prevActiveCells, Set<Cell> prevWinnerCells, double permanenceIncrement, double permanenceDecrement, boolean learn) {
        ArrayList<Cell> cellsToAdd = new ArrayList<Cell>();
        Cell previousCell = null;
        for (DistalDendrite segment : activeSegments) {
            Cell currCell = segment.getParentCell();
            if (currCell != previousCell) {
                cellsToAdd.add(currCell);
                previousCell = currCell;
            }
            if (!learn) continue;
            this.adaptSegment(conn, segment, prevActiveCells, permanenceIncrement, permanenceDecrement);
            int numActive = conn.getLastActivity().numActivePotential[segment.getIndex()];
            int nGrowDesired = conn.getMaxNewSynapseCount() - numActive;
            if (nGrowDesired <= 0) continue;
            this.growSynapses(conn, prevWinnerCells, segment, conn.getInitialPermanence(), nGrowDesired, conn.getRandom());
        }
        return cellsToAdd;
    }

    public Tuple burstColumn(Connections conn, Column column, List<DistalDendrite> matchingSegments, Set<Cell> prevActiveCells, Set<Cell> prevWinnerCells, double permanenceIncrement, double permanenceDecrement, Random random, boolean learn) {
        List<Cell> cells = column.getCells();
        Cell bestCell = null;
        if (!matchingSegments.isEmpty()) {
            int[] numPoten = conn.getLastActivity().numActivePotential;
            Comparator cmp = (dd1, dd2) -> numPoten[dd1.getIndex()] - numPoten[dd2.getIndex()];
            DistalDendrite bestSegment = (DistalDendrite)matchingSegments.stream().max(cmp).get();
            bestCell = bestSegment.getParentCell();
            if (learn) {
                this.adaptSegment(conn, bestSegment, prevActiveCells, permanenceIncrement, permanenceDecrement);
                int nGrowDesired = conn.getMaxNewSynapseCount() - numPoten[bestSegment.getIndex()];
                if (nGrowDesired > 0) {
                    this.growSynapses(conn, prevWinnerCells, bestSegment, conn.getInitialPermanence(), nGrowDesired, random);
                }
            }
        } else {
            int nGrowExact;
            bestCell = this.leastUsedCell(conn, cells, random);
            if (learn && (nGrowExact = Math.min(conn.getMaxNewSynapseCount(), prevWinnerCells.size())) > 0) {
                DistalDendrite bestSegment = conn.createSegment(bestCell);
                this.growSynapses(conn, prevWinnerCells, bestSegment, conn.getInitialPermanence(), nGrowExact, random);
            }
        }
        return new Tuple(cells, bestCell);
    }

    public void punishPredictedColumn(Connections conn, List<DistalDendrite> activeSegments, List<DistalDendrite> matchingSegments, Set<Cell> prevActiveCells, Set<Cell> prevWinnerCells, double predictedSegmentDecrement) {
        if (predictedSegmentDecrement > 0.0) {
            for (DistalDendrite segment : matchingSegments) {
                this.adaptSegment(conn, segment, prevActiveCells, -conn.getPredictedSegmentDecrement(), 0.0);
            }
        }
    }

    public Cell leastUsedCell(Connections conn, List<Cell> cells, Random random) {
        ArrayList<Cell> leastUsedCells = new ArrayList<Cell>();
        int minNumSegments = Integer.MAX_VALUE;
        for (Cell cell : cells) {
            int numSegments = conn.numSegments(cell);
            if (numSegments < minNumSegments) {
                minNumSegments = numSegments;
                leastUsedCells.clear();
            }
            if (numSegments != minNumSegments) continue;
            leastUsedCells.add(cell);
        }
        int i = random.nextInt(leastUsedCells.size());
        return (Cell)leastUsedCells.get(i);
    }

    public void growSynapses(Connections conn, Set<Cell> prevWinnerCells, DistalDendrite segment, double initialPermanence, int nDesiredNewSynapses, Random random) {
        ArrayList<Cell> candidates = new ArrayList<Cell>(prevWinnerCells);
        Collections.sort(candidates);
        for (Synapse synapse : conn.getSynapses(segment)) {
            Cell presynapticCell = synapse.getPresynapticCell();
            int index = candidates.indexOf(presynapticCell);
            if (index == -1) continue;
            candidates.remove(index);
        }
        int candidatesLength = candidates.size();
        int nActual = nDesiredNewSynapses < candidatesLength ? nDesiredNewSynapses : candidatesLength;
        for (int i = 0; i < nActual; ++i) {
            int rand = random.nextInt(candidates.size());
            conn.createSynapse(segment, (Cell)candidates.get(rand), initialPermanence);
            candidates.remove(rand);
        }
    }

    public void adaptSegment(Connections conn, DistalDendrite segment, Set<Cell> prevActiveCells, double permanenceIncrement, double permanenceDecrement) {
        ArrayList<Synapse> synapsesToDestroy = new ArrayList<Synapse>();
        for (Synapse synapse : conn.getSynapses(segment)) {
            double permanence = synapse.getPermanence();
            permanence = prevActiveCells.contains(synapse.getPresynapticCell()) ? (permanence += permanenceIncrement) : (permanence -= permanenceDecrement);
            double d = permanence < 0.0 ? 0.0 : (permanence = permanence > 1.0 ? 1.0 : permanence);
            if (permanence < 1.0E-5) {
                synapsesToDestroy.add(synapse);
                continue;
            }
            synapse.setPermanence(conn, permanence);
        }
        for (Synapse s : synapsesToDestroy) {
            conn.destroySynapse(s);
        }
        if (conn.numSynapses(segment) == 0L) {
            conn.destroySegment(segment);
        }
    }

    public static class ColumnData
    implements Serializable {
        private static final long serialVersionUID = 1L;
        Tuple t;

        public ColumnData() {
        }

        public ColumnData(Tuple t) {
            this.t = t;
        }

        public Column column() {
            return (Column)this.t.get(0);
        }

        public List<Column> activeColumns() {
            return (List)this.t.get(1);
        }

        public List<DistalDendrite> activeSegments() {
            return ((List)this.t.get(2)).get(0).equals(GroupBy2.Slot.empty()) ? Collections.emptyList() : (List)this.t.get(2);
        }

        public List<DistalDendrite> matchingSegments() {
            return ((List)this.t.get(3)).get(0).equals(GroupBy2.Slot.empty()) ? Collections.emptyList() : (List)this.t.get(3);
        }

        public ColumnData set(Tuple t) {
            this.t = t;
            return this;
        }

        public boolean isNotNone(int memberIndex) {
            return !((List)this.t.get(memberIndex)).get(0).equals(GroupBy2.Slot.NONE);
        }
    }
}

