package edu.uw.easysrl.syntax.tagger; import java.io.File; import java.io.FilenameFilter; import java.io.IOException; import java.nio.file.Files; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.regex.Pattern; import java.util.stream.Collectors; import no.uib.cipr.matrix.DenseMatrix; import no.uib.cipr.matrix.DenseVector; import no.uib.cipr.matrix.Matrix; import no.uib.cipr.matrix.Vector; import com.google.common.collect.ImmutableMap; import com.google.common.io.PatternFilenameFilter; import edu.uw.easysrl.main.InputReader.InputWord; import edu.uw.easysrl.syntax.grammar.Category; import edu.uw.easysrl.syntax.model.CutoffsDictionaryInterface; import edu.uw.easysrl.util.Util; public class TaggerEmbeddings extends Tagger { private final Matrix weightMatrix; private final Vector bias; private final Map<String, double[]> discreteFeatures; private final Map<String, double[]> embeddingsFeatures; private final int totalFeatures; /** * Number of words forward/backward to use as context (so a value of 3 means the tagger looks at 3+3+1=7 words). */ private final int contextWindow = 3; // Special words used in the embeddings tables. private final static String leftPad = "*left_pad*"; private final static String rightPad = "*right_pad*"; private final static String unknownLower = "*unknown_lower*"; private final static String unknownUpper = "*unknown_upper*"; private final static String unknownSpecial = "*unknown_special*"; private static final String capsLower = "*lower_case*"; private static final String capsUpper = "*upper_case*"; private final static String capitalizedPad = "*caps_pad*"; private final static String suffixPad = "*suffix_pad*"; private final static String unknownSuffix = "*unknown_suffix*"; /** * Indices for POS-tags, if using them as features. */ private final Map<String, Integer> posFeatures; /** * Indices for specific words, if using them as features. */ private final Map<String, Integer> lexicalFeatures; private final List<Vector> weightMatrixRows; private final Map<Category, Integer> categoryToIndex; public TaggerEmbeddings(final File modelFolder, final double beta, final int maxTagsPerWord, final CutoffsDictionaryInterface cutoffs) throws IOException { super(cutoffs, beta, loadCategories(new File(modelFolder, "categories")), maxTagsPerWord); try { final FilenameFilter embeddingsFileFilter = new PatternFilenameFilter("embeddings.*"); // If we're using POS tags or lexical features, load l. this.posFeatures = loadSparseFeatures(new File(modelFolder + "/postags")); this.lexicalFeatures = loadSparseFeatures(new File(modelFolder + "/frequentwords")); // Load word embeddings. embeddingsFeatures = loadEmbeddings(true, modelFolder.listFiles(embeddingsFileFilter)); // Load embeddings for capitalization and suffix features. discreteFeatures = new HashMap<>(); discreteFeatures.putAll(loadEmbeddings(false, new File(modelFolder, "capitals"))); discreteFeatures.putAll(loadEmbeddings(false, new File(modelFolder, "suffix"))); totalFeatures = (embeddingsFeatures.get(unknownLower).length + discreteFeatures.get(unknownSuffix).length + discreteFeatures.get(capsLower).length + posFeatures.size() + lexicalFeatures.size()) * (2 * contextWindow + 1); // Load the list of categories used by the model. categoryToIndex = new HashMap<>(); for (int i = 0; i < lexicalCategories.size(); i++) { categoryToIndex.put(lexicalCategories.get(i), i); } // Load the weight matrix used by the classifier. weightMatrix = new DenseMatrix(lexicalCategories.size(), totalFeatures); loadMatrix(weightMatrix, new File(modelFolder, "classifier")); weightMatrixRows = new ArrayList<>(lexicalCategories.size()); for (int i = 0; i < lexicalCategories.size(); i++) { final Vector row = new DenseVector(totalFeatures); for (int j = 0; j < totalFeatures; j++) { row.set(j, weightMatrix.get(i, j)); } weightMatrixRows.add(row); } bias = new DenseVector(lexicalCategories.size()); loadVector(bias, new File(modelFolder, "bias")); } catch (final Exception e) { throw new RuntimeException(e); } } private Map<String, Integer> loadSparseFeatures(final File posTagFeaturesFile) throws IOException { Map<String, Integer> posFeatures; if (posTagFeaturesFile.exists()) { posFeatures = new HashMap<>(); for (final String line : Util.readFile(posTagFeaturesFile)) { posFeatures.put(line, posFeatures.size()); } posFeatures = ImmutableMap.copyOf(posFeatures); } else { posFeatures = Collections.emptyMap(); } return posFeatures; } /** * Loads the neural network weight matrix. */ private void loadMatrix(final Matrix matrix, final File file) throws IOException { final Iterator<String> lines = Util.readFileLineByLine(file); int row = 0; while (lines.hasNext()) { final String line = lines.next(); final String[] fields = line.split(" "); for (int i = 0; i < fields.length; i++) { matrix.set(row, i, Double.valueOf(fields[i])); } row++; } } private void loadVector(final Vector vector, final File file) throws IOException { final Iterator<String> lines = Util.readFileLineByLine(file); int row = 0; while (lines.hasNext()) { final String data = lines.next(); vector.set(row, Double.valueOf(data)); row++; } } public static List<Category> loadCategories(final File catFile) throws IOException { return Files.lines(catFile.toPath()).map(Category::valueOf).collect(Collectors.toList()); } /* * (non-Javadoc) * * @see uk.ac.ed.easyccg.syntax.Tagger#tag(java.util.List) */ @Override public List<List<ScoredCategory>> tag(final List<InputWord> words) { final List<List<ScoredCategory>> result = new ArrayList<>(words.size()); for (int wordIndex = 0; wordIndex < words.size(); wordIndex++) { result.add(getTagsForWord(getVectorForWord(words, wordIndex), words.get(wordIndex))); } return result; } private Vector getVectorForWord(final List<InputWord> words, final int wordIndex) { final double[] vector = new double[totalFeatures]; int vectorIndex = 0; for (int sentencePosition = wordIndex - contextWindow; sentencePosition <= wordIndex + contextWindow; sentencePosition++) { vectorIndex = addToFeatureVector(vectorIndex, vector, sentencePosition, words); // If using lexical features, update the vector. if (lexicalFeatures.size() > 0) { if (sentencePosition >= 0 && sentencePosition < words.size()) { final Integer index = lexicalFeatures.get(words.get(sentencePosition).word); if (index != null) { vector[vectorIndex + index] = 1; } } vectorIndex = vectorIndex + lexicalFeatures.size(); } // If using POS-tag features, update the vector. if (posFeatures.size() > 0) { if (sentencePosition >= 0 && sentencePosition < words.size()) { vector[vectorIndex + posFeatures.get(words.get(sentencePosition).pos)] = 1; } vectorIndex = vectorIndex + posFeatures.size(); } } // System.out.println(words.get(wordIndex).word+ " " + // Doubles.asList(vector)); return new DenseVector(vector); } /** * Adds the features for the word in the specified position to the vector, and returns the next empty index in the * vector. */ private int addToFeatureVector(int vectorIndex, final double[] vector, final int sentencePosition, final List<InputWord> words) { final double[] embedding = getEmbedding(words, sentencePosition); vectorIndex = addToVector(vectorIndex, vector, embedding); final double[] suffix = getSuffix(words, sentencePosition); vectorIndex = addToVector(vectorIndex, vector, suffix); final double[] caps = getCapitalization(words, sentencePosition); vectorIndex = addToVector(vectorIndex, vector, caps); return vectorIndex; } private int addToVector(int index, final double[] vector, final double[] embedding) { System.arraycopy(embedding, 0, vector, index, embedding.length); index = index + embedding.length; return index; } /** * * @param normalize * If true, words are lower-cased with numbers replaced * @param embeddingsFiles * @return * @throws IOException */ private Map<String, double[]> loadEmbeddings(final boolean normalize, final File... embeddingsFiles) throws IOException { final Map<String, double[]> embeddingsMap = new HashMap<>(); // Allow sharded input, by allowing the embeddings to be split across // multiple files. for (final File embeddingsFile : embeddingsFiles) { final Iterator<String> lines = Util.readFileLineByLine(embeddingsFile); while (lines.hasNext()) { final String line = lines.next(); // Lines have the format: word dim1 dim2 dim3 ... String word = line.substring(0, line.indexOf(" ")); if (normalize) { word = normalize(word); } if (!embeddingsMap.containsKey(word)) { final String[] fields = line.split(" "); final double[] embeddings = new double[fields.length - 1]; for (int i = 1; i < fields.length; i++) { embeddings[i - 1] = Double.valueOf(fields[i]); } embeddingsMap.put(word, embeddings); } } } return embeddingsMap; } /** * Normalizes words by lower-casing and replacing numbers with '#'/ */ private final static Pattern numbers = Pattern.compile("[0-9]"); private String normalize(String word) { word = numbers.matcher(word.toLowerCase()).replaceAll("#"); return word; } /** * Loads the embedding for the word at the specified index in the sentence. The index is allowed to be outside the * sentence range, in which case the appropriate 'padding' embedding is returned. */ private double[] getEmbedding(final List<InputWord> words, final int index) { if (index < 0) { return embeddingsFeatures.get(leftPad); } if (index >= words.size()) { return embeddingsFeatures.get(rightPad); } String word = words.get(index).word; word = translateBrackets(word); final double[] result = embeddingsFeatures.get(normalize(word)); if (result == null) { final char firstCharacter = word.charAt(0); final boolean isLower = 'a' <= firstCharacter && firstCharacter <= 'z'; final boolean isUpper = 'A' <= firstCharacter && firstCharacter <= 'Z'; if (isLower) { return embeddingsFeatures.get(unknownLower); } else if (isUpper) { return embeddingsFeatures.get(unknownUpper); } else { return embeddingsFeatures.get(unknownSpecial); } } return result; } /** * Loads the embedding for a word's 2-character suffix. The index is allowed to be outside the sentence range, in * which case the appropriate 'padding' embedding is returned. */ private double[] getSuffix(final List<InputWord> words, final int index) { String suffix = null; if (index < 0 || index >= words.size()) { suffix = suffixPad; } else { String word = words.get(index).word; word = translateBrackets(word); if (word.length() > 1) { suffix = (word.substring(word.length() - 2, word.length())); } else { // Padding for words of length 1. suffix = ("_" + word.substring(0, 1)); } } double[] result = discreteFeatures.get(suffix.toLowerCase()); if (result == null) { result = discreteFeatures.get(unknownSuffix); } return result; } /** * Loads the embedding for a word's capitalization. The index is allowed to be outside the sentence range, in which * case the appropriate 'padding' embedding is returned. */ private double[] getCapitalization(final List<InputWord> words, final int index) { String key; if (index < 0 || index >= words.size()) { key = capitalizedPad; } else { final String word = words.get(index).word; final char c = word.charAt(0); if ('A' <= c && c <= 'Z') { key = capsUpper; } else { key = capsLower; } } return discreteFeatures.get(key); } /** * weights(cat1) ... weights(cat2) ... ... bias(cat1) bias(cat2) */ public double[] getWeightVector() { final double[] result = new double[(totalFeatures + 1) * lexicalCategories.size()]; int index = 0; for (final Vector vector : weightMatrixRows) { for (int i = 0; i < vector.size(); i++) { result[index] = vector.get(i); index++; } } for (int i = 0; i < bias.size(); i++) { result[index] = bias.get(i); index++; } return result; } /** * Returns a list of @SyntaxTreeNode for this word, sorted by their probability. * * @param vector * A vector * @param word * The word itself. * @param wordIndex * The position of the word in the sentence. * @return * @return */ private List<ScoredCategory> getTagsForWord(final Vector vector, final InputWord word) { // If we're using a tag dictionary, consider those tags --- otherwise, // try all tags. Collection<Integer> possibleCategories = tagDict.get(word.word); if (possibleCategories == null) { possibleCategories = tagDict.get(TagDict.OTHER_WORDS); } return getTagsForWord(vector, possibleCategories); } private List<ScoredCategory> getTagsForWord(final Vector vector, final Collection<Integer> possibleCategories) { final int size = Math.min(maxTagsPerWord, possibleCategories.size()); double bestScore = 0.0; List<ScoredCategory> result = new ArrayList<>(possibleCategories.size()); for (final Integer cat : possibleCategories) { final double score = weightMatrixRows.get(cat).dot(vector) + bias.get(cat); result.add(new ScoredCategory(lexicalCategories.get(cat), score)); bestScore = Math.max(bestScore, score); } Collections.sort(result); if (result.size() > size) { result = result.subList(0, size); } final double threshold = beta * Math.exp(bestScore); for (int i = 2; i < result.size(); i++) { // TODO binary search if (Math.exp(result.get(i).getScore()) < threshold) { result = result.subList(0, i); break; } } return result; } @Override public Map<Category, Double> getCategoryScores(final List<InputWord> sentence, final int wordIndex, final double weight, final Collection<Category> categories) { final List<ScoredCategory> scoredCats = getTagsForWord(getVectorForWord(sentence, wordIndex), categories.stream().map(x -> categoryToIndex.get(x)).collect(Collectors.toList())); return scoredCats.stream().collect(Collectors.toMap(ScoredCategory::getCategory, x -> x.getScore() * weight)); } }