/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.kgraph.pregel;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.kafka.streams.processor.ProcessorContext;
import org.apache.kafka.streams.state.KeyValueStore;
import org.apache.kafka.streams.state.TimestampedKeyValueStore;
import org.apache.kafka.streams.state.ValueAndTimestamp;

import io.kgraph.EdgeWithValue;
import io.kgraph.VertexWithValue;
import io.kgraph.pregel.PregelComputation.AggregatorWrapper;
import io.kgraph.pregel.aggregators.Aggregator;

/**
 * The user-defined compute function for a Pregel computation.
 *
 * @param <K> The type of the vertex key (the vertex identifier).
 * @param <VV> The type of the vertex value (the state of the vertex).
 * @param <EV> The type of the values that are associated with the edges.
 * @param <Message> The type of the message sent between vertices along the edges.
 */
@FunctionalInterface
public interface ComputeFunction<K, VV, EV, Message> {

    /**
     * Initialize the ComputeFunction, this is the place to register aggregators.
     *
     * @param configs configuration parameters
     * @param cb a callback for registering aggregators
     */
    default void init(Map<String, ?> configs, InitCallback cb) {
    }

    /**
     * A function for performing sequential computations between supersteps.
     *
     * @param superstep the superstep
     * @param cb a callback for writing to aggregators or halting the computation
     */
    default void masterCompute(int superstep, MasterCallback cb) {
    }

    /**
     * Prepare for computation.  This method is executed exactly once prior to compute() being called
     * for any of the vertices in the partition.
     *
     * @param superstep the superstep
     * @param aggregators the aggregators
     */
    default void preSuperstep(int superstep, Aggregators aggregators) {
    }

    /**
     * The function for computing a new vertex value or sending messages to the next superstep.
     *
     * @param superstep the count of the current superstep
     * @param vertex the current vertex with its value
     * @param messages a Map of the source vertex and the message sent from the previous superstep
     * @param edges the adjacent edges with their values
     * @param cb a callback for setting a new vertex value or sending messages to the next superstep
     */
    void compute(int superstep,
                 VertexWithValue<K, VV> vertex,
                 Iterable<Message> messages,
                 Iterable<EdgeWithValue<K, EV>> edges,
                 Callback<K, VV, EV, Message> cb);

    /**
     * Finish computation.  This method is executed exactly once after computation
     * for all vertices in the partition is complete.
     *
     * @param superstep the superstep
     * @param aggregators the aggregators
     */
    default void postSuperstep(int superstep, Aggregators aggregators) {
    }


    final class InitCallback {

        protected final Map<String, AggregatorWrapper<?>> aggregators;

        public InitCallback(Map<String, AggregatorWrapper<?>> aggregators) {
            this.aggregators = aggregators;
        }

        public <T> void registerAggregator(String name,
                                           Class<? extends Aggregator<T>> aggregatorClass) {
            registerAggregator(name, aggregatorClass, false);
        }

        public <T> void registerAggregator(String name,
                                           Class<? extends Aggregator<T>> aggregatorClass,
                                           boolean persistent) {
            aggregators.put(name, new AggregatorWrapper<>(aggregatorClass, persistent));
        }
    }

    interface ReadAggregators {
        <T> T getAggregatedValue(String name);
    }

    interface ReadWriteAggregators extends ReadAggregators {
        <T> void aggregate(String name, T value);
    }

    final class MasterCallback implements ReadAggregators {

        protected final Map<String, Aggregator<?>> previousAggregators;

        protected boolean haltComputation = false;

        public MasterCallback(Map<String, Aggregator<?>> previousAggregators) {
            this.previousAggregators = previousAggregators;
        }

        @Override
        @SuppressWarnings("unchecked")
        public final <T> T getAggregatedValue(String name) {
            return (T) previousAggregators.get(name).getAggregate();
        }

        @SuppressWarnings("unchecked")
        public final <T> void setAggregatedValue(String name, T value) {
            ((Aggregator<T>) previousAggregators.get(name)).setAggregate(value);
        }

        public void haltComputation() {
            haltComputation = true;
        }
    }

    final class Aggregators implements ReadWriteAggregators {

        protected final Map<String, ?> previousAggregates;

        protected final Map<String, Aggregator<?>> aggregators;

        public Aggregators(Map<String, ?> previousAggregates, Map<String, Aggregator<?>> aggregators) {
            this.previousAggregates = previousAggregates;
            this.aggregators = aggregators;
        }

        @Override
        @SuppressWarnings("unchecked")
        public final <T> T getAggregatedValue(String name) {
            return (T) previousAggregates.get(name);
        }

        @Override
        public final <T> void aggregate(String name, T value) {
            aggregator(name).aggregate(value);
        }

        @SuppressWarnings("unchecked")
        private <T> Aggregator<T> aggregator(String name) {
            return (Aggregator<T>) aggregators.get(name);
        }
    }

    final class Callback<K, VV, EV, Message> implements ReadWriteAggregators {

        protected final ProcessorContext context;

        protected final K key;

        protected final TimestampedKeyValueStore<K, Map<K, EV>> edgesStore;

        protected VV newVertexValue = null;

        protected final Map<K, List<Message>> outgoingMessages = new HashMap<>();

        protected boolean voteToHalt = false;

        protected final Map<String, ?> previousAggregates;

        protected final Map<String, Map<K, ?>> aggregators;

        public Callback(ProcessorContext context,
                        K key,
                        TimestampedKeyValueStore<K, Map<K, EV>> edgesStore,
                        Map<String, ?> previousAggregates,
                        Map<String, Map<K, ?>> aggregators) {
            this.context = context;
            this.previousAggregates = previousAggregates;
            this.aggregators = aggregators;
            this.key = key;
            this.edgesStore = edgesStore;
        }

        public final void sendMessageTo(K target, Message m) {
            List<Message> messages = outgoingMessages.computeIfAbsent(target, k -> new ArrayList<>());
            messages.add(m);
        }

        public final void setNewVertexValue(VV vertexValue) {
            newVertexValue = vertexValue;
        }

        public final void addEdge(K target, EV value) {
            Map<K, EV> edges = ValueAndTimestamp.getValueOrNull(edgesStore.get(key));
            if (edges == null) {
                edges = new HashMap<>();
            }
            edges.put(target, value);
            edgesStore.put(key, ValueAndTimestamp.make(edges, context.timestamp()));
        }

        public final void removeEdge(K target) {
            Map<K, EV> edges = ValueAndTimestamp.getValueOrNull(edgesStore.get(key));
            if (edges == null) {
                return;
            }
            edges.remove(target);
            edgesStore.put(key, ValueAndTimestamp.make(edges, context.timestamp()));
        }

        public final void setNewEdgeValue(K target, EV value) {
            Map<K, EV> edges = ValueAndTimestamp.getValueOrNull(edgesStore.get(key));
            if (edges == null) {
                return;
            }
            edges.replace(target, value);
            edgesStore.put(key, ValueAndTimestamp.make(edges, context.timestamp()));
        }

        public void voteToHalt() {
            voteToHalt = true;
        }

        @Override
        @SuppressWarnings("unchecked")
        public final <T> T getAggregatedValue(String name) {
            return (T) previousAggregates.get(name);
        }

        @Override
        public final <T> void aggregate(String name, T value) {
            aggregator(name).put(key, value);
        }

        @SuppressWarnings("unchecked")
        private <T> Map<K, T> aggregator(String name) {
            return (Map<K, T>) aggregators.get(name);
        }
    }
}