/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.util.fst;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.zip.GZIPInputStream;

import org.apache.lucene.store.ByteArrayDataInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.InputStreamDataInput;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.IntsRefBuilder;
import org.apache.lucene.util.LuceneTestCase;

public class TestFSTDirectAddressing extends LuceneTestCase {

  public void testDenseWithGap() throws Exception {
    List<String> words = Arrays.asList("ah", "bi", "cj", "dk", "fl", "gm");
    List<BytesRef> entries = new ArrayList<>();
    for (String word : words) {
      entries.add(new BytesRef(word.getBytes(StandardCharsets.US_ASCII)));
    }
    final BytesRefFSTEnum<Object> fstEnum = new BytesRefFSTEnum<>(buildFST(entries));
    for (BytesRef entry : entries) {
      assertNotNull(entry.utf8ToString() + " not found", fstEnum.seekExact(entry));
    }
  }

  public void testDeDupTails() throws Exception {
    List<BytesRef> entries = new ArrayList<>();
    for (int i = 0; i < 1000000; i += 4) {
      byte[] b = new byte[3];
      int val = i;
      for (int j = b.length - 1; j >= 0; --j) {
        b[j] = (byte) (val & 0xff);
        val >>= 8;
      }
      entries.add(new BytesRef(b));
    }
    long size = buildFST(entries).ramBytesUsed();
    // Size is 1648 when we use only list-encoding. We were previously failing to ever de-dup
    // direct addressing, which led this case to blow up.
    // This test will fail if there is more than 1% size increase with direct addressing.
    assertTrue("FST size = " + size + " B", size <= 1648 * 1.01d);
  }

  @Nightly
  public void testWorstCaseForDirectAddressing() throws Exception {
    // This test will fail if there is more than 1% memory increase with direct addressing in this worst case.
    final double MEMORY_INCREASE_LIMIT_PERCENT = 1d;
    final int NUM_WORDS = 1000000;

    // Generate words with specially crafted bytes.
    Set<BytesRef> wordSet = new HashSet<>();
    for (int i = 0; i < NUM_WORDS; ++i) {
      byte[] b = new byte[5];
      random().nextBytes(b);
      for (int j = 0; j < b.length; ++j) {
        b[j] &= 0xfc; // Make this byte a multiple of 4.
      }
      wordSet.add(new BytesRef(b));
    }
    List<BytesRef> wordList = new ArrayList<>(wordSet);
    Collections.sort(wordList);

    // Disable direct addressing and measure the FST size.
    FSTCompiler<Object> fstCompiler = createFSTCompiler(-1f);
    FST<Object> fst = buildFST(wordList, fstCompiler);
    long ramBytesUsedNoDirectAddressing = fst.ramBytesUsed();

    // Enable direct addressing and measure the FST size.
    fstCompiler = createFSTCompiler(FSTCompiler.DIRECT_ADDRESSING_MAX_OVERSIZING_FACTOR);
    fst = buildFST(wordList, fstCompiler);
    long ramBytesUsed = fst.ramBytesUsed();

    // Compute the size increase in percents.
    double directAddressingMemoryIncreasePercent = ((double) ramBytesUsed / ramBytesUsedNoDirectAddressing - 1) * 100;

//    printStats(builder, ramBytesUsed, directAddressingMemoryIncreasePercent);

    // Verify the FST size does not exceed the limit.
    assertTrue("FST size exceeds limit, size = " + ramBytesUsed
            + ", increase = " + directAddressingMemoryIncreasePercent + " %"
            + ", limit = " + MEMORY_INCREASE_LIMIT_PERCENT + " %",
        directAddressingMemoryIncreasePercent < MEMORY_INCREASE_LIMIT_PERCENT);
  }

