meka.classifiers.multilabel.Evaluation Java Examples

The following examples show how to use meka.classifiers.multilabel.Evaluation. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TrainTestSplit.java    From meka with GNU General Public License v3.0 6 votes vote down vote up
public static void main(String[] args) throws Exception {
  if (args.length != 2)
    throw new IllegalArgumentException("Required arguments: <dataset> <percentage>");

  System.out.println("Loading data: " + args[0]);
  Instances data = DataSource.read(args[0]);
  MLUtils.prepareData(data);

  double percentage = Double.parseDouble(args[1]);
  int trainSize = (int) (data.numInstances() * percentage / 100.0);
  Instances train = new Instances(data, 0, trainSize);
  Instances test = new Instances(data, trainSize, data.numInstances() - trainSize);

  System.out.println("Build BR classifier on " + percentage + "%");
  BR classifier = new BR();
  // further configuration of classifier
  classifier.buildClassifier(train);

  System.out.println("Evaluate BR classifier on " + (100.0 - percentage) + "%");
  String top = "PCut1";
  String vop = "3";
  Result result = Evaluation.evaluateModel(classifier, train, test, top, vop);

  System.out.println(result);
}
 
Example #2
Source File: StatUtils.java    From meka with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Main - do some tests.
 */
public static void main(String args[]) throws Exception {
	Instances D = Evaluation.loadDataset(args);
	MLUtils.prepareData(D);
	int L = D.classIndex();

	double CD[][] = null;

	if (args[2].equals("L")) {
		String I = "I";
		if (args.length >= 3) 
			I = args[3];
		CD = StatUtils.LEAD(D, new SMO(), new Random(), I);
	}
	else {
		CD = StatUtils.margDepMatrix(D,args[2]);
	}
	System.out.println(MatrixUtils.toString(CD, "M" + args[2]));
}
 
Example #3
Source File: MicroCurve.java    From meka with GNU General Public License v3.0 6 votes vote down vote up
public static void main(String[] args) throws Exception {
  if (args.length != 1)
    throw new IllegalArgumentException("Required arguments: <dataset>");

  System.out.println("Loading data: " + args[0]);
  Instances data = DataSource.read(args[0]);
  MLUtils.prepareData(data);

  System.out.println("Cross-validate BR classifier");
  BR classifier = new BR();
  // further configuration of classifier
  String top = "PCut1";
  String vop = "3";
  Result result = Evaluation.cvModel(classifier, data, 10, top, vop);

  JFrame frame = new JFrame("Micro curve");
  frame.setDefaultCloseOperation(JDialog.EXIT_ON_CLOSE);
  frame.getContentPane().setLayout(new BorderLayout());
  Instances performance = (Instances) result.getMeasurement(CURVE_DATA_MICRO);
  try {
    VisualizePanel panel = createPanel(performance);
    frame.getContentPane().add(panel, BorderLayout.CENTER);
  }
  catch (Exception ex) {
    System.err.println("Failed to create plot!");
    ex.printStackTrace();
  }
  frame.setSize(800, 600);
  frame.setLocationRelativeTo(null);
  frame.setVisible(true);
}
 
Example #4
Source File: CrossValidate.java    From meka with GNU General Public License v3.0 6 votes vote down vote up
public static void main(String[] args) throws Exception {
  if (args.length != 1)
    throw new IllegalArgumentException("Required arguments: <dataset>");

  System.out.println("Loading data: " + args[0]);
  Instances data = DataSource.read(args[0]);
  MLUtils.prepareData(data);

  int numFolds = 10;
  System.out.println("Cross-validate BR classifier using " + numFolds + " folds");
  BR classifier = new BR();
  // further configuration of classifier
  String top = "PCut1";
  String vop = "3";
  Result result = Evaluation.cvModel(classifier, data, numFolds, top, vop);

  System.out.println(result);
}
 
Example #5
Source File: MekaClassifierTest.java    From AILibs with GNU Affero General Public License v3.0 6 votes vote down vote up
@Test
public void testFitAndPredictWithHoldoutSplitter() throws Exception {
	BR br = new BR();
	br.buildClassifier(splitterSplit.get(0).getInstances());
	Result res = Evaluation.testClassifier(br, splitterSplit.get(1).getInstances());
	double[][] mekaPredictions = res.allPredictions();

	MekaClassifier classifier = new MekaClassifier(new BR());
	classifier.fit(splitterSplit.get(0));
	IMultiLabelClassificationPredictionBatch pred = classifier.predict(splitterSplit.get(1));

	assertEquals("Number of predictions is not consistent.", splitterSplit.get(1).size(), pred.getNumPredictions());

	double[][] jaicorePredictions = pred.getPredictionMatrix();
	assertEquals("Length of prediction matrices is not consistent.", mekaPredictions.length, jaicorePredictions.length);
	assertEquals("Width of prediction matrices is not consistent.", mekaPredictions[0].length, jaicorePredictions[0].length);

	for (int i = 0; i < mekaPredictions.length; i++) {
		for (int j = 0; j < mekaPredictions[i].length; j++) {
			assertEquals("The prediction for instance " + i + " and label " + j + " is not consistent.", mekaPredictions[i][j], jaicorePredictions[i][j], 1E-8);
		}
	}
}
 
