/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.alibaba.jstorm.transactional.spout;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Map.Entry;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.alibaba.jstorm.client.spout.IAckValueSpout;
import com.alibaba.jstorm.client.spout.IFailValueSpout;
import com.alibaba.jstorm.task.TaskBaseMetric;
import com.alibaba.jstorm.transactional.utils.AckPendingBatchTracker;

import backtype.storm.spout.SpoutOutputCollector;
import backtype.storm.spout.SpoutOutputCollectorCb;
import backtype.storm.task.ICollectorCallback;
import backtype.storm.task.TopologyContext;
import backtype.storm.topology.IRichSpout;
import backtype.storm.topology.OutputFieldsDeclarer;
import backtype.storm.utils.Utils;

/**
 * This spout was created to provide compatibility with the ACK mechanism of Storm
 */
public class AckTransactionSpout implements ITransactionSpoutExecutor {
    private static final long serialVersionUID = -6561817670963028414L;

    private static Logger LOG = LoggerFactory.getLogger(AckTransactionSpout.class);

    private IRichSpout spoutExecutor;
    private boolean isCacheTuple;
    private TaskBaseMetric taskStats;

    private volatile long currBatchId = -1;
    // Map<BatchId, Map<StreamId, Map<MsgId, Value>>>
    private AckPendingBatchTracker<Map<Object, List<Object>>> tracker;

    private Random random;

    private class AckSpoutOutputCollector extends SpoutOutputCollectorCb {
        private SpoutOutputCollectorCb delegate;

        public AckSpoutOutputCollector (SpoutOutputCollectorCb delegate) {
            this.delegate = delegate;
        }

        @Override
        public List<Integer> emit(String streamId, List<Object> tuple, Object messageId) {
            if (messageId != null) {
                addPendingTuple(currBatchId, streamId, messageId, tuple);
                tuple.add(Utils.generateId(random));
            } else {
                // for non-anchor tuples, use 0 as default rootId
                tuple.add(0l);
            }
            return delegate.emit(streamId, tuple, null);
        }

        @Override
        public void emitDirect(int taskId, String streamId, List<Object> tuple, Object messageId) {
            delegate.emitDirect(taskId, streamId, tuple, null);
        }

        @Override
        public List<Integer> emit(String streamId, List<Object> tuple, Object messageId, ICollectorCallback callback) {
            if (messageId != null) {
                addPendingTuple(currBatchId, streamId, messageId, tuple);
                tuple.add(Utils.generateId(random));
            } else {
                // for non-anchor tuples, use 0 as default rootId
                tuple.add(0l);
            }
            return delegate.emit(streamId, tuple, null, callback);
        }

        @Override
        public void emitDirect(int taskId, String streamId, List<Object> tuple, Object messageId, ICollectorCallback callback) {
            if (messageId != null) {
                addPendingTuple(currBatchId, streamId, messageId, tuple);
                tuple.add(Utils.generateId(random));
            } else {
                tuple.add(0l);
            }
            delegate.emitDirect(taskId, streamId, tuple, null, callback);
        }

        @Override
        public void reportError(Throwable error) {
            delegate.reportError(error);
        }

        @Override
        public void flush(){
            delegate.flush();
        }

        @Override
        public void emitDirectCtrl(int taskId, String streamId, List<Object> tuple, Object messageId) {
            delegate.emitDirectCtrl(taskId, streamId, tuple, messageId);
        }

        @Override
        public List<Integer> emitCtrl(String streamId, List<Object> tuple, Object messageId) {
            return delegate.emitCtrl(streamId, tuple, messageId);
        }
    }

    public AckTransactionSpout(IRichSpout spout) {
        this.spoutExecutor = spout;
        if (spoutExecutor instanceof IAckValueSpout || spoutExecutor instanceof IFailValueSpout)
            isCacheTuple = true;
        else
            isCacheTuple = false;
    }

