package org.apache.lucene.analysis.jate;

import org.apache.log4j.Logger;
import org.apache.lucene.analysis.TokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.*;
import org.apache.lucene.util.Attribute;
import org.apache.lucene.util.AttributeSource;
import org.apache.lucene.util.BytesRef;
import uk.ac.shef.dcs.jate.nlp.POSTagger;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;


public final class OpenNLPPOSTaggerFilter extends TokenFilter {
    protected static Logger LOG = Logger.getLogger(OpenNLPPOSTaggerFilter.class.getSimpleName());

    private POSTagger tagger;
    private int tokenIdx = 0;
    protected boolean first = true;
    // cloned attrs of all tokens
    protected List<AttributeSource> tokenAttrs = new ArrayList<>();

    private final PayloadAttribute exitingPayload = addAttribute(PayloadAttribute.class);

    private String[] posTags;

    /**
     * Construct a token stream filtering the given input.
     *
     * @param input
     */
    protected OpenNLPPOSTaggerFilter(TokenStream input, POSTagger tagger) {
        super(input);
        this.tagger = tagger;
    }

    @Override
    public boolean incrementToken() throws IOException {
        //clearAttributes();
        if (first) {
            //gather all tokens from doc
            String[] words = walkTokens();
            if (words.length == 0) {
                return false;
            }
            //tagging
            posTags = createTags(words);
            first = false;
            tokenIdx = 0;
        }

        if (tokenIdx == tokenAttrs.size()) {
            resetParams();
            return false;
        }

        AttributeSource as = tokenAttrs.get(tokenIdx);
        Iterator<? extends Class<? extends Attribute>> it = as.getAttributeClassesIterator();
        while (it.hasNext()) {
            Class<? extends Attribute> attrClass = it.next();
            if (!hasAttribute(attrClass)) {
                addAttribute(attrClass);
            }
        }
        as.copyTo(this);
        MWEMetadata metadata = exitingPayload.getPayload() == null ? new MWEMetadata() :
                MWEMetadata.deserialize(exitingPayload.getPayload().utf8ToString());
        metadata.addMetaData(MWEMetadataType.POS, posTags[tokenIdx]);
        exitingPayload.setPayload(new BytesRef(MWEMetadata.serialize(metadata)));
        tokenIdx++;
        return true;
    }

    protected String[] walkTokens() throws IOException {
        List<String> wordList = new ArrayList<>();
        while (input.incrementToken()) {
            CharTermAttribute textAtt = input.getAttribute(CharTermAttribute.class);
            OffsetAttribute offsetAtt = input.getAttribute(OffsetAttribute.class);
            char[] buffer = textAtt.buffer();
            String word =  new String(buffer, 0, offsetAtt.endOffset() - offsetAtt.startOffset());
            wordList.add(word);

            AttributeSource attrs = input.cloneAttributes();
            tokenAttrs.add(attrs);
        }
        String[] words = new String[wordList.size()];
        for (int i = 0; i < words.length; i++) {
            words[i] = wordList.get(i);
        }
        return words;
    }

    protected String[] createTags(String[] words) {
        //String[] appended = appendDot(words);
        return assignPOS(words);
    }


    protected String[] assignPOS(String[] words) {

        return tagger.tag(words);
    }

    @Override
    public void reset() throws IOException {
        super.reset();
        //clearAttributes();
        resetParams();
    }

    @Override
    public final void end() throws IOException {
        super.end();
        //clearAttributes();
        tokenAttrs.clear();
    }

    protected void resetParams() {
        first = true;
        tokenIdx = 0;
        posTags = null;
    }
}