weka.classifiers.trees.M5P Java Examples

The following examples show how to use weka.classifiers.trees.M5P. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: MultiResponseModelTrees.java    From tsml with GNU General Public License v3.0 6 votes vote down vote up
@Override
public void buildClassifier(Instances data) throws Exception {
    //creating the 2class version of the insts
    numericClassInsts = new Instances(data);
    numericClassInsts.setClassIndex(0); //temporary
    numericClassInsts.deleteAttributeAt(numericClassInsts.numAttributes()-1);
    Attribute newClassAtt = new Attribute("newClassVal"); //numeric class
    numericClassInsts.insertAttributeAt(newClassAtt, numericClassInsts.numAttributes());
    numericClassInsts.setClassIndex(numericClassInsts.numAttributes()-1); //temporary

    //and building the regressors
    regressors = new M5P[data.numClasses()];
    double[] trueClassVals = data.attributeToDoubleArray(data.classIndex());
    for (int c = 0; c < data.numClasses(); c++) {

        for (int i = 0; i < numericClassInsts.numInstances(); i++) {
            //if this inst is of the class we're currently handling (c), set new class val to 1 else 0
            double cval = trueClassVals[i] == c ? 1 : 0; 
            numericClassInsts.instance(i).setClassValue(cval);
        }    

        regressors[c] = new M5P();
        regressors[c].buildClassifier(numericClassInsts);
    }
}
 
Example #2
Source File: COMT2.java    From bestconf with Apache License 2.0 5 votes vote down vote up
private ArrayList<Branch2> getLeavesInfoForM5P(M5P model){
	ArrayList<Branch2> retval = new ArrayList<Branch2>();
	ArrayList<RuleNode> leafNodes = new ArrayList<RuleNode>();
	model.getM5RootNode().returnLeaves(new ArrayList[]{leafNodes});
	
	for(RuleNode leaf : leafNodes){
		Branch2 branch = new Branch2();
		ArrayList<PreConstructedLinearModel> lmodel = new ArrayList<PreConstructedLinearModel>();
		lmodel.add(leaf.getModel());
		branch.setLinearModels(lmodel);
		
		Map<Attribute,Range<Double>> rangeMap = branch.getRangeMap();
		RuleNode parent = leaf, child;
		while(parent.parentNode()!=null){
			child = parent;
			parent = parent.parentNode();
			
			Attribute att = this.labeledInstances.attribute(parent.splitAtt());
			Range<Double> previous = null;
			if(parent.leftNode()==child)
				previous = rangeMap.put(att,Range.atMost(parent.splitVal()));
			else
				previous = rangeMap.put(att, Range.greaterThan(parent.splitVal()));
			//the attribute is visited previously
			if(previous!=null){
				 previous = rangeMap.get(att).intersection(previous);
				 rangeMap.put(att, previous);
			}
		}
		
		retval.add(branch);
	}
	
	return retval;
}
 
Example #3
Source File: COMT2.java    From bestconf with Apache License 2.0 5 votes vote down vote up
private static M5P buildModel(Instances modelInstances, int numOfInstanceInLeaf) throws Exception{
	M5P retval = new M5P();
	retval.setSaveInstances(true);
	retval.setOptions(Utils.splitOptions("-N -L -M "+numOfInstanceInLeaf));
	retval.buildClassifier(modelInstances);
	return retval;
}
 
Example #4
Source File: COMT2.java    From bestconf with Apache License 2.0 5 votes vote down vote up
private static Instances getSiblings(M5P modelTree, Instance ins){
	RuleNode node = modelTree.getM5RootNode();
	
	while(!node.isLeaf()){
		if(ins.value(node.splitAtt())<=node.splitVal()){
			node = node.leftNode();
		}else {
			node = node.rightNode();
		}
	}
	
	return node.zyqGetTrainingSet();
}
 
Example #5
Source File: COMT2.java    From bestconf with Apache License 2.0 5 votes vote down vote up
private static double computeOmegaDelta(M5P model, M5P modelPi, Instances omega) throws Exception{
	double retval = 0., y;
	Enumeration<Instance> enu = omega.enumerateInstances();
	int idxClass = omega.classIndex();
	Instance ins;
	while(enu.hasMoreElements()){
		ins = enu.nextElement();
		y = ins.value(idxClass);
		retval += Math.pow(y-model.classifyInstance(ins), 2)-Math.pow(y-modelPi.classifyInstance(ins), 2);
	}
	return retval;
}
 
