/* * This file is part of JaTeCS. * * JaTeCS is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * JaTeCS is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with JaTeCS. If not, see <http://www.gnu.org/licenses/>. * * The software has been mainly developed by (in alphabetical order): * - Andrea Esuli ([email protected]) * - Tiziano Fagni ([email protected]) * - Alejandro Moreo Fernández ([email protected]) * Other past contributors were: * - Giacomo Berardi ([email protected]) */ package it.cnr.jatecs.classification.treeboost; import gnu.trove.TIntArrayList; import gnu.trove.TShortArrayList; import it.cnr.jatecs.classification.BaseLearner; import it.cnr.jatecs.classification.interfaces.IClassifier; import it.cnr.jatecs.classification.interfaces.ILearner; import it.cnr.jatecs.classification.interfaces.ILearnerRuntimeCustomizer; import it.cnr.jatecs.indexes.DB.interfaces.ICategoryDB; import it.cnr.jatecs.indexes.DB.interfaces.IIndex; import it.cnr.jatecs.indexing.tsr.ITsr; import it.cnr.jatecs.indexing.weighting.IWeighting; import it.cnr.jatecs.utils.IOperationStatusListener; import it.cnr.jatecs.utils.JatecsLogger; import it.cnr.jatecs.utils.Os; import it.cnr.jatecs.utils.iterators.TIntArrayListIterator; import it.cnr.jatecs.utils.iterators.TShortArrayListIterator; import it.cnr.jatecs.utils.iterators.interfaces.IIntIterator; import it.cnr.jatecs.utils.iterators.interfaces.IShortIterator; import java.util.Vector; public class SingleLabelTreeBoostLearner extends BaseLearner { /** * The learner object which construct classifiers. */ protected ILearner _learner; /** * The weighting module. */ protected IWeighting _weighting; /** * The tsr module. */ protected ITsr _tsr; /** * The status listener */ protected IOperationStatusListener _status; public SingleLabelTreeBoostLearner(ILearner learner) { this(learner, null); } public SingleLabelTreeBoostLearner(ILearner learner, IOperationStatusListener status) { assert (learner != null); _customizer = new TreeBoostLearnerCustomizer(null); _learner = learner; _weighting = null; _tsr = null; _status = status; } /** * Set the weighting module to use to compute the features weight inside * documents. * * @param w The weighting module. */ public void setWeighting(IWeighting w) { _weighting = w; } /** * Set the TSR module to use to reduce the feature space of documents. * * @param tsr The TSR module. */ public void setTSR(ITsr tsr) { _tsr = tsr; } public IClassifier build(IIndex trainingIndex) { SingleLabelTreeBoostClassifier c = new SingleLabelTreeBoostClassifier(); if (_status != null) _status.operationStatus(0.0); // Construct classifiers at root level. constructClassifiersAt(Short.MIN_VALUE, trainingIndex.getCategoryDB(), trainingIndex, c); if (_status != null) _status.operationStatus(100.0); return c; } public void constructClassifiersAt(short catID, ICategoryDB catsDB, IIndex training, SingleLabelTreeBoostClassifier classifiers) { IShortIterator childs = null; if (catID == Short.MIN_VALUE) childs = training.getCategoryDB().getRootCategories(); else childs = training.getCategoryDB().getChildCategories(catID); if (!childs.hasNext()) // Nothing to do. return; JatecsLogger.status().print( "---> Start analyzing the hierarchy level owned by category <" + (catID == Short.MIN_VALUE ? "RootCategory" : training .getCategoryDB().getCategoryName(catID)) + ">." + Os.newline()); // JatecsLogger.status().println("---> Start analyzing the hierarchy // level owned by category <"+(catID == // Short.MIN_VALUE?"RootCategory":training.getCategoriesDB().getCategoryName(catID))+">."); // Create a temporary index containing only the valid categories. JatecsLogger.status().print( "Creating a temporary index containing " + training.getDocumentDB().getDocumentsCount() + " documents..."); IIndex idx = training.cloneIndex(); TShortArrayList toRemove = new TShortArrayList(); short nextCatID = childs.next(); for (short i = 0; i < training.getCategoryDB().getCategoriesCount(); i++) { if (i == nextCatID) { if (childs.hasNext()) nextCatID = childs.next(); continue; } toRemove.add(i); } idx.removeCategories(new TShortArrayListIterator(toRemove)); JatecsLogger.status().println( "done. The categories valid are " + idx.getCategoryDB().getCategoriesCount() + "."); if (_tsr != null) { JatecsLogger.status().println("Now apply TSR to index."); _tsr.computeTSR(idx); JatecsLogger.status().println("Ok. TSR applied."); } if (_weighting != null) { JatecsLogger.status().println( "Now apply weighting to document features."); idx = _weighting.computeWeights(idx); JatecsLogger.status().println("Ok. Weighting operation done."); } JatecsLogger.status().println( "The number of valid features is " + idx.getFeatureDB().getFeaturesCount() + "."); if (((TreeBoostLearnerCustomizer) _customizer).getInternalCustomizer( TreeBoostLearnerCustomizer.ALL_LEVELS, TreeBoostLearnerCustomizer.ALL_CATEGORIES) != null) _learner.setRuntimeCustomizer(((TreeBoostLearnerCustomizer) _customizer) .getInternalCustomizer( TreeBoostLearnerCustomizer.ALL_LEVELS, TreeBoostLearnerCustomizer.ALL_CATEGORIES)); // Construct the classifier for this level. IClassifier c = _learner.build(idx); // Save data at this level. short levelCatID = -1; if (catID != Short.MIN_VALUE) { String catName = training.getCategoryDB().getCategoryName(catID); short globalCatID = catsDB.getCategory(catName); classifiers._map.put(globalCatID, c); levelCatID = globalCatID; } else { classifiers._map.put(catID, c); levelCatID = catID; } // Init iterator. childs.begin(); while (childs.hasNext()) { short curCatID = childs.next(); // Get current global catID and add reference to classifier object. String catName = training.getCategoryDB().getCategoryName(curCatID); short globalCatID = catsDB.getCategory(catName); TreeBoostClassifierAddress addr = new TreeBoostClassifierAddress(); addr.level = levelCatID; addr.categoryID = idx.getCategoryDB().getCategory(catName); classifiers._mapCatLevel.put(globalCatID, addr); IShortIterator ch = training.getCategoryDB().getChildCategories( curCatID); if (!ch.hasNext()) continue; JatecsLogger .status() .println( "The child category <" + training.getCategoryDB().getCategoryName( curCatID) + "> is the owner of a subtree of categories. Analyze it."); JatecsLogger.status().print( "Select all documents positives for category " + training.getCategoryDB() .getCategoryName(curCatID) + "..."); // Select positive documents for this category. IIndex idxCur = selectPositives(curCatID, training); JatecsLogger.status().println("done."); curCatID = idxCur.getCategoryDB().getCategory(catName); // Recursion over current category. constructClassifiersAt(curCatID, catsDB, idxCur, classifiers); } } protected IIndex selectPositives(short catID, IIndex training) { // First create a new index. IIndex idx = training.cloneIndex(); IShortIterator childCats = getAllChildsCategoriesFor(idx, catID); short nextCatID = Short.MIN_VALUE; if (childCats.hasNext()) nextCatID = childCats.next(); // Remove unwanted categories. TShortArrayList toRemove = new TShortArrayList(); for (short i = 0; i < training.getCategoryDB().getCategoriesCount(); i++) { if (i == nextCatID) { if (childCats.hasNext()) nextCatID = childCats.next(); continue; } toRemove.add(i); } toRemove.sort(); idx.removeCategories(new TShortArrayListIterator(toRemove)); // Remove unwanted documents. TIntArrayList docsToRemove = new TIntArrayList(); IIntIterator docs = idx.getDocumentDB().getDocuments(); while (docs.hasNext()) { int docID = docs.next(); IShortIterator curCats = idx.getClassificationDB() .getDocumentCategories(docID); if (!curCats.hasNext()) docsToRemove.add(docID); } docsToRemove.sort(); idx.removeDocuments(new TIntArrayListIterator(docsToRemove), false); return idx; } protected IShortIterator getAllChildsCategoriesFor(IIndex idx, short catID) { TShortArrayList childs = new TShortArrayList(); IShortIterator curChilds = idx.getCategoryDB() .getChildCategories(catID); while (curChilds.hasNext()) { short curCatID = curChilds.next(); IShortIterator c = getAllChildsCategoriesFor(idx, curCatID); // Merge the results with current list. while (c.hasNext()) { short id = c.next(); if (!childs.contains(id)) childs.add(id); } } // Add this category. if (!childs.contains(catID)) childs.add(catID); childs.sort(); return new TShortArrayListIterator(childs); } @Override public ILearnerRuntimeCustomizer getRuntimeCustomizer(short catID) { return null; } @Override public IClassifier mergeClassifiers(Vector<IClassifier> classifiers) { return null; } @Override public void setRuntimeCustomizer( Vector<ILearnerRuntimeCustomizer> customizers) { } }