package com.spotify.annoy;

import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.*;

/**
 * Read-only Approximate Nearest Neighbor Index which queries
 * databases created by annoy.
 */
public class ANNIndex implements AnnoyIndex {


  private final ArrayList<Long> roots;
  private MappedByteBuffer[] buffers;

  private final int DIMENSION, MIN_LEAF_SIZE;
  private final IndexType INDEX_TYPE;
  private final int INDEX_TYPE_OFFSET;

  // size of C structs in bytes (initialized in init)
  private final int K_NODE_HEADER_STYLE;
  private final long NODE_SIZE;

  private final int INT_SIZE = 4;
  private final int FLOAT_SIZE = 4;
  private final int MAX_NODES_IN_BUFFER;
  private final int BLOCK_SIZE;
  private RandomAccessFile memoryMappedFile;


  /**
   * Construct and load an Annoy index of a specific type (euclidean / angular).
   *
   * @param dimension dimensionality of tree, e.g. 40
   * @param filename  filename of tree
   * @param indexType type of index
   * @throws IOException if file can't be loaded
   */
  public ANNIndex(final int dimension,
                  final String filename,
                  IndexType indexType) throws IOException {
    this(dimension, filename, indexType, 0);
  }

  /**
   * Construct and load an (Angular) Annoy index.
   *
   * @param dimension dimensionality of tree, e.g. 40
   * @param filename  filename of tree
   * @throws IOException if file can't be loaded
   */
  public ANNIndex(final int dimension,
                  final String filename) throws IOException {
    this(dimension, filename, IndexType.ANGULAR);
  }

  ANNIndex(final int dimension,
                  final String filename,
                  IndexType indexType,
                  final int blockSize) throws IOException {
    DIMENSION = dimension;
    INDEX_TYPE = indexType;
    INDEX_TYPE_OFFSET = INDEX_TYPE.getOffset();
    K_NODE_HEADER_STYLE = INDEX_TYPE.getkNodeHeaderStyle();
    // we can store up to MIN_LEAF_SIZE children in leaf nodes (we put
    // them where the separating plane normally goes)
    this.MIN_LEAF_SIZE = DIMENSION + 2;
    this.NODE_SIZE = K_NODE_HEADER_STYLE + FLOAT_SIZE * DIMENSION;
    this.MAX_NODES_IN_BUFFER = (int) (blockSize == 0 ?
            Integer.MAX_VALUE / NODE_SIZE : blockSize * NODE_SIZE);
    BLOCK_SIZE = (int) (this.MAX_NODES_IN_BUFFER * NODE_SIZE);
    roots = new ArrayList<>();
    load(filename);
  }

  private void load(final String filename) throws IOException {
    memoryMappedFile = new RandomAccessFile(filename, "r");
    long fileSize = memoryMappedFile.length();
    if (fileSize == 0L) {
      throw new IOException("Index is a 0-byte file?");
    }

    int numNodes = (int) (fileSize / NODE_SIZE);
    int buffIndex =  (numNodes - 1) / MAX_NODES_IN_BUFFER;
    int rest = (int) (fileSize % BLOCK_SIZE);
    int blockSize = (rest > 0 ? rest : BLOCK_SIZE);
    // Two valid relations between dimension and file size:
    // 1) rest % NODE_SIZE == 0 makes sure either everything fits into buffer or rest is a multiple of NODE_SIZE;
    // 2) (file_size - rest) % NODE_SIZE == 0 makes sure everything else is a multiple of NODE_SIZE.
    if (rest % NODE_SIZE != 0 || (fileSize - rest) % NODE_SIZE != 0) {
      throw new RuntimeException("ANNIndex initiated with wrong dimension size");
    }
    long position = fileSize - blockSize;
    buffers = new MappedByteBuffer[buffIndex + 1];
    boolean process = true;
    int m = -1;
    long index = fileSize;
    while (position >= 0) {
      MappedByteBuffer annBuf = memoryMappedFile.getChannel().map(
              FileChannel.MapMode.READ_ONLY, position, blockSize);
      annBuf.order(ByteOrder.LITTLE_ENDIAN);

      buffers[buffIndex--] = annBuf;

      for (int i = blockSize - (int) NODE_SIZE; process && i >= 0; i -= NODE_SIZE) {
        index -= NODE_SIZE;
        int k = annBuf.getInt(i);  // node[i].n_descendants
        if (m == -1 || k == m) {
          roots.add(index);
          m = k;
        } else {
          process = false;
        }
      }
      blockSize = BLOCK_SIZE;
      position -= blockSize;
    }
  }

