package org.thunlp.text.classifiers; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.EOFException; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; import java.text.DecimalFormat; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.Hashtable; import java.util.Map; import java.util.PriorityQueue; import java.util.Vector; import libsvm.svm; import libsvm.svm_model; import libsvm.svm_node; import libsvm.svm_parameter; import libsvm.svm_problem; import de.bwaldvogel.liblinear.*; import de.bwaldvogel.*; import org.apache.commons.codec.binary.Base64; import org.thunlp.io.TextFileWriter; import org.thunlp.language.chinese.LangUtils; import org.thunlp.language.chinese.WordSegment; import org.thunlp.text.DocumentVector; import org.thunlp.text.Lexicon; import org.thunlp.text.Term; import org.thunlp.text.TfIdfTermWeighter; import org.thunlp.text.TfOnlyTermWeighter; import org.thunlp.text.Lexicon.Word; public abstract class LiblinearTextClassifier implements TextClassifier{ public Lexicon lexicon; // 词典 private DocumentVector trainingVectorBuilder; // 用来构造训练特征向量 private DocumentVector testVectorBuilder; // 用来构造待分类文本的特征向量 private WordSegment seg; //private svm_model model; // 训练好的模型 private de.bwaldvogel.liblinear.Model lmodel; private int maxFeatures = 5000; // 默认的最大特征数 private int nclasses; // 类别数 private int longestDoc; // 最长的文档向量长度,决定读取临时文件时缓冲大小 private int ndocs; //训练集的大小 public ArrayList<Integer> labelIndex = new ArrayList<Integer>(); // 类别标签 public File tsCacheFile; // 训练集的cache文件,存放在磁盘上 public DataOutputStream tsCache = null; // 训练集的cache输出流 public int getLongestDoc() { return longestDoc; } public void init ( int nclasses, WordSegment seg) { lexicon = new Lexicon(); trainingVectorBuilder = new DocumentVector(lexicon, new TfOnlyTermWeighter()); testVectorBuilder = null; //model = null; lmodel = null; this.nclasses = nclasses; ndocs = 0; this.seg = seg; } public Lexicon getLexicon() { return lexicon; } public void clear () { lexicon = null; trainingVectorBuilder = null; testVectorBuilder = null; lmodel = null; seg = null; labelIndex = null; } abstract protected WordSegment initWordSegment(); public LiblinearTextClassifier( int nclasses ) { init( nclasses, initWordSegment()); } /** * 初始化一个基于bigram和svm的中文文本分类器 * @param nclasses 类别数 */ public LiblinearTextClassifier( int nclasses, WordSegment seg ) { init( nclasses, seg); } /** * 利用Scalable Term Selection方法进行特征选择 * @author sames * @param cacheFile 数据集,其中term的weight应该是tf * @param featureSize 数据集中特征的总数,特征应该是从0到featureSize编号 * @param kept 需要保留的特征数 * @param ndocs 训练集文档数 * @param nclasses 类别数 * @param longestDoc 最长文档含有的特征数 * @return */ public Map<Integer, Integer> selectFeatureBySTS(File cacheFile, int featureSize, int kept, int ndocs, int nclasses, int longestDoc) { // lamda初始值 double lamda = 0.5; int[][] featureStats = new int[featureSize][nclasses]; int[] featureFreq = new int[featureSize]; double[] prValues = new double[featureSize]; PriorityQueue<Term> selectedFeatures; int[] classSize = new int[nclasses]; // 统计chi-square需要的计数 int label = 0; int nterms = 0; double sum = 0; Term[] terms = new Term[longestDoc + 1]; for (int i = 0; i < terms.length; i++) { terms[i] = new Term(); } int ndocsread = 0; try { DataInputStream dis = new DataInputStream(new BufferedInputStream( new FileInputStream(cacheFile))); while (true) { int ncut = 0; try { label = dis.readInt(); nterms = dis.readInt(); sum += nterms; for (int i = 0; i < nterms; i++) { terms[i].id = dis.readInt(); terms[i].weight = dis.readDouble(); if ( lexicon.getWord(terms[i].id).getDocumentFrequency() == 1 ) ncut++; } } catch (EOFException e) { break; } sum -= ncut; classSize[label]++; for (int i = 0; i < nterms; i++) { Term t = terms[i]; featureStats[t.id][label]++; featureFreq[t.id]++; } } dis.close(); } catch (IOException e) { return null; } System.err.println("start STS calculation"); // 计算chi^2_avg(t),这里利用一个优先级队列来选择chi^2最高的特征 selectedFeatures = new PriorityQueue<Term>(kept + 1, new Comparator<Term>() { public int compare(Term t0, Term t1) { return Double.compare(t0.weight, t1.weight); } }); long A, B, C, D; for (int i = 0; i < featureSize; i++) { double pr = -1; double prmax = -1; for (int j = 0; j < nclasses; j++) { A = featureStats[i][j]; B = featureFreq[i] - A; C = classSize[j]; D = ndocs - C; double fractorBase = (double) (B * C); if (Double.compare(fractorBase, 0.0) == 0) { pr = Double.MAX_VALUE; } else { pr = (double) (A * D) / fractorBase; if (pr > prmax) { prmax = pr; prValues[i] = prmax; } } } } double targetAVL = Math.pow(sum / ndocs, 0.085 * Math.log(kept)); Term[] featuresToSort = new Term[kept]; double first = 0; double second = 1; int iteration = 1; while (true) { selectedFeatures.clear(); for (int i = 0; i < featureSize; i++) { if ( lexicon.getWord(i).getDocumentFrequency() == 1 ) continue; Term t = new Term(); t.id = i; t.weight = 1.0 / (lamda / Math.log(prValues[i]) + (1 - lamda) / Math.log(featureFreq[i])); selectedFeatures.add(t); if (selectedFeatures.size() > kept) { selectedFeatures.poll(); } } double AVL = 0; int n = 0; while (selectedFeatures.size() > 0) { Term t = selectedFeatures.poll(); featuresToSort[n] = t; n++; AVL += featureFreq[t.id]; } Arrays.sort(featuresToSort, new Term.TermIdComparator()); AVL /= ndocs; System.out.println("Iteration:" + iteration + " lamda = " + lamda + " AVL = " + AVL + " Target AVL = " + targetAVL); if (Math.abs(AVL - targetAVL) < 0.1) break; else { if (AVL < targetAVL) { second = lamda; lamda = (first + lamda) / 2; } else { first = lamda; lamda = (lamda + second) / 2; } if( Math.abs(second - first) < 1.0E-13 ) break; } iteration ++; } System.err.println("generating feature map"); // 生成旧id和新选择的id的对应表 Map<Integer, Integer> fidmap = new Hashtable<Integer, Integer>(kept); for (int i = 0; i < featuresToSort.length; i++) { fidmap.put(featuresToSort[i].id, i); } return fidmap; } /** * 利用chi-square统计量来进行特征选择 * @param dataSet 数据集,其中term的weight应该是tf * @param featureSize 数据集中特征的总数,特征应该是从0到featureSize编号的 * @param kept 要保留的特征数 * @return 选择前特征到选择后特征的id对应表,保证选择后特征的排序和选择前一样 */ public Map<Integer, Integer> selectFeaturesByChiSquare( File cacheFile, int featureSize, int kept ) { return selectFeaturesByChiSquare(cacheFile, featureSize, kept, ndocs, nclasses, longestDoc, null); } /** * 真正的特征选择函数,允许输出所有特征chimax的值,用于调试 * @param cacheFile * @param featureSize * @param kept * @param chimaxValues 每个特征的chimax值,如果为null则不记录 * @return */ public Map<Integer, Integer> selectFeaturesByChiSquare( File cacheFile, int featureSize, int kept, int ndocs, int nclasses, int longestDoc, double [] chimaxValues ) { System.out.println("selectFeatureBySTS : " + "featureSize = " + featureSize + "; kept = " + kept + "; ndocs = " + ndocs + "; nclasses = " + nclasses + "; longestDoc = " + longestDoc); int [][] featureStats = new int[featureSize][nclasses]; //某词在某类出现次数 int [] featureFreq = new int[featureSize];//某词词频 PriorityQueue<Term> selectedFeatures; int [] classSize = new int[nclasses];//每类多少篇文章 // 统计chi-square需要的计数 int label = 0; int nterms = 0; Term [] terms = new Term[longestDoc + 1]; for ( int i = 0 ; i < terms.length ; i++ ) { terms[i] = new Term(); } int ndocsread = 0; try { DataInputStream dis = new DataInputStream(new BufferedInputStream( new FileInputStream(cacheFile))); while ( true ) { try { label = dis.readInt(); nterms = dis.readInt(); //System.out.println("Reading doc "+ ndocsread + " : label = " + label + "; nterms = " + nterms); for ( int i = 0 ; i < nterms ; i++ ) { terms[i].id = dis.readInt(); terms[i].weight = dis.readDouble(); } } catch ( EOFException e ) { break; } classSize[label] ++; for ( int i = 0 ; i < nterms ; i++ ) { Term t = terms[i]; featureStats[t.id][label] ++; featureFreq[t.id] ++; } if ( ndocsread++ % 10000 == 0) { System.err.println("scanned " + ndocsread); } } dis.close(); } catch ( IOException e ) { return null; } System.err.println("start chi-square calculation"); // 计算chi^2_avg(t),这里利用一个优先级队列来选择chi^2最高的特征 selectedFeatures = new PriorityQueue<Term>( kept + 1, new Comparator<Term>() { public int compare(Term t0, Term t1) { return Double.compare(t0.weight, t1.weight); } }); long A, B, C, D; for ( int i = 0 ; i < featureSize ; i++ ) { Word w = lexicon.getWord(i); if (w != null) { if ( w.getDocumentFrequency() == 1 || w.getName().length() > 50 ) continue; } double chisqr = -1; double chimax = -1; for ( int j = 0 ; j < nclasses ; j++ ) { A = featureStats[i][j]; B = featureFreq[i] - A; C = classSize[j] - A; D = ndocs - A - B - C; //System.out.println("A:"+A+" B:"+B+" C:"+C+" D:"+D); double fractorBase = (double)( (A+C) * (B+D) * (A+B) * (C+D) ); if ( Double.compare(fractorBase, 0.0 ) == 0 ) { chisqr = 0; } else { // 我们不用ndocs,因为所有特征的ndocs都一样 //chisqr = ndocs * ( A*D -B*C) * (A*D - B*C) / fractorBase ; chisqr = ( A*D -B*C) / fractorBase * (A*D - B*C) ; } if ( chisqr > chimax ) { chimax = chisqr; } // 被注释的方法是计算chi^2_avg即概率加权平均的卡方值。我们实际用的是chimax // chisqr += (classSize[j] / (double) ndocs) * // ndocs * ( A*D -B*C) * (A*D - B*C) // / (double)( (A+C) * (B+D) * (A+B) * (C+D) ) ; } if ( chimaxValues != null ) { chimaxValues[i] = chimax; } Term t = new Term(); t.id = i; t.weight = chimax; selectedFeatures.add(t); if ( selectedFeatures.size() > kept ) { selectedFeatures.poll(); } } outputSecletedFeatures(selectedFeatures); System.err.println("generating feature map"); // 生成旧id和新选择的id的对应表 Map<Integer, Integer> fidmap = new Hashtable<Integer, Integer>(kept); Term [] featuresToSort = new Term[selectedFeatures.size()]; int n = 0; while ( selectedFeatures.size() > 0 ) { Term t = selectedFeatures.poll(); featuresToSort[n] = t; n++; } Arrays.sort(featuresToSort, new Term.TermIdComparator()); for ( int i = 0 ; i < featuresToSort.length ; i++ ) { fidmap.put(featuresToSort[i].id, i); } return fidmap; } public void outputSecletedFeatures(PriorityQueue<Term> features){ System.out.println("store features...======================================="); try{ TextFileWriter tw = new TextFileWriter("selectedFeatures","UTF-8"); Term[] f; f = features.toArray(new Term[features.size()]); System.out.println(f.length); for(int i=f.length-1; i>=0; i--){ tw.writeLine(lexicon.getWord(f[i].id).getName() + " " + f[i].weight); } tw.flush(); tw.close(); }catch(IOException e){ e.printStackTrace(); } System.out.println("end store features...======================================="); } public void setMaxFeatures( int max ) { maxFeatures = max; } public int getMaxFeatures() { return maxFeatures; } /** * 加入一篇训练文档。要求label是小于总类别数的整数,从0开始。 * @param text 训练文本 * @param label 类别编号 * @return 加入是否成功。不成功可能是由于不能在磁盘上创建临时文件 */ public boolean addTrainingText(String text, int label) { if ( label >= nclasses || label < 0 ) { return false; } if ( tsCache == null ) { try { //tsCacheFile = File.createTempFile("tctscache", "data"); tsCacheFile = new File(".", "tctscache" + Long.toString(System.currentTimeMillis()) + "data"); tsCache = new DataOutputStream( new BufferedOutputStream( new FileOutputStream(tsCacheFile))); longestDoc = 0; } catch (IOException e) { return false; } } text = LangUtils.removeEmptyLines(text); text = LangUtils.removeExtraSpaces(text); String [] bigrams = seg.segment(text); lexicon.addDocument(bigrams); Word [] words = lexicon.convertDocument(bigrams); bigrams = null; Term [] terms = trainingVectorBuilder.build( words, false ); try { tsCache.writeInt(label); tsCache.writeInt(terms.length); if ( terms.length > longestDoc ) { longestDoc = terms.length; } for ( int i = 0 ; i < terms.length ; i++ ) { tsCache.writeInt(terms[i].id); tsCache.writeDouble(terms[i].weight); } } catch (IOException e) { return false; } if ( ! labelIndex.contains(label) ) { labelIndex.add(label); } ndocs++; return true; } /** * 分类一篇文档 * @param text 待分类文档 * @return 分类结果,其中包含分类标签和概率,对于svm分类器,概率无意义 */ public ClassifyResult classify(String text) { String [] bigrams = seg.segment(text); Word [] words = lexicon.convertDocument(bigrams); bigrams = null; Term [] terms = testVectorBuilder.build( words, true); int m = terms.length; //svm_node[] x = new svm_node[m]; FeatureNode[] lx = new FeatureNode[m]; for(int j = 0; j < m; j++) { lx[j] = new FeatureNode(terms[j].id + 1, terms[j].weight); } ClassifyResult cr = new ClassifyResult(-1, -Double.MAX_VALUE); //double [] probs = new double[svm.svm_get_nr_class(model)]; double[] probs = new double[this.lmodel.getNrClass()]; //svm.svm_predict_probability(model, x, probs); //de.bwaldvogel.liblinear.Linear.predictValues(lmodel, lx, probs); de.bwaldvogel.liblinear.Linear.predictProbability(lmodel, lx, probs); for (int i = 0; i < probs.length; i++) { if (probs[i] > cr.prob) { cr.prob = probs[i]; cr.label = i; } } return cr; } public ClassifyResult[] classify(String text, int topN){ String [] bigrams = seg.segment(text); Word [] words = lexicon.convertDocument(bigrams); bigrams = null; Term [] terms = testVectorBuilder.build(words, true); int m = terms.length; //svm_node[] x = new svm_node[m]; FeatureNode[] lx = new FeatureNode[m]; for(int j = 0; j < m; j++) { lx[j] = new FeatureNode(terms[j].id + 1, terms[j].weight); } //double [] probs = new double[svm.svm_get_nr_class(model)]; double[] probs = new double[this.lmodel.getNrClass()]; ArrayList<ClassifyResult> cr = new ArrayList<ClassifyResult>(); //svm.svm_predict_probability(model, x, probs); //de.bwaldvogel.liblinear.Linear.predictValues(lmodel, lx, probs); de.bwaldvogel.liblinear.Linear.predictProbability(lmodel, lx, probs); for(int i=0; i<probs.length; i++){ cr.add(new ClassifyResult(i, probs[i])); } Comparator com = new Comparator() { public int compare(Object obj1, Object obj2){ ClassifyResult o1 = (ClassifyResult)obj1; ClassifyResult o2 = (ClassifyResult)obj2; if(o1.prob > o2.prob + 1e-20) return -1; else if(o1.prob < o2.prob - 1e-20) return 1; else return 0; } }; Collections.sort(cr,com); /* double totalexp = 0.0; for(int i=0; i<probs.length;i++){ totalexp += Math.exp(probs[i]); } for(int i=0; i<topN; i++){ results[i] = ""+ al.get(i).index +" "+ Math.exp(al.get(i).value)/totalexp; }*/ java.text.DecimalFormat dcmFmt = new DecimalFormat("0.0000"); ClassifyResult result[] = new ClassifyResult[topN]; for(int i=0; i<topN; i++){ result[i] = new ClassifyResult(cr.get(i).label, cr.get(i).prob); } cr.clear(); //System.out.println(""+totalexp+results[0]); return result; } public ClassifyResult[] classify(String text, String mode){ /* double mean1 = 0.9587235538481998; double mean2 = 0.027685258283684743; double mean3 = 0.0047742850704747064; double sd1 = 0.11380367342131785; double sd2 = 0.07759971712653867; double sd3 = 0.02049807196764366; */ double mean1 = 0.8976; double mean2 = 0.0663; double mean3 = 0.0106; double sd1 = 0.1547; double sd2 = 0.1165; double sd3 = 0.0275; String [] bigrams = seg.segment(text); Word [] words = lexicon.convertDocument(bigrams); bigrams = null; Term [] terms = testVectorBuilder.build( words, true); int m = terms.length; FeatureNode[] lx = new FeatureNode[m]; for(int j = 0; j < m; j++){ lx[j] = new FeatureNode(terms[j].id + 1, terms[j].weight); } double[] probs = new double[this.lmodel.getNrClass()]; //de.bwaldvogel.liblinear.Linear.predictValues(lmodel, lx, probs); de.bwaldvogel.liblinear.Linear.predictProbability(lmodel, lx, probs); ArrayList<ClassifyResult> al = new ArrayList<ClassifyResult>(); for(int i=0; i<probs.length; i++){ al.add(new ClassifyResult(i, probs[i])); } if(al.size() ==0){ System.err.println("error!result size is 0!"); return null; } Comparator com = new Comparator() { public int compare(Object obj1, Object obj2){ ClassifyResult o1 = (ClassifyResult)obj1; ClassifyResult o2 = (ClassifyResult)obj2; if(o1.prob > o2.prob + 0.0000000001) return -1; else if(o1.prob <o2.prob - 0.0000000001) return 1; else return 0; } }; Collections.sort(al,com); int num=0; ArrayList<ClassifyResult> res = new ArrayList<ClassifyResult>(); java.text.DecimalFormat dcmFmt = new DecimalFormat("0.0000"); res.add(new ClassifyResult(al.get(0).label, al.get(0).prob)); if(al.size() >=3){ if(al.get(1).prob-mean2>2*sd2/3){ res.add(new ClassifyResult(al.get(1).label, al.get(1).prob)); if(al.get(0).prob+al.get(1).prob<0.99){ if(al.get(2).prob-mean3>sd3){ res.add(new ClassifyResult(al.get(2).label, al.get(2).prob)); } } } } return res.toArray(new ClassifyResult[res.size()]); //System.out.println(""+totalexp+results[0]); //return results; } /** * 从磁盘上加载训练好的模型 * @param filename 模型文件名(是一个目录) * @return 加载是否成功 */ public boolean loadModel(String filename) { File modelPath = new File(filename); if ( ! modelPath.isDirectory() ) return false; File lexiconFile = new File( modelPath, "lexicon"); File modelFile = new File( modelPath, "model"); System.out.println(lexiconFile.getAbsolutePath()); try { if ( lexiconFile.exists() ) { lexicon.loadFromFile(lexiconFile); System.out.println("lexicon exists!"); } else { return false; } if ( modelFile.exists() ) { //this.model = svm.svm_load_model(modelFile.getAbsolutePath()); //this.lmodel = de.bwaldvogel.liblinear.Linear.loadModel(new File(modelFile.getAbsolutePath())); System.out.println("model exists!"); this.lmodel = de.bwaldvogel.liblinear.Linear.loadModel(modelFile); } else { return false; } } catch ( Exception e ) { return false; } lexicon.setLock( true ); trainingVectorBuilder = null; testVectorBuilder = new DocumentVector(lexicon, new TfIdfTermWeighter(lexicon)); return true; } /** * 将训练好的模型保存到磁盘 * @param filename 保存的文件名(实际是一个目录) * @return 保存是否成功 */ public boolean saveModel(String filename) { File modelPath = new File(filename); if (!modelPath.exists() && !modelPath.mkdir() ) { return false; } File lexiconFile = new File( modelPath, "lexicon"); File modelFile = new File( modelPath, "model"); try { lexicon.saveToFile(lexiconFile); //svm.svm_save_model(modelFile.getAbsolutePath(), model); de.bwaldvogel.liblinear.Linear.saveModel(new File(modelFile.getAbsolutePath()), lmodel); } catch (IOException e ) { return false; } return true; } /** * 训练模型 * @return 训练是否成功。不成功可能是由于不能正确地读写临时文件造成的 */ public boolean train() { try { tsCache.close(); } catch (IOException e) { return false; } Map<Integer, Integer> selectedFeatures = selectFeaturesByChiSquare( tsCacheFile, lexicon.getSize(), maxFeatures); //以下注释的代码为用李景阳论文Scalable Term Selection方法选择特征,目前未经完全测试通过!! // Map<Integer, Integer> selectedFeatures = selectFeatureBySTS( // tsCacheFile, lexicon.getSize(), maxFeatures, ndocs, nclasses, // longestDoc); if ( selectedFeatures == null ) { return false; } System.err.println("feature selection complete"); //svm_problem problem = createLibSVMProblem(tsCacheFile, selectedFeatures); ///////////////////add de.bwaldvogel.liblinear.Problem lproblem = createLiblinearProblem(tsCacheFile, selectedFeatures); System.err.println("liblinear problem created"); lexicon = lexicon.map( selectedFeatures ); lexicon.setLock( true ); tsCacheFile.delete(); trainingVectorBuilder = null; testVectorBuilder = new DocumentVector(lexicon, new TfIdfTermWeighter(lexicon)); de.bwaldvogel.liblinear.Parameter lparam = new Parameter(SolverType.L1R_LR, 500, 0.01); //de.bwaldvogel.liblinear.Parameter lparam = new Parameter(solverType, C, eps) de.bwaldvogel.liblinear.Model tempModel = de.bwaldvogel.liblinear.Linear.train(lproblem, lparam); System.err.println("TRAINING COMPLETE========================================================================================="); this.lmodel = tempModel; //this.model = (svm_model)tempModel; return true; } private static class DataNode implements Comparable{ int label; svm_node [] nodes; public int compareTo( Object o ) { DataNode other = (DataNode) o; return label - other.label; } } private static class LdataNode implements Comparable{ int llabel; FeatureNode [] lnodes; public int compareTo( Object o ) { LdataNode other = (LdataNode) o; return llabel - other.llabel; } } private de.bwaldvogel.liblinear.Problem createLiblinearProblem( File cacheFile, Map<Integer, Integer> selectedFeatures){ Vector<Double> vy = new Vector<Double>(); Vector<svm_node[]> vx = new Vector<svm_node[]>(); //DataNode [] datanodes = new DataNode[this.ndocs]; LdataNode [] ldatanodes = new LdataNode[this.ndocs]; //FeatureNode[][] lfeatureNodes; int label, nterms; Term [] terms = new Term[longestDoc + 1]; for ( int i = 0 ; i < terms.length ; i++ ) { terms[i] = new Term(); } int ndocsread = 0; //add------------------------ int maxIndex=0; try { DataInputStream dis = new DataInputStream(new BufferedInputStream( new FileInputStream(cacheFile))); while ( true ) { int n = 0; try { label = dis.readInt(); nterms = dis.readInt(); for ( int i = 0 ; i < nterms ; i++ ) { int tid = dis.readInt(); double tweight = dis.readDouble(); Integer id = selectedFeatures.get(tid); if ( id != null ) { terms[n].id = id; //add maxIndex=Math.max(maxIndex, id+1); Word w = lexicon.getWord(tid); int df = w.getDocumentFrequency(); terms[n].weight = Math.log( tweight + 1 ) * ( Math.log( (double) ( ndocs + 1 ) / df ) ); n++; //System.err.println("doc " + id + " " + w); } } } catch ( EOFException e ) { break; } //---------------------------- //lfeatureNodes = new FeatureNode[this.ndocs][n]; //----------------------- //System.out.println("===================================================n: "+n); // 归一化向量 double normalizer = 0; for ( int i = 0 ; i < n ; i++ ) { normalizer += terms[i].weight * terms[i].weight; } normalizer = Math.sqrt(normalizer); for ( int i = 0 ; i < n ; i++ ) { terms[i].weight /= normalizer; } //datanodes[ndocsread] = new DataNode(); // 放入svm problem中 ldatanodes[ndocsread] = new LdataNode(); ldatanodes[ndocsread].llabel= label; FeatureNode[] lx = new FeatureNode[n]; for ( int i = 0; i < n ; i++ ) { lx[i] = new FeatureNode(terms[i].id + 1,terms[i].weight); } ldatanodes[ndocsread].lnodes = lx; if ( ndocsread++ % 10000 == 0) { System.err.println("scanned " + ndocsread); } } dis.close(); } catch ( IOException e ) { return null; } assert( this.ndocs == ndocsread ); Arrays.sort( ldatanodes ); //svm_problem prob = new svm_problem(); de.bwaldvogel.liblinear.Problem lprob = new de.bwaldvogel.liblinear.Problem(); /* prob.l = datanodes.length; prob.x = new svm_node[prob.l][]; for( int i = 0 ; i < prob.l ; i++ ) prob.x[i] = datanodes[i].nodes; prob.y = new double[prob.l]; for(int i = 0 ; i < prob.l ; i++ ) prob.y[i] = (double) datanodes[i].label; return prob; */ //add System.out.println("max index: -------------------------------------: " + maxIndex); lprob.n = maxIndex; lprob.l = ldatanodes.length; lprob.x = new de.bwaldvogel.liblinear.FeatureNode[lprob.l][]; for( int i = 0 ; i < lprob.l ; i++ ) //lprob.x[i] = datanodes[i].nodes; lprob.x[i]=ldatanodes[i].lnodes; lprob.y = new int[lprob.l]; for(int i = 0 ; i < lprob.l ; i++ ) lprob.y[i] = ldatanodes[i].llabel; return lprob; } /** * 根据特征选择的结果来生成一个用于训练的SVM problem * @param cacheFile 存放训练集的缓存文件 * @param selectedFeatures 特征选择的结果 * @return 构造好的svm_problem数据结构 */ public String saveToString() { ByteArrayOutputStream baos = new ByteArrayOutputStream(); try { ObjectOutputStream oos = new ObjectOutputStream(baos); oos.writeObject(this.lexicon); //oos.writeObject(this.model); oos.writeObject(this.lmodel); oos.close(); } catch (IOException e) { e.printStackTrace(); return ""; // Failed to serialize the model. } String base64 = new String(Base64.encodeBase64(baos.toByteArray())); return base64; } public void loadFromString(String model) { ByteArrayInputStream bais = new ByteArrayInputStream(Base64.decodeBase64(model.getBytes())); ObjectInputStream ois; try { ois = new ObjectInputStream(bais); this.lexicon = (Lexicon) ois.readObject(); //this.model = (svm_model) ois.readObject(); this.lmodel = (de.bwaldvogel.liblinear.Model) ois.readObject(); ois.close(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (ClassNotFoundException e) { // TODO Auto-generated catch block e.printStackTrace(); } testVectorBuilder = new DocumentVector(lexicon, new TfIdfTermWeighter(lexicon)); } }