Example #6
Source File: ARAMNetwork.java    From meka with GNU General Public License v3.0 5 votes vote down vote up
public static void main(String [] argv) {

    try {
    	Evaluation.runExperiment(((MultiLabelClassifier) new ARAMNetwork()), argv);
    } catch (Exception e) {
      e.printStackTrace();
      System.err.println(e.getMessage());
    }
  }
 
Example #7
Source File: StatUtils.java    From meka with GNU General Public License v3.0 5 votes vote down vote up
public static double[][] LEAD(Instances D, Classifier h, Random r, String MDType)  throws Exception {
	Instances D_r = new Instances(D);
	D_r.randomize(r);
	Instances D_train = new Instances(D_r,0,D_r.numInstances()*60/100);
	Instances D_test = new Instances(D_r,D_train.numInstances(),D_r.numInstances()-D_train.numInstances());
	BR br = new BR();
	br.setClassifier(h);
	Result result = Evaluation.evaluateModel((MultiLabelClassifier)br,D_train,D_test,"PCut1","1");

	return LEAD(D_test, result, MDType);
}
 
Example #8
Source File: StatUtils.java    From meka with GNU General Public License v3.0 5 votes vote down vote up
/**
 * LEAD - Performs LEAD on dataset 'D', using BR with base classifier 'h', under random seed 'r'.
 * <br>
 * WARNING: changing this method will affect the perfomance of e.g., BCC -- on the other hand the original BCC paper did not use LEAD, so don't worry.
 */
public static double[][] LEAD(Instances D, Classifier h, Random r)  throws Exception {
	Instances D_r = new Instances(D);
	D_r.randomize(r);
	Instances D_train = new Instances(D_r,0,D_r.numInstances()*60/100);
	Instances D_test = new Instances(D_r,D_train.numInstances(),D_r.numInstances()-D_train.numInstances());
	BR br = new BR();
	br.setClassifier(h);
	Result result = Evaluation.evaluateModel((MultiLabelClassifier)br,D_train,D_test,"PCut1","1");
	return LEAD2(D_test,result);
}
 
Example #9
Source File: WARAM.java    From meka with GNU General Public License v3.0 5 votes vote down vote up
public static void main(String [] argv) {

    try {
    	Evaluation.runExperiment((MultiLabelClassifier) new WARAM(), argv);
    } catch (Exception e) {
      e.printStackTrace();
      System.err.println(e.getMessage());
    }
  }
 
Example #10
Source File: ARAMNetworkSparseHT.java    From meka with GNU General Public License v3.0 5 votes vote down vote up
public static void main(String [] argv) {

    try {
    	Evaluation.runExperiment(((MultiLabelClassifier) new WvARAM()), argv);
    } catch (Exception e) {
      e.printStackTrace();
      System.err.println(e.getMessage());
    }
  }
 
Example #11
Source File: ARAMNetworkSparseHT_Strange.java    From meka with GNU General Public License v3.0 5 votes vote down vote up
public static void main(String [] argv) {

    try {
    	Evaluation.runExperiment(new WvARAM(), argv);
    } catch (Exception e) {
      e.printStackTrace();
      System.err.println(e.getMessage());
    }
  }
 
Example #12
Source File: ARAMNetworkSparseV.java    From meka with GNU General Public License v3.0 5 votes vote down vote up
public static void main(String [] argv) {

    try {
    	Evaluation.runExperiment(((MultiLabelClassifier) new WvARAM()), argv);
    } catch (Exception e) {
      e.printStackTrace();
      System.err.println(e.getMessage());
    }
  }
 
Example #13
Source File: ARAMNetworkSparse.java    From meka with GNU General Public License v3.0 5 votes vote down vote up
public static void main(String [] argv) {

    try {
    	Evaluation.runExperiment(((MultiLabelClassifier) new WvARAM()), argv);
    } catch (Exception e) {
      e.printStackTrace();
      System.err.println(e.getMessage());
    }
  }
 
