/*
 * Decompiled with CFR 0.152.
 */
package org.numenta.nupic.network;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;
import org.joda.time.DateTime;
import org.numenta.nupic.Parameters;
import org.numenta.nupic.encoders.MultiEncoder;
import org.numenta.nupic.model.Persistable;
import org.numenta.nupic.network.CheckPointOp;
import org.numenta.nupic.network.Inference;
import org.numenta.nupic.network.Layer;
import org.numenta.nupic.network.Region;
import org.numenta.nupic.network.sensor.HTMSensor;
import org.numenta.nupic.network.sensor.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import rx.Observable;
import rx.Subscriber;

public class Network
implements Persistable {
    private static final long serialVersionUID = 1L;
    private static final Logger LOGGER = LoggerFactory.getLogger(Network.class);
    private String name;
    private Parameters parameters;
    private HTMSensor<?> sensor;
    private MultiEncoder encoder;
    private Region head;
    private Region tail;
    private Region sensorRegion;
    private volatile Publisher publisher;
    private volatile boolean isLearn = true;
    private volatile boolean isThreadRunning;
    private List<Region> regions = new ArrayList<Region>();
    private transient Function<Persistable, ?> checkPointFunction;
    boolean shouldDoHalt = true;

    Network() {
    }

    public Network(String name, Parameters parameters) {
        if (name == null || name.isEmpty()) {
            throw new IllegalStateException("All Networks must have a name. Increases digestion, and overall happiness!");
        }
        this.name = name;
        this.parameters = parameters;
        if (parameters == null) {
            throw new IllegalArgumentException("Network Parameters were null.");
        }
    }

    public static Network create(String name, Parameters parameters) {
        return new Network(name, parameters);
    }

    public static Region createRegion(String name) {
        Network.checkName(name);
        Region r = new Region(name, null);
        return r;
    }

    public static Layer<?> createLayer(String name, Parameters p) {
        Network.checkName(name);
        return new Layer(name, null, p);
    }

    public Network preSerialize() {
        if (this.shouldDoHalt && this.isThreadRunning) {
            this.halt();
        } else {
            if (this.regions.size() == 1) {
                this.tail = this.regions.get(0);
            }
            this.tail.close();
        }
        this.regions.stream().forEach(r -> r.preSerialize());
        return this;
    }

    public Network postDeSerialize() {
        this.regions.stream().forEach(r -> r.setNetwork(this));
        this.regions.stream().forEach(r -> r.postDeSerialize());
        if (this.isMultiRegion()) {
            Region curr = this.head;
            Region nxt = curr.getUpstreamRegion();
            do {
                curr.connect(nxt);
            } while ((curr = nxt) != null && (nxt = nxt.getUpstreamRegion()) != null);
        }
        return this;
    }

    byte[] internalCheckPointOp() {
        this.shouldDoHalt = false;
        byte[] serializedBytes = (byte[])this.checkPointFunction.apply(this);
        this.shouldDoHalt = true;
        return serializedBytes;
    }

    public <T extends Persistable, R> void setCheckPointFunction(Function<T, R> f) {
        this.checkPointFunction = f;
    }

    CheckPointOp<byte[]> getCheckPointOperator() {
        LOGGER.debug("Network [" + this.getName() + "] called checkPoint() at: " + new DateTime());
        if (this.regions.size() == 1) {
            this.tail = this.regions.get(0);
        }
        return this.tail.getCheckPointOperator();
    }

    public void restart() {
        this.restart(true);
    }

    public void restart(boolean startAtIndex) {
        Region tail;
        if (this.regions.size() < 1) {
            throw new IllegalStateException("Nothing to start - 0 regions");
        }
        Region upstream = tail = this.regions.get(0);
        while ((upstream = upstream.getUpstreamRegion()) != null) {
            tail = upstream;
        }
        this.isThreadRunning = tail.restart(startAtIndex);
    }

    void setPublisher(Publisher p) {
        this.publisher = p;
        this.publisher.setNetwork(this);
    }

    public Publisher getPublisher() {
        if (this.publisher == null) {
            throw new NullPointerException("A Supplier must be built first. please see Network.getPublisherSupplier()");
        }
        return this.publisher;
    }

    public boolean isMultiRegion() {
        return this.regions.size() > 1;
    }

    public String getName() {
        return this.name;
    }

    public void start() {
        Region tail;
        if (this.regions.size() < 1) {
            throw new IllegalStateException("Nothing to start - 0 regions");
        }
        Region upstream = tail = this.regions.get(0);
        while ((upstream = upstream.getUpstreamRegion()) != null) {
            tail = upstream;
        }
        this.isThreadRunning = tail.start();
    }

    public boolean isThreadedOperation() {
        return this.isThreadRunning;
    }

    public void halt() {
        if (this.publisher != null) {
            this.publisher.onComplete();
        }
        if (this.regions.size() == 1) {
            this.tail = this.regions.get(0);
        }
        this.tail.halt();
    }

    public boolean isHalted() {
        if (this.regions.size() == 1) {
            this.tail = this.regions.get(0);
        }
        return this.tail.isHalted();
    }

    public int getRecordNum() {
        if (this.regions.size() == 1) {
            this.tail = this.regions.get(0);
        }
        return this.tail.getTail().getRecordNum();
    }

    public void setLearn(boolean isLearn) {
        this.isLearn = isLearn;
        for (Region r : this.regions) {
            r.setLearn(isLearn);
        }
    }

    public boolean isLearn() {
        return this.isLearn;
    }

    public void reset() {
        for (Region r : this.regions) {
            r.reset();
        }
    }

    public void resetRecordNum() {
        for (Region r : this.regions) {
            r.resetRecordNum();
        }
    }

    public Observable<Inference> observe() {
        if (this.regions.size() == 1) {
            this.head = this.regions.get(0);
        }
        return this.head.observe();
    }

    public Region getHead() {
        if (this.regions.size() == 1) {
            this.head = this.regions.get(0);
        }
        return this.head;
    }

    public Region getTail() {
        if (this.regions.size() == 1) {
            this.tail = this.regions.get(0);
        }
        return this.tail;
    }

    boolean isTail(Layer<?> l) {
        if (this.regions.size() == 1) {
            this.tail = this.regions.get(0);
        }
        return this.tail.getTail() == l;
    }

    public Iterator<Region> iterator() {
        return this.getRegions().iterator();
    }

    public <T> void compute(T input) {
        if (this.tail == null && this.regions.size() == 1) {
            this.tail = this.regions.get(0);
        }
        if (this.head == null) {
            this.addDummySubscriber();
        }
        this.tail.compute(input);
    }

    public <T> Inference computeImmediate(T input) {
        if (this.isThreadRunning) {
            throw new IllegalStateException("Cannot call computeImmediate() when Network has been started.");
        }
        if (this.tail == null && this.regions.size() == 1) {
            this.tail = this.regions.get(0);
        }
        if (this.head == null) {
            this.addDummySubscriber();
        }
        this.tail.compute(input);
        return this.head.getHead().getInference();
    }

    void addDummySubscriber() {
        this.observe().subscribe((Subscriber)new Subscriber<Inference>(){

            public void onCompleted() {
            }

            public void onError(Throwable e) {
                e.printStackTrace();
            }

            public void onNext(Inference i) {
            }
        });
    }

    public Network connect(String regionSink, String regionSource) {
        Region source = this.lookup(regionSource);
        if (source == null) {
            throw new IllegalArgumentException("Region with name: " + regionSource + " not added to Network.");
        }
        Region sink = this.lookup(regionSink);
        if (sink == null) {
            throw new IllegalArgumentException("Region with name: " + regionSink + " not added to Network.");
        }
        sink.connect(source);
        this.tail = this.tail == null ? source : this.tail;
        this.head = this.head == null ? sink : this.head;
        Region bottom = source;
        while ((bottom = bottom.getUpstreamRegion()) != null) {
            this.tail = bottom;
        }
        Region top = sink;
        while ((top = top.getDownstreamRegion()) != null) {
            this.head = top;
        }
        return this;
    }

    public Network add(Region region) {
        this.regions.add(region);
        region.setNetwork(this);
        return this;
    }

    public Network close() {
        this.regions.forEach(region -> region.close());
        return this;
    }

    public List<Region> getRegions() {
        return new ArrayList<Region>(this.regions);
    }

    public void setSensorRegion(Region r) {
        this.sensorRegion = r;
    }

    public Region getSensorRegion() {
        return this.sensorRegion;
    }

    public Region lookup(String regionName) {
        for (Region r : this.regions) {
            if (!r.getName().equals(regionName)) continue;
            return r;
        }
        return null;
    }

    public Parameters getParameters() {
        return this.parameters;
    }

    public void setSensor(HTMSensor<?> sensor) {
        this.sensor = sensor;
        this.sensor.initEncoder(this.parameters);
    }

    public HTMSensor<?> getSensor() {
        return this.sensor;
    }

    public void setEncoder(MultiEncoder e) {
        this.encoder = e;
    }

    public MultiEncoder getEncoder() {
        return this.encoder;
    }

    private static void checkName(String name) {
        if (name.indexOf(":") != -1) {
            throw new IllegalArgumentException("\":\" is a reserved character.");
        }
    }

    public int hashCode() {
        int prime = 31;
        int result = 1;
        result = 31 * result + (this.isLearn ? 1231 : 1237);
        result = 31 * result + (this.name == null ? 0 : this.name.hashCode());
        result = 31 * result + (this.parameters == null ? 0 : this.parameters.hashCode());
        result = 31 * result + (this.regions == null ? 0 : this.regions.hashCode());
        result = 31 * result + (this.sensor == null ? 0 : this.sensor.hashCode());
        return result;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null) {
            return false;
        }
        if (this.getClass() != obj.getClass()) {
            return false;
        }
        Network other = (Network)obj;
        if (this.isLearn != other.isLearn) {
            return false;
        }
        if (this.name == null ? other.name != null : !this.name.equals(other.name)) {
            return false;
        }
        if (this.parameters == null ? other.parameters != null : !this.parameters.equals(other.parameters)) {
            return false;
        }
        if (this.regions == null ? other.regions != null : !this.regions.equals(other.regions)) {
            return false;
        }
        return !(this.sensor == null ? other.sensor != null : !this.sensor.equals(other.sensor));
    }
}

