package fasttext.mmap; import com.google.common.base.Preconditions; import fasttext.ReadableMatrix; import fasttext.Vector; import fasttext.store.MMapFile; import fasttext.store.ResourceInput; import java.io.IOException; import java.io.OutputStream; /** Memory-mapped {@link ReadableMatrix}. Only supports read-only operations. */ public class MMapMatrix implements ReadableMatrix { private final int m; private final int n; private final MMapFile mmapFile; private ResourceInput in; private MMapMatrix(MMapFile mmapFile, ResourceInput in, int m, int n) { this.mmapFile = mmapFile; this.in = in; this.m = m; this.n = n; } private float readAt(int i, int j) { try { in.seek( 16L + (long) (i * n + j) * 4); return in.readFloat(); } catch (IOException ex) { throw new IllegalArgumentException("Could not read float from matrix at i=" + i + " j=" + j); } } private float[] readRow(int i) { float[] r = new float[n]; try { in.seek(16L + (long) (i * n) * 4); for (int j = 0; j < n; j++) { r[j] = in.readFloat(); } } catch (IOException ex) { throw new IllegalArgumentException("Could not read row " + i + " from matrix"); } return r; } public float[] atRow(int i) { return readRow(i); } public float at(int i, int j) { return readAt(i, j); } public float dotRow(final Vector vec, int i) { Preconditions.checkPositionIndex(i, m); Preconditions.checkArgument(vec.size() == n); float d = 0.0f; float[] r = atRow(i); for (int j = 0; j < n; j++) { d += r[j] * vec.at(j); } return d; } public float l2NormRow(int i) { float norm = 0.0f; float[] r = atRow(i); for (int j = 0; j < n; j++) { float v = r[j]; norm += v * v; } return (float) Math.sqrt(norm); } public Vector l2NormRow(Vector norms) { Preconditions.checkArgument(norms.size() == m); for (int i = 0; i < m; i++) { norms.set(i, l2NormRow(i)); } return norms; } public int m() { return m; } public int n() { return n; } @Override public String toString() { StringBuilder builder = new StringBuilder(); builder.append("Matrix(m="); builder.append(m); builder.append(", n="); builder.append(n); builder.append(", mmap=MMapFile("); builder.append(mmapFile.getPath().toString()); builder.append("))"); return builder.toString(); } public static MMapMatrix load(MMapFile mmap) throws IOException { ResourceInput in = mmap.openInput(); int m = (int) in.readLong(); int n = (int) in.readLong(); return new MMapMatrix(mmap, in, m, n); } @Override public MMapMatrix clone() throws CloneNotSupportedException { MMapMatrix m = (MMapMatrix) super.clone(); m.in = in.clone(); return m; } public void close() throws IOException { in.close(); } public void saveToMMap(OutputStream os) throws IOException { throw new UnsupportedOperationException("Not implemented yet"); } }