Example #14
Source File: TrainTestSet.java    From meka with GNU General Public License v3.0 5 votes vote down vote up
public static void main(String[] args) throws Exception {
  if (args.length != 2)
    throw new IllegalArgumentException("Required arguments: <train> <test>");

  System.out.println("Loading train: " + args[0]);
  Instances train = DataSource.read(args[0]);
  MLUtils.prepareData(train);

  System.out.println("Loading test: " + args[1]);
  Instances test = DataSource.read(args[1]);
  MLUtils.prepareData(test);

  // compatible?
  String msg = train.equalHeadersMsg(test);
  if (msg != null)
    throw new IllegalStateException(msg);

  System.out.println("Build BR classifier on " + args[0]);
  BR classifier = new BR();
  // further configuration of classifier
  classifier.buildClassifier(train);

  System.out.println("Evaluate BR classifier on " + args[1]);
  String top = "PCut1";
  String vop = "3";
  Result result = Evaluation.evaluateModel(classifier, train, test, top, vop);

  System.out.println(result);
}
 
Example #15
Source File: WvARAM.java    From meka with GNU General Public License v3.0 5 votes vote down vote up
/**
 * Main method for testing this class.
 *
 * @param argv the options
 */
public static void main(String [] argv) {

  try {
  	Evaluation.runExperiment(new WvARAM(), argv);
  } catch (Exception e) {
    e.printStackTrace();
    System.err.println(e.getMessage());
  }
  System.out.println("Done");
}
 
Example #16
Source File: ARAMNetworkSparseH.java    From meka with GNU General Public License v3.0 5 votes vote down vote up
public static void main(String [] argv) {

    try {
    	Evaluation.runExperiment(new WvARAM(), argv);
    } catch (Exception e) {
      e.printStackTrace();
      System.err.println(e.getMessage());
    }
  }
 
Example #17
Source File: SCC.java    From meka with GNU General Public License v3.0 5 votes vote down vote up
/**
 * Test classifier h, on dataset D, under super-class partition 'partition'.
 * <br>
 * TODO should be able to use something out of meka.classifiers.Evaluation instead of all this ...
 */
public Result testClassifier(Classifier h, Instances D_train, Instances D_test, int partition[][]) throws Exception {

	trainClassifier(m_Classifier,D_train,partition);

	Result result = Evaluation.testClassifier((ProblemTransformationMethod)h, D_test);

	if (h instanceof MultiTargetClassifier || Evaluation.isMT(D_test)) {
		result.setInfo("Type","MT");
	}
	else if (h instanceof ProblemTransformationMethod) {
		result.setInfo("Threshold", MLEvalUtils.getThreshold(result.predictions, D_train, "PCut1"));
		result.setInfo("Type","ML");
	}

	result.setValue("N_train",D_train.numInstances());
	result.setValue("N_test",D_test.numInstances());
	result.setValue("LCard_train",MLUtils.labelCardinality(D_train));
	result.setValue("LCard_test",MLUtils.labelCardinality(D_test));

	//result.setValue("Build_time",(after - before)/1000.0);
	//result.setValue("Test_time",(after_test - before_test)/1000.0);
	//result.setValue("Total_time",(after_test - before)/1000.0);

	result.setInfo("Classifier_name",h.getClass().getName());
	//result.setInfo("Classifier_ops", Arrays.toString(h.getOptions()));
	result.setInfo("Classifier_info",h.toString());
	result.setInfo("Dataset_name",MLUtils.getDatasetName(D_test));

	result.output = Result.getStats(result,"1");
	return result;
}
 
Example #18
Source File: PrecisionRecall.java    From meka with GNU General Public License v3.0 5 votes vote down vote up
public static void main(String[] args) throws Exception {
  if (args.length != 1)
    throw new IllegalArgumentException("Required arguments: <dataset>");

  System.out.println("Loading data: " + args[0]);
  Instances data = DataSource.read(args[0]);
  MLUtils.prepareData(data);

  System.out.println("Cross-validate BR classifier");
  BR classifier = new BR();
  // further configuration of classifier
  String top = "PCut1";
  String vop = "3";
  Result result = Evaluation.cvModel(classifier, data, 10, top, vop);

  JFrame frame = new JFrame("Precision-recall");
  frame.setDefaultCloseOperation(JDialog.EXIT_ON_CLOSE);
  frame.getContentPane().setLayout(new BorderLayout());
  JTabbedPane tabbed = new JTabbedPane();
  frame.getContentPane().add(tabbed, BorderLayout.CENTER);
  Instances[] curves = (Instances[]) result.getMeasurement(CURVE_DATA);
  for (int i = 0; i < curves.length; i++) {
    try {
      ThresholdVisualizePanel panel = createPanel(curves[i], "Label " + i);
      tabbed.addTab("" + i, panel);
    }
    catch (Exception ex) {
      System.err.println("Failed to create plot for label " + i);
      ex.printStackTrace();
    }
  }
  frame.setSize(800, 600);
  frame.setLocationRelativeTo(null);
  frame.setVisible(true);
}
 