Example #6
Source File: COMT2.java    From bestconf with Apache License 2.0 4 votes vote down vote up
private void train() throws Exception{
	models = new M5P[ModelNum];
	for(int i=0;i<ModelNum;i++){
		models[i] = buildModel(labeledInstances, M[i]);
	}
	
	for(int i=0;i<this.comtIterations;i++){
		ArrayList<Instance>[] InstancePiSet = new ArrayList[ModelNum];
		for(int j=0;j<ModelNum;j++)
			InstancePiSet[j] = new ArrayList<Instance>();
		
		for(int m=0;m<ModelNum;m++){
			double maxDelta = 0;
			Instance maxDeltaXY = null;
			Enumeration<Instance> enu = this.unlabeledInstances.enumerateInstances();
			
			while(enu.hasMoreElements()){
				Instance ulIns = enu.nextElement();
				Instances omega = getSiblings(models[m], ulIns);
				double y = models[m].classifyInstance(ulIns);
				if(indexOfClass==-1)
					indexOfClass = labeledInstances.classIndex();
				ulIns.setValue(indexOfClass, y);
				
				Instances instancesPi = new Instances(models[m].getM5RootNode().zyqGetTrainingSet());
				instancesPi.add(ulIns);
				M5P modelPi = buildModel(instancesPi, M[m]);
				double delta = computeOmegaDelta(models[m],modelPi,omega);
				if(maxDelta<delta){
					maxDelta = delta;
					maxDeltaXY = ulIns;
				}
			}
			
			//now check facts about delta
			if(maxDelta>0){
				InstancePiSet[m].add(maxDeltaXY);
				this.unlabeledInstances.delete(this.unlabeledInstances.indexOf(maxDeltaXY));
			}
		}//check for both model
		
		boolean toExit = true;
		for(int m=0;m<ModelNum;m++){
			if(InstancePiSet[m].size()>0){
				toExit = false;
				break;
			}
		}
		
		if(toExit)
			break;
		else{
			//update the models
			int toGen = 0;
			for(int m=0;m<ModelNum;m++){
				Instances set = models[m].getM5RootNode().zyqGetTrainingSet();
				toGen += InstancePiSet[m].size();
				for(Instance ins : InstancePiSet[m])
					set.add(ins);
				
				models[m] = buildModel(set, M[m]);
			}
			
			//Replenish pool U' to size p
			Instances toAdd = retrieveMore(toGen);
			unlabeledInstances.addAll(toAdd);
		}//we will go to another round of iteration
	}//iterate for a number of rounds or break out on empty InstancesPiSets
	
	//now we have the model as y = 0.5*sum(models[m].predict(x))
}
 
Example #7
Source File: RegressionTask.java    From Machine-Learning-in-Java with MIT License 4 votes vote down vote up
public static void main(String[] args) throws Exception {

		/*
		 * Load data
		 */
		CSVLoader loader = new CSVLoader();
		loader.setFieldSeparator(",");
		loader.setSource(new File("data/ENB2012_data.csv"));
		Instances data = loader.getDataSet();

		// System.out.println(data);

		/*
		 * Build regression models
		 */
		// set class index to Y1 (heating load)
		data.setClassIndex(data.numAttributes() - 2);
		// remove last attribute Y2
		Remove remove = new Remove();
		remove.setOptions(new String[] { "-R", data.numAttributes() + "" });
		remove.setInputFormat(data);
		data = Filter.useFilter(data, remove);

		// build a regression model
		LinearRegression model = new LinearRegression();
		model.buildClassifier(data);
		System.out.println(model);

		// 10-fold cross-validation
		Evaluation eval = new Evaluation(data);
		eval.crossValidateModel(model, data, 10, new Random(1), new String[] {});
		System.out.println(eval.toSummaryString());
		double coef[] = model.coefficients();
		System.out.println();

		// build a regression tree model

		M5P md5 = new M5P();
		md5.setOptions(new String[] { "" });
		md5.buildClassifier(data);
		System.out.println(md5);

		// 10-fold cross-validation
		eval.crossValidateModel(md5, data, 10, new Random(1), new String[] {});
		System.out.println(eval.toSummaryString());
		System.out.println();
		
		
		
		
		/*
		 * Bonus: Build additional models 
		 */
		
		// ZeroR modelZero = new ZeroR();
		//
		//
		//
		//
		//
		// REPTree modelTree = new REPTree();
		// modelTree.buildClassifier(data);
		// System.out.println(modelTree);
		// eval = new Evaluation(data);
		// eval.crossValidateModel(modelTree, data, 10, new Random(1), new
		// String[]{});
		// System.out.println(eval.toSummaryString());
		//
		// SMOreg modelSVM = new SMOreg();
		//
		// MultilayerPerceptron modelPerc = new MultilayerPerceptron();
		//
		// GaussianProcesses modelGP = new GaussianProcesses();
		// modelGP.buildClassifier(data);
		// System.out.println(modelGP);
		// eval = new Evaluation(data);
		// eval.crossValidateModel(modelGP, data, 10, new Random(1), new
		// String[]{});
		// System.out.println(eval.toSummaryString());

		/*
		 * Bonus: Save ARFF
		 */
		// ArffSaver saver = new ArffSaver();
		// saver.setInstances(data);
		// saver.setFile(new File(args[1]));
		// saver.setDestination(new File(args[1]));
		// saver.writeBatch();

	}