  private static void printStats(FSTCompiler<Object> fstCompiler, long ramBytesUsed, double directAddressingMemoryIncreasePercent) {
    System.out.println("directAddressingMaxOversizingFactor = " + fstCompiler.getDirectAddressingMaxOversizingFactor());
    System.out.println("ramBytesUsed = "
        + String.format(Locale.ENGLISH, "%.2f MB", ramBytesUsed / 1024d / 1024d)
        + String.format(Locale.ENGLISH, " (%.2f %% increase with direct addressing)", directAddressingMemoryIncreasePercent));
    System.out.println("num nodes = " + fstCompiler.nodeCount);
    long fixedLengthArcNodeCount = fstCompiler.directAddressingNodeCount + fstCompiler.binarySearchNodeCount;
    System.out.println("num fixed-length-arc nodes = " + fixedLengthArcNodeCount
        + String.format(Locale.ENGLISH, " (%.2f %% of all nodes)",
        ((double) fixedLengthArcNodeCount / fstCompiler.nodeCount * 100)));
    System.out.println("num binary-search nodes = " + (fstCompiler.binarySearchNodeCount)
        + String.format(Locale.ENGLISH, " (%.2f %% of fixed-length-arc nodes)",
        ((double) (fstCompiler.binarySearchNodeCount) / fixedLengthArcNodeCount * 100)));
    System.out.println("num direct-addressing nodes = " + (fstCompiler.directAddressingNodeCount)
        + String.format(Locale.ENGLISH, " (%.2f %% of fixed-length-arc nodes)",
        ((double) (fstCompiler.directAddressingNodeCount) / fixedLengthArcNodeCount * 100)));
  }

  private static FSTCompiler<Object> createFSTCompiler(float directAddressingMaxOversizingFactor) {
    return new FSTCompiler.Builder<>(FST.INPUT_TYPE.BYTE1, NoOutputs.getSingleton())
        .directAddressingMaxOversizingFactor(directAddressingMaxOversizingFactor)
        .build();
  }

  private FST<Object> buildFST(List<BytesRef> entries) throws Exception {
    return buildFST(entries, createFSTCompiler(FSTCompiler.DIRECT_ADDRESSING_MAX_OVERSIZING_FACTOR));
  }

  private static FST<Object> buildFST(List<BytesRef> entries, FSTCompiler<Object> fstCompiler) throws Exception {
    BytesRef last = null;
    for (BytesRef entry : entries) {
      if (entry.equals(last) == false) {
        fstCompiler.add(Util.toIntsRef(entry, new IntsRefBuilder()), NoOutputs.getSingleton().getNoOutput());
      }
      last = entry;
    }
    return fstCompiler.compile();
  }

  public static void main(String... args) throws Exception {
    if (args.length < 2) {
      throw new IllegalArgumentException("Missing argument");
    }
    switch (args[0]) {
      case "-countFSTArcs":
        countFSTArcs(args[1]);
        break;
      case "-measureFSTOversizing":
        measureFSTOversizing(args[1]);
        break;
      case "-recompileAndWalk":
        recompileAndWalk(args[1]);
        break;
      default:
        throw new IllegalArgumentException("Invalid argument " + args[0]);
    }
  }

  private static void countFSTArcs(String fstFilePath) throws IOException {
    byte[] buf = Files.readAllBytes(Paths.get(fstFilePath));
    DataInput in = new ByteArrayDataInput(buf);
    FST<BytesRef> fst = new FST<>(in, in, ByteSequenceOutputs.getSingleton());
    BytesRefFSTEnum<BytesRef> fstEnum = new BytesRefFSTEnum<>(fst);
    int binarySearchArcCount = 0, directAddressingArcCount = 0, listArcCount = 0;
    while(fstEnum.next() != null) {
      if (fstEnum.arcs[fstEnum.upto].bytesPerArc() == 0) {
        listArcCount ++;
      } else if (fstEnum.arcs[fstEnum.upto].nodeFlags() == FST.ARCS_FOR_DIRECT_ADDRESSING) {
        directAddressingArcCount ++;
      } else {
        binarySearchArcCount ++;
      }
    }
    System.out.println("direct addressing arcs = " + directAddressingArcCount
        + ", binary search arcs = " + binarySearchArcCount
        + " list arcs = " + listArcCount);
  }

