/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hugegraph.computer.algorithm.sampling;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.hugegraph.computer.algorithm.sampling.RandomWalkMessage;
import org.apache.hugegraph.computer.core.common.exception.ComputerException;
import org.apache.hugegraph.computer.core.config.Config;
import org.apache.hugegraph.computer.core.graph.edge.Edge;
import org.apache.hugegraph.computer.core.graph.edge.Edges;
import org.apache.hugegraph.computer.core.graph.id.Id;
import org.apache.hugegraph.computer.core.graph.value.DoubleValue;
import org.apache.hugegraph.computer.core.graph.value.IdList;
import org.apache.hugegraph.computer.core.graph.value.IdListList;
import org.apache.hugegraph.computer.core.graph.value.Value;
import org.apache.hugegraph.computer.core.graph.vertex.Vertex;
import org.apache.hugegraph.computer.core.worker.Computation;
import org.apache.hugegraph.computer.core.worker.ComputationContext;
import org.apache.hugegraph.util.Log;
import org.slf4j.Logger;

public class RandomWalk
implements Computation<RandomWalkMessage> {
    private static final Logger LOG = Log.logger(RandomWalk.class);
    public static final String OPTION_WALK_PER_NODE = "random_walk.walk_per_node";
    public static final String OPTION_WALK_LENGTH = "random_walk.walk_length";
    public static final String OPTION_WEIGHT_PROPERTY = "random_walk.weight_property";
    public static final String OPTION_DEFAULT_WEIGHT = "random_walk.default_weight";
    public static final String OPTION_MIN_WEIGHT_THRESHOLD = "random_walk.min_weight_threshold";
    public static final String OPTION_MAX_WEIGHT_THRESHOLD = "random_walk.max_weight_threshold";
    public static final String OPTION_RETURN_FACTOR = "random_walk.return_factor";
    public static final String OPTION_INOUT_FACTOR = "random_walk.inout_factor";
    private Random random;
    private Integer walkPerNode;
    private Integer walkLength;
    private String weightProperty;
    private Double defaultWeight;
    private Double minWeightThreshold;
    private Double maxWeightThreshold;
    private Double returnFactor;
    private Double inOutFactor;

    public String category() {
        return "sampling";
    }

    public String name() {
        return "random_walk";
    }

    public void init(Config config) {
        this.random = new Random();
        this.walkPerNode = config.getInt(OPTION_WALK_PER_NODE, 3);
        if (this.walkPerNode <= 0) {
            throw new ComputerException("The param %s must be greater than 0, actual got '%s'", new Object[]{OPTION_WALK_PER_NODE, this.walkPerNode});
        }
        this.walkLength = config.getInt(OPTION_WALK_LENGTH, 3);
        if (this.walkLength <= 0) {
            throw new ComputerException("The param %s must be greater than 0, actual got '%s'", new Object[]{OPTION_WALK_LENGTH, this.walkLength});
        }
        this.weightProperty = config.getString(OPTION_WEIGHT_PROPERTY, "");
        this.defaultWeight = config.getDouble(OPTION_DEFAULT_WEIGHT, 1.0);
        if (this.defaultWeight <= 0.0) {
            throw new ComputerException("The param %s must be greater than 0, actual got '%s'", new Object[]{OPTION_DEFAULT_WEIGHT, this.defaultWeight});
        }
        this.minWeightThreshold = config.getDouble(OPTION_MIN_WEIGHT_THRESHOLD, 0.0);
        if (this.minWeightThreshold < 0.0) {
            throw new ComputerException("The param %s must be greater than or equal 0, actual got '%s'", new Object[]{OPTION_MIN_WEIGHT_THRESHOLD, this.minWeightThreshold});
        }
        this.maxWeightThreshold = config.getDouble(OPTION_MAX_WEIGHT_THRESHOLD, Double.MAX_VALUE);
        if (this.maxWeightThreshold < 0.0) {
            throw new ComputerException("The param %s must be greater than or equal 0, actual got '%s'", new Object[]{OPTION_MAX_WEIGHT_THRESHOLD, this.maxWeightThreshold});
        }
        if (this.minWeightThreshold > this.maxWeightThreshold) {
            throw new ComputerException("%s must be greater than or equal %s, ", new Object[]{OPTION_MAX_WEIGHT_THRESHOLD, OPTION_MIN_WEIGHT_THRESHOLD});
        }
        this.returnFactor = config.getDouble(OPTION_RETURN_FACTOR, 1.0);
        if (this.returnFactor <= 0.0) {
            throw new ComputerException("The param %s must be greater than 0, actual got '%s'", new Object[]{OPTION_RETURN_FACTOR, this.returnFactor});
        }
        this.inOutFactor = config.getDouble(OPTION_INOUT_FACTOR, 1.0);
        if (this.inOutFactor <= 0.0) {
            throw new ComputerException("The param %s must be greater than 0, actual got '%s'", new Object[]{OPTION_INOUT_FACTOR, this.inOutFactor});
        }
    }

    public void compute0(ComputationContext context, Vertex vertex) {
        vertex.value((Value)new IdListList());
        RandomWalkMessage message = new RandomWalkMessage();
        message.addToPath(vertex);
        if (vertex.numEdges() <= 0) {
            this.savePath(vertex, message.path());
            vertex.inactivate();
            return;
        }
        vertex.edges().forEach(edge -> message.addToPreVertexAdjacence(edge.targetId()));
        for (int i = 0; i < this.walkPerNode; ++i) {
            Edge selectedEdge = this.randomSelectEdge(null, null, vertex.edges());
            context.sendMessage(selectedEdge.targetId(), (Value)message);
        }
        vertex.inactivate();
    }

    public void compute(ComputationContext context, Vertex vertex, Iterator<RandomWalkMessage> messages) {
        while (messages.hasNext()) {
            RandomWalkMessage message = messages.next();
            Id preVertexId = (Id)message.path().getLast();
            if (message.isFinish()) {
                this.savePath(vertex, message.path());
                vertex.inactivate();
                continue;
            }
            message.addToPath(vertex);
            if (vertex.numEdges() <= 0) {
                message.finish();
                context.sendMessage(this.getSourceId(message.path()), (Value)message);
                vertex.inactivate();
                continue;
            }
            if (message.path().size() >= this.walkLength + 1) {
                message.finish();
                Id sourceId = this.getSourceId(message.path());
                if (vertex.id().equals(sourceId)) {
                    this.savePath(vertex, message.path());
                } else {
                    context.sendMessage(sourceId, (Value)message);
                }
                vertex.inactivate();
                continue;
            }
            vertex.edges().forEach(edge -> message.addToPreVertexAdjacence(edge.targetId()));
            Edge selectedEdge = this.randomSelectEdge(preVertexId, message.preVertexAdjacence(), vertex.edges());
            context.sendMessage(selectedEdge.targetId(), (Value)message);
        }
        vertex.inactivate();
    }

    private Edge randomSelectEdge(Id preVertexId, IdList preVertexAdjacenceIdList, Edges edges) {
        ArrayList<Double> weightList = new ArrayList<Double>();
        for (Edge edge : edges) {
            double weight = this.getEdgeWeight(edge);
            double finalWeight = this.calculateEdgeWeight(preVertexId, preVertexAdjacenceIdList, edge.targetId(), weight);
            weightList.add(finalWeight);
        }
        int selectedIndex = this.randomSelectIndex(weightList);
        Edge selectedEdge = this.selectEdge(edges.iterator(), selectedIndex);
        return selectedEdge;
    }

    private double getEdgeWeight(Edge edge) {
        double weight = this.defaultWeight;
        Value property = edge.property(this.weightProperty);
        if (property != null) {
            if (!property.isNumber()) {
                throw new ComputerException("The value of %s must be a numeric value, actual got '%s'", new Object[]{this.weightProperty, property.string()});
            }
            weight = ((DoubleValue)property).doubleValue();
        }
        if (weight < this.minWeightThreshold) {
            weight = this.minWeightThreshold;
        }
        if (weight > this.maxWeightThreshold) {
            weight = this.maxWeightThreshold;
        }
        return weight;
    }

    private double calculateEdgeWeight(Id preVertexId, IdList preVertexAdjacenceIdList, Id nextVertexId, double weight) {
        double finalWeight = 0.0;
        finalWeight = preVertexId != null && preVertexId.equals(nextVertexId) ? 1.0 / this.returnFactor * weight : (preVertexAdjacenceIdList != null && preVertexAdjacenceIdList.contains((Value.Tvalue)nextVertexId) ? 1.0 * weight : 1.0 / this.inOutFactor * weight);
        return finalWeight;
    }

    private int randomSelectIndex(List<Double> weightList) {
        int selectedIndex = 0;
        double totalWeight = weightList.stream().mapToDouble(Double::doubleValue).sum();
        double randomNum = this.random.nextDouble() * totalWeight;
        double cumulativeWeight = 0.0;
        for (int i = 0; i < weightList.size(); ++i) {
            if (!(randomNum < (cumulativeWeight += weightList.get(i).doubleValue()))) continue;
            selectedIndex = i;
            break;
        }
        return selectedIndex;
    }

    private Edge selectEdge(Iterator<Edge> iterator, int selectedIndex) {
        Edge selectedEdge = null;
        int index = 0;
        while (iterator.hasNext()) {
            selectedEdge = iterator.next();
            if (index == selectedIndex) break;
            ++index;
        }
        return selectedEdge;
    }

    private Id getSourceId(IdList path) {
        return (Id)path.getFirst();
    }

    private void savePath(Vertex sourceVertex, IdList path) {
        IdListList curValue = (IdListList)sourceVertex.value();
        curValue.add((Value.Tvalue)path.copy());
    }
}

