package se.lth.cs.srl.util; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.BufferedReader; import java.io.DataInput; import java.io.DataInputStream; import java.io.DataOutput; import java.io.DataOutputStream; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.Serializable; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; import java.util.regex.Pattern; import java.util.zip.GZIPInputStream; public class WordEmbedding implements Serializable { private static final long serialVersionUID = 1L; public static final int DEF_DIMENSIONALITY = 100; private final Map<String, EmbeddingEntry> map; public WordEmbedding(DataInput input) throws IOException { map = new HashMap<String, EmbeddingEntry>(); read(input); } public WordEmbedding(File dataFile) throws IOException { this(dataFile, DEF_DIMENSIONALITY); } public WordEmbedding(File dataFile, int dim) throws IOException { map = new HashMap<String, EmbeddingEntry>(); if (dataFile != null) { InputStream is = new FileInputStream(dataFile); if (dataFile.toString().endsWith(".gz")) is = new GZIPInputStream(is); BufferedReader reader = new BufferedReader(new InputStreamReader( is, "UTF8")); populateEmbedding(reader, dim); reader.close(); } } private void populateEmbedding(BufferedReader reader, int dim) throws IOException { // int sInd=1; // int lInd=1; // Map<Integer,Integer> sM=new HashMap<Integer,Integer>(); // Map<Integer,Integer> lM=new HashMap<Integer,Integer>(); String line; Pattern tab = Pattern.compile(" "); int lineCount = 0; int saveCount = 0; while ((line = reader.readLine()) != null) { lineCount++; String[] cols = tab.split(line); // int count=Integer.parseInt(cols[2]); // if(count<threshold) // continue; saveCount++; // Ok, we store it. Now calculate the short and long bit strings as // short values: double[] entry = new double[cols.length - 1]; for (int i = 0; i < entry.length; i++) entry[i] = Double.parseDouble(cols[i + 1]); // Integer _long=Integer.parseInt(cols[0], 2); // Integer // _short=Integer.parseInt(cols[0].length()>shortLen?cols[0].substring(0, // shortLen):cols[0], 2); // Integer _s=sM.get(_short); // if(_s==null){ // _s=sInd++; // sM.put(_short, _s); // } // Integer _l=lM.get(_long); // if(_l==null){ // _l=lInd++; // lM.put(_long, _l); // } map.put(cols[0], new EmbeddingEntry(entry)); } System.out.println("Initiated word embedding. Read " + lineCount + " lines, saved " + saveCount); } static final class EmbeddingEntry implements Serializable { private static final long serialVersionUID = 1L; final double[] entry; public EmbeddingEntry(double[] entry) { this.entry = new double[entry.length]; for (int i = 0; i < entry.length; i++) this.entry[i] = entry[i]; } } // Helpers for the feature functions: public Double getValue(String s, int dim) { EmbeddingEntry ee = map.get(s); if (ee == null) return 0.0; if (ee.entry.length <= dim) return 0.0; return ee.entry[dim]; } // For reading and writing manually (not really needed if serializable is // used) private static final String MAGIC_STRING = "EMBEDDING-MAGIC-KEY"; private void read(DataInput input) throws IOException { String foo = input.readUTF(); if (!foo.equals(MAGIC_STRING)) throw new Error( "Error reading word embedding. Magic string not found."); int entries = input.readInt(); for (int i = 0; i < entries; ++i) { String str = input.readUTF(); double[] entry = new double[DEF_DIMENSIONALITY]; for (int j = 0; j < entry.length; j++) entry[j] = input.readDouble(); map.put(str, new EmbeddingEntry(entry)); } } public void write(DataOutput output) throws IOException { output.writeUTF(MAGIC_STRING); output.writeInt(map.size()); for (Entry<String, EmbeddingEntry> e : map.entrySet()) { EmbeddingEntry ce = e.getValue(); output.writeUTF(e.getKey()); for (int i = 0; i < ce.entry.length; i++) output.writeDouble(ce.entry[i]); } } // To test public static void main(String[] args) throws IOException { File input = new File( "/afs/inf.ed.ac.uk/user/m/mroth/s-case/mate/embeddings/CW_embeddings_by_turian_50dims_scaled.txt"); WordEmbedding c = new WordEmbedding(input); String[] examples = { "believe", "hello", "hi", "bye", "banana", "apple", "pepsi", "beer", "wine", "water", "asasfasfaf", "drink", "eat", "ate", "drank", "drunk", "eaten", "devour" }; // Write out some stuff printExamples(examples, c); // Save the cluster DataOutputStream dos = new DataOutputStream(new BufferedOutputStream( new FileOutputStream("foobar"))); c.write(dos); dos.close(); // Reread it and write out the same stuff DataInputStream dis = new DataInputStream(new BufferedInputStream( new FileInputStream("foobar"))); WordEmbedding c2 = new WordEmbedding(dis); printExamples(examples, c2); } private static void printExamples(String[] examples, WordEmbedding c) { System.out.printf("%12s | %12s | %12s\n", "Form", "Dim 0", "Dim 1"); for (String e : examples) { EmbeddingEntry ce = c.map.get(e); Object[] o; if (ce == null) o = new Object[] { e, "null", "null" }; else o = new Object[] { e, new Double(ce.entry[0]), new Double(ce.entry[1]) }; System.out.printf("%-12s | %12s | %12s\n", o); } System.out.println(); } }