  private float getFloatInAnnBuf(long pos) {
    int b = (int) (pos / BLOCK_SIZE);
    int f = (int) (pos % BLOCK_SIZE);
    return buffers[b].getFloat(f);
  }

  private int getIntInAnnBuf(long pos) {
    int b = (int) (pos / BLOCK_SIZE);
    int i = (int) (pos % BLOCK_SIZE);
    return buffers[b].getInt(i);
  }

  @Override
  public void getNodeVector(final long nodeOffset, float[] v) {
    MappedByteBuffer nodeBuf = buffers[(int) (nodeOffset / BLOCK_SIZE)];
    int offset = (int) ((nodeOffset % BLOCK_SIZE) + K_NODE_HEADER_STYLE);
    for (int i = 0; i < DIMENSION; i++) {
      v[i] = nodeBuf.getFloat(offset + i * FLOAT_SIZE);
    }
  }

  @Override
  public void getItemVector(int itemIndex, float[] v) {
    getNodeVector(itemIndex * NODE_SIZE, v);
  }

  private float getNodeBias(final long nodeOffset) { // euclidean-only
    return getFloatInAnnBuf(nodeOffset + 4);
  }

  private float getDotFactor(final long nodeOffset) { // dot-only
    return getFloatInAnnBuf(nodeOffset + 12);
  }

  public final float[] getItemVector(final int itemIndex) {
    return getNodeVector(itemIndex * NODE_SIZE);
  }

  public float[] getNodeVector(final long nodeOffset) {
    float[] v = new float[DIMENSION];
    getNodeVector(nodeOffset, v);
    return v;
  }

  private static float norm(final float[] u) {
    float n = 0;
    for (float x : u)
      n += x * x;
    return (float) Math.sqrt(n);
  }

  private static float euclideanDistance(final float[] u, final float[] v) {
    float[] diff = new float[u.length];
    for (int i = 0; i < u.length; i++)
      diff[i] = u[i] - v[i];
    return norm(diff);
  }

  public static float dot(final float[] u, final float[] v) {
    double d = 0;
    for (int i = 0; i < u.length; i++)
      d += u[i] * v[i];
    return (float) d;
  }

  public static float cosineMargin(final float[] u, final float[] v) {
    return dot(u, v) / (norm(u) * norm(v));
  }

  public static float dotMargin(final float[] u, final float[] v, final float norm) {
    return dot(u, v) + norm * norm;
  }

  public static float euclideanMargin(final float[] u,
                                      final float[] v,
                                      final float bias) {
    float d = bias;
    for (int i = 0; i < u.length; i++)
      d += u[i] * v[i];
    return d;
  }

  /**
   * Closes this stream and releases any system resources associated
   * with it. If the stream is already closed then invoking this
   * method has no effect.
   *
   * <p> As noted in {@link AutoCloseable#close()}, cases where the
   * close may fail require careful attention. It is strongly advised
   * to relinquish the underlying resources and to internally
   * <em>mark</em> the {@code Closeable} as closed, prior to throwing
   * the {@code IOException}.
   *
   * @throws IOException if an I/O error occurs
   */
  @Override
  public void close() throws IOException {
    memoryMappedFile.close();
  }

  private class PQEntry implements Comparable<PQEntry> {

    PQEntry(final float margin, final long nodeOffset) {
      this.margin = margin;
      this.nodeOffset = nodeOffset;
    }

    private float margin;
    private long nodeOffset;

    @Override
    public int compareTo(final PQEntry o) {
      return Float.compare(o.margin, margin);
    }

  }

  private static boolean isZeroVec(float[] v) {
    for (int i = 0; i < v.length; i++)
      if (v[i] != 0)
        return false;
    return true;
  }

