/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.search.suggest.fst;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Set;

import org.apache.lucene.search.suggest.InputIterator;
import org.apache.lucene.search.suggest.Lookup;
import org.apache.lucene.search.suggest.SortedInputIterator;
import org.apache.lucene.store.ByteArrayDataInput;
import org.apache.lucene.store.ByteArrayDataOutput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Accountables;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.CharsRefBuilder;
import org.apache.lucene.util.IntsRefBuilder;
import org.apache.lucene.util.OfflineSorter.ByteSequencesWriter;
import org.apache.lucene.util.fst.FSTCompiler;
import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.FST.Arc;
import org.apache.lucene.util.fst.FST.BytesReader;
import org.apache.lucene.util.fst.PositiveIntOutputs;
import org.apache.lucene.util.fst.Util;
import org.apache.lucene.util.fst.Util.Result;
import org.apache.lucene.util.fst.Util.TopResults;

/**
 * Suggester based on a weighted FST: it first traverses the prefix, 
 * then walks the <i>n</i> shortest paths to retrieve top-ranked
 * suggestions.
 * <p>
 * <b>NOTE</b>:
 * Input weights must be between 0 and {@link Integer#MAX_VALUE}, any
 * other values will be rejected.
 * 
 * @lucene.experimental
 */
// redundant 'implements Accountable' to workaround javadocs bugs
public class WFSTCompletionLookup extends Lookup implements Accountable {
  
  /**
   * FST<Long>, weights are encoded as costs: (Integer.MAX_VALUE-weight)
   */
  // NOTE: like FSTSuggester, this is really a WFSA, if you want to
  // customize the code to add some output you should use PairOutputs.
  private FST<Long> fst = null;
  
  /** 
   * True if exact match suggestions should always be returned first.
   */
  private final boolean exactFirst;
  
  /** Number of entries the lookup was built with */
  private long count = 0;

  private final Directory tempDir;
  private final String tempFileNamePrefix;

  /**
   * Calls {@link #WFSTCompletionLookup(Directory,String,boolean) WFSTCompletionLookup(null,null,true)}
   */
  public WFSTCompletionLookup(Directory tempDir, String tempFileNamePrefix) {
    this(tempDir, tempFileNamePrefix, true);
  }
  
  /**
   * Creates a new suggester.
   * 
   * @param exactFirst <code>true</code> if suggestions that match the 
   *        prefix exactly should always be returned first, regardless
   *        of score. This has no performance impact, but could result
   *        in low-quality suggestions.
   */
  public WFSTCompletionLookup(Directory tempDir, String tempFileNamePrefix, boolean exactFirst) {
    this.exactFirst = exactFirst;
    this.tempDir = tempDir;
    this.tempFileNamePrefix = tempFileNamePrefix;
  }
  
  @Override
  public void build(InputIterator iterator) throws IOException {
    if (iterator.hasPayloads()) {
      throw new IllegalArgumentException("this suggester doesn't support payloads");
    }
    if (iterator.hasContexts()) {
      throw new IllegalArgumentException("this suggester doesn't support contexts");
    }
    count = 0;
    BytesRef scratch = new BytesRef();
    InputIterator iter = new WFSTInputIterator(tempDir, tempFileNamePrefix, iterator);
    IntsRefBuilder scratchInts = new IntsRefBuilder();
    BytesRefBuilder previous = null;
    PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton();
    FSTCompiler<Long> fstCompiler = new FSTCompiler<>(FST.INPUT_TYPE.BYTE1, outputs);
    while ((scratch = iter.next()) != null) {
      long cost = iter.weight();
      
      if (previous == null) {
        previous = new BytesRefBuilder();
      } else if (scratch.equals(previous.get())) {
        continue; // for duplicate suggestions, the best weight is actually
                  // added
      }
      Util.toIntsRef(scratch, scratchInts);
      fstCompiler.add(scratchInts.get(), cost);
      previous.copyBytes(scratch);
      count++;
    }
    fst = fstCompiler.compile();
  }

  
  @Override
  public boolean store(DataOutput output) throws IOException {
    output.writeVLong(count);
    if (fst == null) {
      return false;
    }
    fst.save(output, output);
    return true;
  }

  @Override
  public boolean load(DataInput input) throws IOException {
    count = input.readVLong();
    this.fst = new FST<>(input, input, PositiveIntOutputs.getSingleton());
    return true;
  }

