/* * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package meka.classifiers.multitarget; /** * CR.java - The Class-Relevance Method. * (The generalised, multi-target version of the Binary Relevance (BR) method). * @see BR.java * @version Jan 2012 * @author Jesse Read ([email protected]) */ import weka.classifiers.AbstractClassifier; import meka.classifiers.multilabel.ProblemTransformationMethod; import weka.core.Instance; import weka.core.Instances; import weka.classifiers.trees.J48; import meka.core.MLUtils; import weka.core.RevisionUtils; import weka.core.Utils; public class CR extends meka.classifiers.multilabel.BR implements MultiTargetClassifier { /** for serialization. */ private static final long serialVersionUID = 1627371180786293843L; protected Instances m_Templates[] = null; // TEMPLATES public CR() { // default classifier for GUI this.m_Classifier = new J48(); } @Override protected String defaultClassifierString() { // default classifier for CLI return "weka.classifiers.trees.J48"; } /** * Description to display in the GUI. * * @return the description */ @Override public String globalInfo() { return "The Class-Relevance Method.\n" + "(The generalised, multi-target version of the Binary Relevance (BR) method)."; } @Override public void buildClassifier(Instances D) throws Exception { testCapabilities(D); int L = D.classIndex(); if(getDebug()) System.out.print("Creating "+L+" models ("+m_Classifier.getClass().getName()+"): "); m_MultiClassifiers = AbstractClassifier.makeCopies(m_Classifier,L); m_Templates = new Instances[L]; for(int j = 0; j < L; j++) { //Select only class attribute 'j' m_Templates[j] = MLUtils.keepAttributesAt(new Instances(D),new int[]{j},L); m_Templates[j].setClassIndex(0); //Build the classifier for that class m_MultiClassifiers[j].buildClassifier(m_Templates[j]); if(getDebug()) System.out.print(" " + (m_Templates[j].classAttribute().name())); m_Templates[j] = new Instances(m_Templates[j], 0); } } @Override public double[] distributionForInstance(Instance x) throws Exception { int L = x.classIndex(); double y[] = new double[L*2]; for (int j = 0; j < L; j++) { Instance x_j = (Instance)x.copy(); x_j.setDataset(null); x_j = MLUtils.keepAttributesAt(x_j,new int[]{j},L); x_j.setDataset(m_Templates[j]); double w[] = m_MultiClassifiers[j].distributionForInstance(x_j); // e.g. [0.1, 0.8, 0.1] y[j] = Utils.maxIndex(w); // e.g. 1 y[L+j] = w[(int)y[j]]; // e.g. 0.8 } return y; } @Override public String getRevision() { return RevisionUtils.extract("$Revision: 9117 $"); } public static void main(String args[]) { ProblemTransformationMethod.evaluation(new CR(), args); } }