package me.xiaosheng.word2vec; import java.io.File; import java.io.IOException; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeSet; import com.ansj.vec.Learn; import com.ansj.vec.Word2VEC; import com.ansj.vec.domain.WordEntry; public class Word2Vec { private Word2VEC vec; private boolean loadModel; //是否已经加载模型 public Word2Vec() { vec = new Word2VEC(); loadModel = false; } /** * 加载Google版Word2Vec模型(C语言训练) * @param modelPath 模型文件路径 * @throws IOException */ public void loadGoogleModel(String modelPath) throws IOException { vec.loadGoogleModel(modelPath); loadModel = true; } /** * 加载Java版Word2Vec模型(java语言训练) * @param modelPath 模型文件路径 * @throws IOException */ public void loadJavaModel(String modelPath) throws IOException { vec.loadJavaModel(modelPath); loadModel = true; } /** * 训练Java版Word2Vec模型 * @param trainFilePath 训练文件路径 * @param modelFilePath 模型文件路径 * @throws IOException */ public static void trainJavaModel(String trainFilePath, String modelFilePath) throws IOException { Learn learn = new Learn(); long start = System.currentTimeMillis(); learn.learnFile(new File(trainFilePath)); System.out.println("use time " + (System.currentTimeMillis() - start)); learn.saveModel(new File(modelFilePath)); } /** * 获得词向量 * @param word * @return */ public float[] getWordVector(String word) { if (loadModel == false) { return null; } return vec.getWordVector(word); } /** * 计算向量内积 * @param vec1 * @param vec2 * @return */ private float calDist(float[] vec1, float[] vec2) { float dist = 0; for (int i = 0; i < vec1.length; i++) { dist += vec1[i] * vec2[i]; } return dist; } /** * 向量求和 * @param sum 和向量 * @param vec 添加向量 */ private void calSum(float[] sum, float[] vec) { for (int i = 0; i < sum.length; i++) { sum[i] += vec[i]; } } /** * 计算词相似度 * @param word1 * @param word2 * @return */ public float wordSimilarity(String word1, String word2) { if (loadModel == false) { return 0; } float[] word1Vec = getWordVector(word1); float[] word2Vec = getWordVector(word2); if(word1Vec == null || word2Vec == null) { return 0; } return calDist(word1Vec, word2Vec); } /** * 获取相似词语 * @param word * @param maxReturnNum * @return */ public Set<WordEntry> getSimilarWords(String word, int maxReturnNum) { if (loadModel == false) return null; float[] center = getWordVector(word); if (center == null) { return Collections.emptySet(); } int resultSize = vec.getWords() < maxReturnNum ? vec.getWords() : maxReturnNum; TreeSet<WordEntry> result = new TreeSet<WordEntry>(); double min = Double.MIN_VALUE; for (Map.Entry<String, float[]> entry : vec.getWordMap().entrySet()) { float[] vector = entry.getValue(); float dist = calDist(center, vector); if (result.size() <= resultSize) { result.add(new WordEntry(entry.getKey(), dist)); min = result.last().score; } else { if (dist > min) { result.add(new WordEntry(entry.getKey(), dist)); result.pollLast(); min = result.last().score; } } } result.pollFirst(); return result; } /** * 计算词语与词语列表中所有词语的最大相似度 * (最小返回0) * @param centerWord 词语 * @param wordList 词语列表 * @return */ private float calMaxSimilarity(String centerWord, List<String> wordList) { float max = -1; if (wordList.contains(centerWord)) { return 1; } else { for (String word : wordList) { float temp = wordSimilarity(centerWord, word); if (temp == 0) continue; if (temp > max) { max = temp; } } } if (max == -1) return 0; return max; } /** * 快速计算句子相似度 * @param sentence1Words 句子1词语列表 * @param sentence2Words 句子2词语列表 * @return 两个句子的相似度 */ public float fastSentenceSimilarity(List<String> sentence1Words, List<String> sentence2Words) { if (loadModel == false) { return 0; } if (sentence1Words.isEmpty() || sentence2Words.isEmpty()) { return 0; } float[] sen1vector = new float[vec.getSize()]; float[] sen2vector = new float[vec.getSize()]; double len1 = 0; double len2 = 0; for (int i = 0; i < sentence1Words.size(); i++) { float[] tmp = getWordVector(sentence1Words.get(i)); if (tmp != null) calSum(sen1vector, tmp); } for (int i = 0; i < sentence2Words.size(); i++) { float[] tmp = getWordVector(sentence2Words.get(i)); if (tmp != null) calSum(sen2vector, tmp); } for (int i = 0; i < vec.getSize(); i++) { len1 += sen1vector[i] * sen1vector[i]; len2 += sen2vector[i] * sen2vector[i]; } return (float) (calDist(sen1vector, sen2vector) / Math.sqrt(len1 * len2)); } /** * 计算句子相似度 * 所有词语权值设为1 * @param sentence1Words 句子1词语列表 * @param sentence2Words 句子2词语列表 * @return 两个句子的相似度 */ public float sentenceSimilarity(List<String> sentence1Words, List<String> sentence2Words) { if (loadModel == false) { return 0; } if (sentence1Words.isEmpty() || sentence2Words.isEmpty()) { return 0; } float sum1 = 0; float sum2 = 0; int count1 = 0; int count2 = 0; for (int i = 0; i < sentence1Words.size(); i++) { if (getWordVector(sentence1Words.get(i)) != null) { count1++; sum1 += calMaxSimilarity(sentence1Words.get(i), sentence2Words); } } for (int i = 0; i < sentence2Words.size(); i++) { if (getWordVector(sentence2Words.get(i)) != null) { count2++; sum2 += calMaxSimilarity(sentence2Words.get(i), sentence1Words); } } return (sum1 + sum2) / (count1 + count2); } /** * 计算句子相似度(带权值) * 每一个词语都有一个对应的权值 * @param sentence1Words 句子1词语列表 * @param sentence2Words 句子2词语列表 * @param weightVector1 句子1权值向量 * @param weightVector2 句子2权值向量 * @return 两个句子的相似度 * @throws Exception 词语列表和权值向量长度不同 */ public float sentenceSimilarity(List<String> sentence1Words, List<String> sentence2Words, float[] weightVector1, float[] weightVector2) throws Exception { if (loadModel == false) { return 0; } if (sentence1Words.isEmpty() || sentence2Words.isEmpty()) { return 0; } if (sentence1Words.size() != weightVector1.length || sentence2Words.size() != weightVector2.length) { throw new Exception("length of word list and weight vector is different"); } float sum1 = 0; float sum2 = 0; float divide1 = 0; float divide2 = 0; for (int i = 0; i < sentence1Words.size(); i++) { if (getWordVector(sentence1Words.get(i)) != null) { float wordMaxSimi = calMaxSimilarity(sentence1Words.get(i), sentence2Words); sum1 += wordMaxSimi * weightVector1[i]; divide1 += weightVector1[i]; } } for (int i = 0; i < sentence2Words.size(); i++) { if (getWordVector(sentence2Words.get(i)) != null) { float wordMaxSimi = calMaxSimilarity(sentence2Words.get(i), sentence1Words); sum2 += wordMaxSimi * weightVector2[i]; divide2 += weightVector2[i]; } } return (sum1 + sum2) / (divide1 + divide2); } }