  @Override
  public List<LookupResult> lookup(CharSequence key, Set<BytesRef> contexts, boolean onlyMorePopular, int num) {
    if (contexts != null) {
      throw new IllegalArgumentException("this suggester doesn't support contexts");
    }
    assert num > 0;

    if (onlyMorePopular) {
      throw new IllegalArgumentException("this suggester only works with onlyMorePopular=false");
    }

    if (fst == null) {
      return Collections.emptyList();
    }

    BytesRefBuilder scratch = new BytesRefBuilder();
    scratch.copyChars(key);
    int prefixLength = scratch.length();
    Arc<Long> arc = new Arc<>();
    
    // match the prefix portion exactly
    Long prefixOutput = null;
    try {
      prefixOutput = lookupPrefix(scratch.get(), arc);
    } catch (IOException bogus) { throw new RuntimeException(bogus); }
    
    if (prefixOutput == null) {
      return Collections.emptyList();
    }
    
    List<LookupResult> results = new ArrayList<>(num);
    CharsRefBuilder spare = new CharsRefBuilder();
    if (exactFirst && arc.isFinal()) {
      spare.copyUTF8Bytes(scratch.get());
      results.add(new LookupResult(spare.toString(), decodeWeight(prefixOutput + arc.nextFinalOutput())));
      if (--num == 0) {
        return results; // that was quick
      }
    }

    // complete top-N
    TopResults<Long> completions = null;
    try {
      completions = Util.shortestPaths(fst, arc, prefixOutput, weightComparator, num, !exactFirst);
      assert completions.isComplete;
    } catch (IOException bogus) {
      throw new RuntimeException(bogus);
    }
    
    BytesRefBuilder suffix = new BytesRefBuilder();
    for (Result<Long> completion : completions) {
      scratch.setLength(prefixLength);
      // append suffix
      Util.toBytesRef(completion.input, suffix);
      scratch.append(suffix);
      spare.copyUTF8Bytes(scratch.get());
      results.add(new LookupResult(spare.toString(), decodeWeight(completion.output)));
    }
    return results;
  }
  
  private Long lookupPrefix(BytesRef scratch, Arc<Long> arc) throws /*Bogus*/IOException {
    assert 0 == fst.outputs.getNoOutput().longValue();
    long output = 0;
    BytesReader bytesReader = fst.getBytesReader();
    
    fst.getFirstArc(arc);
    
    byte[] bytes = scratch.bytes;
    int pos = scratch.offset;
    int end = pos + scratch.length;
    while (pos < end) {
      if (fst.findTargetArc(bytes[pos++] & 0xff, arc, arc, bytesReader) == null) {
        return null;
      } else {
        output += arc.output().longValue();
      }
    }
    
    return output;
  }
  
  /**
   * Returns the weight associated with an input string,
   * or null if it does not exist.
   */
  public Object get(CharSequence key) {
    if (fst == null) {
      return null;
    }
    Arc<Long> arc = new Arc<>();
    Long result = null;
    try {
      result = lookupPrefix(new BytesRef(key), arc);
    } catch (IOException bogus) { throw new RuntimeException(bogus); }
    if (result == null || !arc.isFinal()) {
      return null;
    } else {
      return Integer.valueOf(decodeWeight(result + arc.nextFinalOutput()));
    }
  }
  
  /** cost -&gt; weight */
  private static int decodeWeight(long encoded) {
    return (int)(Integer.MAX_VALUE - encoded);
  }
  
  /** weight -&gt; cost */
  private static int encodeWeight(long value) {
    if (value < 0 || value > Integer.MAX_VALUE) {
      throw new UnsupportedOperationException("cannot encode value: " + value);
    }
    return Integer.MAX_VALUE - (int)value;
  }
  
  private static final class WFSTInputIterator extends SortedInputIterator {

    WFSTInputIterator(Directory tempDir, String tempFileNamePrefix, InputIterator source) throws IOException {
      super(tempDir, tempFileNamePrefix, source);
      assert source.hasPayloads() == false;
    }

    @Override
    protected void encode(ByteSequencesWriter writer, ByteArrayDataOutput output, byte[] buffer, BytesRef spare, BytesRef payload, Set<BytesRef> contexts, long weight) throws IOException {
      if (spare.length + 4 >= buffer.length) {
        buffer = ArrayUtil.grow(buffer, spare.length + 4);
      }
      output.reset(buffer);
      output.writeBytes(spare.bytes, spare.offset, spare.length);
      output.writeInt(encodeWeight(weight));
      writer.write(buffer, 0, output.getPosition());
    }
    
    @Override
    protected long decode(BytesRef scratch, ByteArrayDataInput tmpInput) {
      scratch.length -= 4; // int
      // skip suggestion:
      tmpInput.reset(scratch.bytes, scratch.offset+scratch.length, 4);
      return tmpInput.readInt();
    }
  }
  
  static final Comparator<Long> weightComparator = new Comparator<Long> () {
    @Override
    public int compare(Long left, Long right) {
      return left.compareTo(right);
    }  
  };

  /** Returns byte size of the underlying FST. */
  @Override
  public long ramBytesUsed() {
    return (fst == null) ? 0 : fst.ramBytesUsed();
  }
  
  @Override
  public Collection<Accountable> getChildResources() {
    if (fst == null) {
      return Collections.emptyList();
    } else {
      return Collections.singleton(Accountables.namedAccountable("fst", fst));
    }
  }

  @Override
  public long getCount() {
    return count;
  }
}