package fasttext;

import com.google.common.base.Preconditions;
import fasttext.store.InputStreamFastTextInput;
import fasttext.store.OutputStreamFastTextOutput;
import fasttext.store.OutputStreamResourceOutput;

import java.io.IOException;
import java.io.OutputStream;

public class QMatrix implements ReadableQMatrix {

  private final ProductQuantizer npq;
  private final ProductQuantizer pq;
  private final QCodeArray codes;
  private final QCodeArray normCodes;
  private final boolean qnorm;
  private final int m;
  private final int n;

  private QMatrix(boolean qnorm,
                  int m,
                  int n,
                  QCodeArray codes,
                  ProductQuantizer pq,
                  QCodeArray normCodes,
                  ProductQuantizer npq) {
    this.qnorm = qnorm;
    this.m = m;
    this.n = n;
    this.codes = codes;
    this.pq = pq;
    this.normCodes = normCodes;
    this.npq = npq;
  }

  public QMatrix(Matrix mat, int dsub, boolean qnorm) {
    this.qnorm = qnorm;
    this.m = mat.m();
    this.n = mat.n();
    int codeSize = (this.m * (int) Math.ceil(this.n / dsub));
    this.codes = new QCodeArray(codeSize);
    this.pq = new ProductQuantizer(n, dsub);
    if (this.qnorm) {
      this.normCodes = new QCodeArray(this.m);
      this.npq = new ProductQuantizer(1, 1);
    } else {
      this.normCodes = null;
      this.npq = null;
    }
    quantize(mat);
  }

  public QMatrix(QMatrix mat) {
    this.qnorm = mat.qnorm;
    this.m = mat.m;
    this.n = mat.n;
    this.codes = mat.codes;
    this.pq = mat.pq;
    if (mat.qnorm) {
      this.normCodes = mat.normCodes;
      this.npq = mat.npq;
    } else {
      this.normCodes = null;
      this.npq = null;
    }
  }

  public void quantizeNorm(Vector norms) {
    throw new UnsupportedOperationException("Not implemented yet");
  }

  public void quantize(Matrix matrix) {
    throw new UnsupportedOperationException("Not implemented yet");
  }

  public void addToVector(Vector x, int t) {
    float norm = 1f;
    if (qnorm) {
      int cPosition = npq.getCentroidsPosition(0, normCodes.get(t));
      norm = npq.getCentroid(cPosition);
    }
    pq.addCode(x, codes, t, norm);
  }

  public float dotRow(Vector vec, int i) {
    Preconditions.checkPositionIndex(i, m);
    Preconditions.checkArgument(vec.size() == n);
    float norm = 1f;
    if (qnorm) {
      int cPosition = npq.getCentroidsPosition(0, normCodes.get(i));
      norm = npq.getCentroid(cPosition);
    }
    return pq.mulCode(vec, codes, i, norm);
  }

  public int m() {
    return m;
  }

  public int n() {
    return n;
  }

  @Override
  public String toString() {
    StringBuilder builder = new StringBuilder();
    builder.append("Matrix(m=");
    builder.append(m);
    builder.append(", n=");
    builder.append(n);
    builder.append(", codeSize=");
    builder.append(codes.size());
    builder.append(", codes=");
    builder.append(codes.toString());
    builder.append(", qnorm=");
    builder.append(qnorm);
    builder.append(", normCodes=");
    if (normCodes != null) {
      builder.append(normCodes.toString());
    } else {
      builder.append("null");
    }
    builder.append(")");
    return builder.toString();
  }

  public static QMatrix load(InputStreamFastTextInput is) throws IOException {
    boolean qnorm = is.readBoolean();
    int m = (int) is.readLong();
    int n = (int) is.readLong();
    int codeSize = is.readInt();
    int[] rawCodes = new int[codeSize];
    for (int i = 0; i < codeSize; i++) {
      int c = is.readByteAsInt();
      rawCodes[i] = c;
    }
    QCodeArray codes = new QCodeArray(rawCodes);
    ProductQuantizer pq = ProductQuantizer.load(is);
    QCodeArray normCodes = null;
    ProductQuantizer npq = null;
    if (qnorm) {
      int[] rawNormCodes = new int[m];
      for (int i = 0; i < m; i++) {
        int c = is.readByteAsInt();
        rawNormCodes[i] = c;
      }
      normCodes = new QCodeArray(rawNormCodes);
      npq = ProductQuantizer.load(is);
    }
    return new QMatrix(qnorm, m, n, codes, pq, normCodes, npq);
  }

  public void save(OutputStreamFastTextOutput os) throws IOException {
    os.writeBoolean(qnorm);
    os.writeLong(m);
    os.writeLong(n);
    os.writeInt(codes.size());
    for (int i = 0; i < codes.size(); i++) {
      os.writeIntAsByte(codes.get(i));
    }
    pq.save(os);
    if (qnorm) {
      for (int i = 0; i < m; i++) {
        os.writeIntAsByte(normCodes.get(i));
      }
      npq.save(os);
    }
  }

  public void saveToMMap(OutputStream os) throws IOException {
    try (OutputStreamResourceOutput fos = new OutputStreamResourceOutput("qmatrix", os)) {
      fos.writeBoolean(qnorm);
      fos.writeLong(m);
      fos.writeLong(n);
      fos.writeInt(codes.size());
      for (int i = 0; i < codes.size(); i++) {
        fos.writeIntAsByte(codes.get(i));
      }
      // pq
      fos.writeInt(pq.dim());
      fos.writeInt(pq.nsubq());
      fos.writeInt(pq.dsub());
      fos.writeInt(pq.lastdsub());
      for (int i = 0; i < pq.centroids().length; i++) {
        fos.writeFloat(pq.getCentroid(i));
      }
      if (qnorm) {
        for (int i = 0; i < m; i++) {
          fos.writeIntAsByte(normCodes.get(i));
        }
        // npq
        fos.writeInt(npq.dim());
        fos.writeInt(npq.nsubq());
        fos.writeInt(npq.dsub());
        fos.writeInt(npq.lastdsub());
        for (int i = 0; i < npq.centroids().length; i++) {
          fos.writeFloat(npq.getCentroid(i));
        }
      }
    }
  }

  public void close() {}

  @Override
  public QMatrix clone() throws CloneNotSupportedException {
    return (QMatrix) super.clone();
  }

}