    @Override
    public void open(Map conf, TopologyContext context, SpoutOutputCollector collector) {
        SpoutOutputCollectorCb ackOutput = new AckSpoutOutputCollector(collector.getDelegate());
        spoutExecutor.open(conf, context, new SpoutOutputCollector(ackOutput));
        tracker = new AckPendingBatchTracker<>();
        taskStats = new TaskBaseMetric(context.getTopologyId(), context.getThisComponentId(), context.getThisTaskId());
        random = new Random(Utils.secureRandomLong());
    }

    @Override
    public void close() {
        spoutExecutor.close();
    }

    @Override
    public void activate() {
        spoutExecutor.activate();
    }

    @Override
    public void deactivate() {
        spoutExecutor.deactivate();
    }

    @Override
    public void nextTuple() {
        spoutExecutor.nextTuple();
    }

    @Override
    public void ack(Object msgId) {
        spoutExecutor.ack(msgId);
    }

    @Override
    public void fail(Object msgId) {
        spoutExecutor.fail(msgId);
    }

    public void ack(Object msgId, List<Object> values) {
        ((IAckValueSpout) spoutExecutor).ack(msgId, values);
    }

    public void fail(Object msgId, List<Object> values) {
        ((IFailValueSpout) spoutExecutor).fail(msgId, values);
    }

    @Override
    public void declareOutputFields(OutputFieldsDeclarer declarer) {
        spoutExecutor.declareOutputFields(declarer);
    }

    @Override
    public Map<String, Object> getComponentConfiguration() {
        return spoutExecutor.getComponentConfiguration();
    }

    @Override
    public void initState(Object userState) {
        if (userState != null) {
            currBatchId = (Long) userState;
        } else {
            currBatchId = 1;
        }
    }

    @Override
    public Object finishBatch(long batchId) {
        currBatchId++;
        return null;
    }

    @Override
    public Object commit(long batchId, Object state) {
        return batchId;
    }

    @Override
    public void rollBack(Object userState) {
        if (userState != null) {
            currBatchId = (Long) userState;
        } else {
            currBatchId = 1;
        }

        removeObsoleteBatches(currBatchId);
        for (Long batchId : tracker.getBatchIds())
            ackOrFailBatch(batchId, false);
    }

    @Override
    public void ackCommit(long batchId, long timeStamp) {
        ackOrFailBatch(batchId, true);
        removeObsoleteBatches(batchId);
    }

    private void ackOrFailBatch(long batchId, boolean isAck) {
        for (String streamId : tracker.getStreamIds(batchId)) {
            Map<Object, List<Object>> pendingBatch = tracker.getPendingBatch(batchId, streamId);
            if (pendingBatch == null)
                continue;

            for (Entry<Object, List<Object>> entry : pendingBatch.entrySet()) {
                ackOrFailTuple(entry.getKey(), entry.getValue(), isAck);
            }
            if (isAck)
                taskStats.spoutAckedTuple(streamId, pendingBatch.size());
            else
                taskStats.spoutFailedTuple(streamId, pendingBatch.size());
        }
    }

    private void ackOrFailTuple(Object msgId, List<Object> value, boolean isAck) {
        if (isAck) {
            if (spoutExecutor instanceof IAckValueSpout)
                ack(msgId, value);
            else
                ack(msgId);
        } else {
            if (spoutExecutor instanceof IFailValueSpout)
                fail(msgId, value);
            else
                fail(msgId);
        }
    }

    private void addPendingTuple(long batchId, String streamId, Object msgId, List<Object> value) {
        Map<Object, List<Object>> pendingBatch = tracker.getPendingBatch(batchId, streamId);
        if (pendingBatch == null) {
            pendingBatch = new HashMap<>();
        }
        List<Object> cacheValue = isCacheTuple ? value : null;
        pendingBatch.put(msgId, cacheValue);
    }

    private void removeObsoleteBatches(long commitBatchId) {
        TreeSet<Long> totalBatches = new TreeSet<Long>(tracker.getBatchIds());
        Set<Long> obsoleteBatches = totalBatches.headSet(commitBatchId);
        if (obsoleteBatches != null && obsoleteBatches.size() > 0) {
            LOG.info("Remove obsolete batches: {}", obsoleteBatches);
            for (Long batchId : obsoleteBatches) {
                tracker.removeBatch(batchId);
            }
        }
    }
}