/*-
 *  * Copyright 2016 Skymind, Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 */

package org.datavec.api.records.reader.impl.csv;

import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record;
import org.datavec.api.records.SequenceRecord;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataLineInterval;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.primitives.Triple;

import java.io.DataInputStream;
import java.io.IOException;
import java.net.URI;
import java.util.*;

/**
 * A sliding window of variable size across an entire CSV.
 *
 * In practice the sliding window size starts at 1, then linearly increase to maxLinesPer sequence, then
 * linearly decrease back to 1.
 *
 * @author Justin Long (crockpotveggies)
 */
public class CSVVariableSlidingWindowRecordReader extends CSVRecordReader implements SequenceRecordReader {

    public static final String LINES_PER_SEQUENCE = NAME_SPACE + ".nlinespersequence";

    private int maxLinesPerSequence;
    private String delimiter;
    private int stride;
    private LinkedList<List<Writable>> queue;
    private boolean exhausted;

    /**
     * No-arg constructor with the default number of lines per sequence (10)
     */
    public CSVVariableSlidingWindowRecordReader() {
        this(10, 1);
    }

    /**
     * @param maxLinesPerSequence Number of lines in each sequence, use default delemiter(,) between entries in the same line
     */
    public CSVVariableSlidingWindowRecordReader(int maxLinesPerSequence) {
        this(maxLinesPerSequence, 0, 1, String.valueOf(CSVRecordReader.DEFAULT_DELIMITER));
    }

    /**
     * @param maxLinesPerSequence Number of lines in each sequence, use default delemiter(,) between entries in the same line
     * @param stride Number of lines between records (increment window > 1 line)
     */
    public CSVVariableSlidingWindowRecordReader(int maxLinesPerSequence, int stride) {
        this(maxLinesPerSequence, 0, stride, String.valueOf(CSVRecordReader.DEFAULT_DELIMITER));
    }

    /**
     * @param maxLinesPerSequence Number of lines in each sequence, use default delemiter(,) between entries in the same line
     * @param stride Number of lines between records (increment window > 1 line)
     */
    public CSVVariableSlidingWindowRecordReader(int maxLinesPerSequence, int stride, String delimiter) {
        this(maxLinesPerSequence, 0, stride, String.valueOf(CSVRecordReader.DEFAULT_DELIMITER));
    }

    /**
     *
     * @param maxLinesPerSequence Number of lines in each sequences
     * @param skipNumLines Number of lines to skip at the start of the file (only skipped once, not per sequence)
     * @param stride Number of lines between records (increment window > 1 line)
     * @param delimiter Delimiter between entries in the same line, for example ","
     */
    public CSVVariableSlidingWindowRecordReader(int maxLinesPerSequence, int skipNumLines, int stride, String delimiter) {
        super(skipNumLines);
        if(stride < 1)
            throw new IllegalArgumentException("Stride must be greater than 1");

        this.delimiter = delimiter;
        this.maxLinesPerSequence = maxLinesPerSequence;
        this.stride = stride;
        this.queue = new LinkedList<>();
        this.exhausted = false;
    }

    @Override
    public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
        super.initialize(conf, split);
        this.maxLinesPerSequence = conf.getInt(LINES_PER_SEQUENCE, maxLinesPerSequence);
    }

    @Override
    public boolean hasNext() {
        boolean moreInCsv = super.hasNext();
        boolean moreInQueue = !queue.isEmpty();
        return moreInCsv || moreInQueue;
    }

    @Override
    public List<List<Writable>> sequenceRecord() {
        // try polling next(), otherwise empty the queue
        // loop according to stride size
        for(int i = 0; i < stride; i++) {
            if(super.hasNext())
                queue.addFirst(super.next());
            else
                exhausted = true;

            if (exhausted && queue.size() < 1)
                throw new NoSuchElementException("No next element");

            if (queue.size() > maxLinesPerSequence || exhausted)
                queue.pollLast();
        }

        List<List<Writable>> sequence = new ArrayList<>();
        for(List<Writable> line : queue) {
            sequence.add(line);
        }

        if(exhausted && queue.size()==1)
            queue.pollLast();

        return sequence;
    }

    @Override
    public List<List<Writable>> sequenceRecord(URI uri, DataInputStream dataInputStream) throws IOException {
        throw new UnsupportedOperationException("Reading CSV data from DataInputStream not yet implemented");
    }

    @Override
    public SequenceRecord nextSequence() {
        int lineBefore = lineIndex;
        List<List<Writable>> record = sequenceRecord();
        int lineAfter = lineIndex + queue.size();
        URI uri = (locations == null || locations.length < 1 ? null : locations[splitIndex]);
        RecordMetaData meta = new RecordMetaDataLineInterval(lineBefore, lineAfter - 1, uri,
                        CSVVariableSlidingWindowRecordReader.class);
        return new org.datavec.api.records.impl.SequenceRecord(record, meta);
    }

    @Override
    public SequenceRecord loadSequenceFromMetaData(RecordMetaData recordMetaData) throws IOException {
        return loadSequenceFromMetaData(Collections.singletonList(recordMetaData)).get(0);
    }

    @Override
    public List<SequenceRecord> loadSequenceFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public Record loadFromMetaData(RecordMetaData recordMetaData) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void reset() {
        super.reset();
        queue = new LinkedList<>();
        exhausted = false;
    }
}