package org.numenta.nupic.flink.streaming.api.operator;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.accumulators.AverageAccumulator;
import org.apache.flink.api.common.accumulators.Histogram;
import org.apache.flink.api.common.accumulators.IntCounter;
import org.apache.flink.api.common.functions.Function;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
import org.numenta.nupic.flink.streaming.api.NetworkFactory;
import org.numenta.nupic.flink.streaming.api.NetworkInference;
import org.numenta.nupic.flink.streaming.api.ResetFunction;
import org.numenta.nupic.flink.streaming.api.codegen.GenerateEncoderInputFunction;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.numenta.nupic.Parameters;
import org.numenta.nupic.encoders.MultiEncoder;
import org.numenta.nupic.network.Inference;
import org.numenta.nupic.network.Network;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.LinkedHashMap;
import java.util.Map;

/**
 * Base class for HTM inference operator.  The operator uses a {@Link Network} to
 * make inferences about input stream elements.
 *
 * @param <IN> Type of the input elements
 *
 * @author Eron Wright
 */
public abstract class AbstractHTMInferenceOperator<IN>
        extends AbstractUdfStreamOperator<Tuple2<IN, NetworkInference>, ResetFunction<IN>>
        implements OneInputStreamOperator<IN, Tuple2<IN, NetworkInference>> {

    protected static final Logger LOG = LoggerFactory.getLogger(AbstractHTMInferenceOperator.class);

    protected static final int INITIAL_PRIORITY_QUEUE_CAPACITY = 11;

    protected IntCounter networkCounter;
    private AverageAccumulator avgProcessingTime;

    private final ExecutionConfig executionConfig;
    private final TypeInformation<IN> inputType;
    private final TypeSerializer<IN> inputSerializer;
    private final boolean isProcessingTime;
    private final NetworkFactory<IN> networkFactory;

    private transient EncoderInputFunction<IN> inputFunction;

    public AbstractHTMInferenceOperator(
            final ExecutionConfig executionConfig,
            final TypeInformation<IN> inputType,
            final boolean isProcessingTime,
            final NetworkFactory<IN> networkFactory,
            final ResetFunction<IN> resetFunction
    ) {
        super(resetFunction != null ? resetFunction : NEVER_RESET_FUNCTION);

        this.executionConfig = executionConfig;
        this.inputType = inputType;
        this.isProcessingTime = isProcessingTime;
        this.networkFactory = networkFactory;

        this.inputSerializer = inputType.createSerializer(executionConfig);
    }

    public TypeSerializer<IN> getInputSerializer() {
        return inputSerializer;
    }

    @Override
    public void open() throws Exception {
        super.open();

        networkCounter = getRuntimeContext().getIntCounter("networks");
        avgProcessingTime = new AverageAccumulator();
        getRuntimeContext().addAccumulator("processing time (ms)", avgProcessingTime);
    }

    protected abstract Network getInputNetwork() throws Exception;

    @Override
    public void processElement(StreamRecord<IN> element) throws Exception {
        long startTime = System.currentTimeMillis();

        if (isProcessingTime) {
            // there can be no out of order elements in processing time
            Network network = getInputNetwork();
            processInput(network, element.getValue(), element.getTimestamp());
        } else {
            // TODO order input elements by timestamp (HTM causality)
            // see Flink CEP code for an example
            Network network = getInputNetwork();
            processInput(network, element.getValue(), element.getTimestamp());
        }

        long duration = System.currentTimeMillis() - startTime;
        avgProcessingTime.add(duration);
    }

    @Override
    public void processWatermark(Watermark mark) throws Exception {
        output.emitWatermark(mark);
    }

    protected void processInput(Network network, IN record, long timestamp) throws Exception {

        if(userFunction.reset(record)) {
            network.reset();
            LOG.debug("network reset");
        }

        Object input = convertToInput(record, timestamp);
        Inference inference = network.computeImmediate(input);

        if(inference != null) {
            NetworkInference outputInference = NetworkInference.fromInference(inference);
            StreamRecord<Tuple2<IN,NetworkInference>> streamRecord = new StreamRecord<>(
                    new Tuple2<>(record, outputInference),
                    timestamp);
            output.collect(streamRecord);
        }
    }

    /**
     * Initialize the input function to map input elements to HTM encoder input.
     * @throws Exception
     */
    protected void initInputFunction() throws Exception {

        // it is premature to call getInputNetwork, because no 'key' is available
        // when the operator is first opened.
        Network network = networkFactory.createNetwork(null);
        MultiEncoder encoder = network.getEncoder();

        if(encoder == null)
            throw new IllegalArgumentException("a network encoder must be provided");

        // handle the situation where an encoder parameter map was supplied rather than a fully-baked encoder.
        if(encoder.getEncoders(encoder) == null || encoder.getEncoders(encoder).size() < 1) {
            Map<String, Map<String, Object>> encoderParams =
                    (Map<String, Map<String, Object>>) network.getParameters().get(Parameters.KEY.FIELD_ENCODING_MAP);
            if(encoderParams == null || encoderParams.size() < 1) {
                throw new IllegalStateException("No field encoding map found for MultiEncoder");
            }
            encoder.addMultipleEncoders(encoderParams);
        }

        // generate the encoder input function
        final GenerateEncoderInputFunction<IN> generator = new GenerateEncoderInputFunction<>((CompositeType<IN>) inputType, encoder, executionConfig);
        inputFunction = generator.generate();
    }

    private Object convertToInput(IN record, long timestamp) {
        if(inputFunction == null) throw new IllegalStateException("inputFunction is null");

        Map<String, Object> inputMap = inputFunction.map(record);
        return inputMap;
    }

    private static final ResetFunction NEVER_RESET_FUNCTION = new ResetFunction() {
        @Override
        public boolean reset(Object value) throws Exception {
            return false;
        }
    };
}