Example #19
Source File: MacroCurve.java    From meka with GNU General Public License v3.0 5 votes vote down vote up
public static void main(String[] args) throws Exception {
  if (args.length != 1)
    throw new IllegalArgumentException("Required arguments: <dataset>");

  System.out.println("Loading data: " + args[0]);
  Instances data = DataSource.read(args[0]);
  MLUtils.prepareData(data);

  System.out.println("Cross-validate BR classifier");
  BR classifier = new BR();
  // further configuration of classifier
  String top = "PCut1";
  String vop = "3";
  Result result = Evaluation.cvModel(classifier, data, 10, top, vop);

  JFrame frame = new JFrame("Macro curve");
  frame.setDefaultCloseOperation(JDialog.EXIT_ON_CLOSE);
  frame.getContentPane().setLayout(new BorderLayout());
  Instances performance = (Instances) result.getMeasurement(CURVE_DATA_MACRO);
  try {
    VisualizePanel panel = createPanel(performance);
    frame.getContentPane().add(panel, BorderLayout.CENTER);
  }
  catch (Exception ex) {
    System.err.println("Failed to create plot!");
    ex.printStackTrace();
  }
  frame.setSize(800, 600);
  frame.setLocationRelativeTo(null);
  frame.setVisible(true);
}
 
Example #20
Source File: ROC.java    From meka with GNU General Public License v3.0 5 votes vote down vote up
public static void main(String[] args) throws Exception {
  if (args.length != 1)
    throw new IllegalArgumentException("Required arguments: <dataset>");

  System.out.println("Loading data: " + args[0]);
  Instances data = DataSource.read(args[0]);
  MLUtils.prepareData(data);

  System.out.println("Cross-validate BR classifier");
  BR classifier = new BR();
  // further configuration of classifier
  String top = "PCut1";
  String vop = "3";
  Result result = Evaluation.cvModel(classifier, data, 10, top, vop);

  JFrame frame = new JFrame("ROC");
  frame.setDefaultCloseOperation(JDialog.EXIT_ON_CLOSE);
  frame.getContentPane().setLayout(new BorderLayout());
  JTabbedPane tabbed = new JTabbedPane();
  frame.getContentPane().add(tabbed, BorderLayout.CENTER);
  Instances[] curves = (Instances[]) result.getMeasurement(CURVE_DATA);
  for (int i = 0; i < curves.length; i++) {
    try {
      ThresholdVisualizePanel panel = createPanel(curves[i], "Label " + i);
      tabbed.addTab("" + i, panel);
    }
    catch (Exception ex) {
      System.err.println("Failed to create plot for label " + i);
      ex.printStackTrace();
    }
  }
  frame.setSize(800, 600);
  frame.setLocationRelativeTo(null);
  frame.setVisible(true);
}
 
Example #21
Source File: ExportPredictionsOnTestSet.java    From meka with GNU General Public License v3.0 5 votes vote down vote up
public static void main(String[] args) throws Exception {
  if (args.length != 3)
    throw new IllegalArgumentException("Required arguments: <train> <test> <output>");

  System.out.println("Loading train: " + args[0]);
  Instances train = DataSource.read(args[0]);
  MLUtils.prepareData(train);

  System.out.println("Loading test: " + args[1]);
  Instances test = DataSource.read(args[1]);
  MLUtils.prepareData(test);

  // compatible?
  String msg = train.equalHeadersMsg(test);
  if (msg != null)
    throw new IllegalStateException(msg);

  System.out.println("Build BR classifier on " + args[0]);
  BR classifier = new BR();
  // further configuration of classifier
  classifier.buildClassifier(train);

  System.out.println("Evaluate BR classifier on " + args[1]);
  String top = "PCut1";
  String vop = "3";
  Result result = Evaluation.evaluateModel(classifier, train, test, top, vop);

  System.out.println(result);

  System.out.println("Saving predictions test set to " + args[2]);
  Instances performance = Result.getPredictionsAsInstances(result);
  DataSink.write(args[2], performance);
}