  @Override
  public final List<Integer> getNearest(final float[] queryVector,
                                        final int nResults) {

    if (queryVector.length != DIMENSION) {
      throw new RuntimeException(String.format("queryVector must be size of %d, but was %d",
              DIMENSION, queryVector.length));
    }

    PriorityQueue<PQEntry> pq = new PriorityQueue<>(
            roots.size() * FLOAT_SIZE);
    final float kMaxPriority = 1e30f;

    for (long r : roots) {
      pq.add(new PQEntry(kMaxPriority, r));
    }

    Set<Integer> nearestNeighbors = new HashSet<Integer>();
    while (nearestNeighbors.size() < roots.size() * nResults && !pq.isEmpty()) {
      PQEntry top = pq.poll();
      long topNodeOffset = top.nodeOffset;
      int nDescendants = getIntInAnnBuf(topNodeOffset);
      float[] v = getNodeVector(topNodeOffset);
      if (nDescendants == 1) {  // n_descendants
        // FIXME: does this ever happen?
        if (isZeroVec(v))
          continue;
        nearestNeighbors.add((int) (topNodeOffset / NODE_SIZE));
      } else if (nDescendants <= MIN_LEAF_SIZE) {
        for (int i = 0; i < nDescendants; i++) {
          int j = getIntInAnnBuf(topNodeOffset +
                  INDEX_TYPE_OFFSET +
                  i * INT_SIZE);
          if (isZeroVec(getNodeVector(j * NODE_SIZE)))
            continue;
          nearestNeighbors.add(j);
        }
      } else {
        float margin = (INDEX_TYPE == IndexType.ANGULAR) ? cosineMargin(v, queryVector)
                     : (INDEX_TYPE == IndexType.DOT) ? dotMargin(v, queryVector, getDotFactor(topNodeOffset))
                     : euclideanMargin(v, queryVector, getNodeBias(topNodeOffset));
        long childrenMemOffset = topNodeOffset + INDEX_TYPE_OFFSET;
        long lChild = NODE_SIZE * getIntInAnnBuf(childrenMemOffset);
        long rChild = NODE_SIZE * getIntInAnnBuf(childrenMemOffset + 4);
        pq.add(new PQEntry(-margin, lChild));
        pq.add(new PQEntry(margin, rChild));
      }
    }

    ArrayList<PQEntry> sortedNNs = new ArrayList<PQEntry>();
    for (int nn : nearestNeighbors) {
      float[] v = getItemVector(nn);
      if (!isZeroVec(v)) {
        float margin = (INDEX_TYPE == IndexType.ANGULAR) ? cosineMargin(v, queryVector)
                     : (INDEX_TYPE == IndexType.DOT) ? dot(v, queryVector)
                     : -euclideanDistance(v, queryVector);
        sortedNNs.add(new PQEntry(margin, nn));
      }
    }
    Collections.sort(sortedNNs);

    ArrayList<Integer> result = new ArrayList<>(nResults);
    for (int i = 0; i < nResults && i < sortedNNs.size(); i++) {
      result.add((int) sortedNNs.get(i).nodeOffset);
    }
    return result;
  }


  /**
   * a test query program.
   *
   * @param args tree filename, dimension, indextype ("angular" or
   *             "euclidean" and query item id.
   * @throws IOException if unable to load index
   */
  public static void main(final String[] args) throws IOException {

    String indexPath = args[0];                 // 0
    int dimension = Integer.parseInt(args[1]);  // 1
    IndexType indexType = null;                 // 2
    if (args[2].toLowerCase().equals("angular"))
      indexType = IndexType.ANGULAR;
    else if (args[2].toLowerCase().equals("dot"))
      indexType = IndexType.DOT;
    else if (args[2].toLowerCase().equals("euclidean"))
      indexType = IndexType.EUCLIDEAN;
    else throw new RuntimeException("wrong index type specified");
    int queryItem = Integer.parseInt(args[3]);  // 3

    ANNIndex annIndex = new ANNIndex(dimension, indexPath, indexType);

    // input vector
    float[] u = annIndex.getItemVector(queryItem);
    System.out.printf("vector[%d]: ", queryItem);
    for (float x : u) {
      System.out.printf("%2.2f ", x);
    }
    System.out.printf("\n");

    List<Integer> nearestNeighbors = annIndex.getNearest(u, 10);
    for (int nn : nearestNeighbors) {
      float[] v = annIndex.getItemVector(nn);
      System.out.printf("%d %d %f\n",
              queryItem, nn,
              (indexType == IndexType.ANGULAR) ?
                      cosineMargin(u, v) : euclideanDistance(u, v));
    }
  }

}