  private static void measureFSTOversizing(String wordsFilePath) throws Exception {
    final int MAX_NUM_WORDS = 1000000;

    // Read real english words.
    List<BytesRef> wordList = new ArrayList<>();
    try (BufferedReader reader = Files.newBufferedReader(Paths.get(wordsFilePath))) {
      while (wordList.size() < MAX_NUM_WORDS) {
        String word = reader.readLine();
        if (word == null) {
          break;
        }
        wordList.add(new BytesRef(word));
      }
    }
    Collections.sort(wordList);

    // Disable direct addressing and measure the FST size.
    FSTCompiler<Object> fstCompiler = createFSTCompiler(-1f);
    FST<Object> fst = buildFST(wordList, fstCompiler);
    long ramBytesUsedNoDirectAddressing = fst.ramBytesUsed();

    // Enable direct addressing and measure the FST size.
    fstCompiler = createFSTCompiler(FSTCompiler.DIRECT_ADDRESSING_MAX_OVERSIZING_FACTOR);
    fst = buildFST(wordList, fstCompiler);
    long ramBytesUsed = fst.ramBytesUsed();

    // Compute the size increase in percents.
    double directAddressingMemoryIncreasePercent = ((double) ramBytesUsed / ramBytesUsedNoDirectAddressing - 1) * 100;

    printStats(fstCompiler, ramBytesUsed, directAddressingMemoryIncreasePercent);
  }

  private static void recompileAndWalk(String fstFilePath) throws IOException {
    try (InputStreamDataInput in = new InputStreamDataInput(newInputStream(Paths.get(fstFilePath)))) {

      System.out.println("Reading FST");
      long startTimeMs = System.currentTimeMillis();
      FST<CharsRef> originalFst = new FST<>(in, in, CharSequenceOutputs.getSingleton());
      long endTimeMs = System.currentTimeMillis();
      System.out.println("time = " + (endTimeMs - startTimeMs) + " ms");

      for (float oversizingFactor : List.of(0f, 0f, 0f, 1f, 1f, 1f)) {
        System.out.println("\nFST construction (oversizingFactor=" + oversizingFactor + ")");
        startTimeMs = System.currentTimeMillis();
        FST<CharsRef> fst = recompile(originalFst, oversizingFactor);
        endTimeMs = System.currentTimeMillis();
        System.out.println("time = " + (endTimeMs - startTimeMs) + " ms");
        System.out.println("FST RAM = " + fst.ramBytesUsed() + " B");

        System.out.println("FST enum");
        startTimeMs = System.currentTimeMillis();
        walk(fst);
        endTimeMs = System.currentTimeMillis();
        System.out.println("time = " + (endTimeMs - startTimeMs) + " ms");
      }
    }
  }

  private static InputStream newInputStream(Path path) throws IOException {
    InputStream in = Files.newInputStream(path);
    String fileName = path.getFileName().toString();
    if (fileName.endsWith("gz") || fileName.endsWith("zip")) {
      in = new GZIPInputStream(in);
    }
    return in;
  }

  private static FST<CharsRef> recompile(FST<CharsRef> fst, float oversizingFactor) throws IOException {
    FSTCompiler<CharsRef> fstCompiler = new FSTCompiler.Builder<>(FST.INPUT_TYPE.BYTE4, CharSequenceOutputs.getSingleton())
        .directAddressingMaxOversizingFactor(oversizingFactor)
        .build();
    IntsRefFSTEnum<CharsRef> fstEnum = new IntsRefFSTEnum<>(fst);
    IntsRefFSTEnum.InputOutput<CharsRef> inputOutput;
    while ((inputOutput = fstEnum.next()) != null) {
      fstCompiler.add(inputOutput.input, CharsRef.deepCopyOf(inputOutput.output));
    }
    return fstCompiler.compile();
  }

  private static int walk(FST<CharsRef> read) throws IOException {
    IntsRefFSTEnum<CharsRef> fstEnum = new IntsRefFSTEnum<>(read);
    IntsRefFSTEnum.InputOutput<CharsRef> inputOutput;
    int terms = 0;
    while ((inputOutput = fstEnum.next()) != null) {
      terms += inputOutput.input.length;
      terms += inputOutput.output.length;
    }
    return terms;
  }
}