/* * Copyright (c) 2015-2016 The University Of Sheffield. * * This file is part of gateplugin-LearningFramework * (see https://github.com/GateNLP/gateplugin-LearningFramework). * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as published by * the Free Software Foundation, either version 2.1 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public License * along with this software. If not, see <http://www.gnu.org/licenses/>. */ package gate.plugin.learningframework.engines; import cc.mallet.classify.BalancedWinnowTrainer; import cc.mallet.classify.C45Trainer; import cc.mallet.classify.Classification; import cc.mallet.classify.Classifier; import cc.mallet.classify.ClassifierTrainer; import cc.mallet.classify.DecisionTreeTrainer; import cc.mallet.classify.MaxEntTrainer; import cc.mallet.classify.WinnowTrainer; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; import cc.mallet.types.InstanceList.CrossValidationIterator; import cc.mallet.types.LabelVector; import cc.mallet.types.Labeling; import gate.Annotation; import gate.AnnotationSet; import gate.plugin.learningframework.EvaluationMethod; import gate.plugin.learningframework.ModelApplication; import gate.plugin.learningframework.data.CorpusRepresentationMalletTarget; import static gate.plugin.learningframework.engines.Engine.FILENAME_MODEL; import gate.plugin.learningframework.mallet.LFPipe; import gate.util.GateRuntimeException; import java.io.File; import java.io.InputStream; import java.io.ObjectInputStream; import java.net.URL; import java.util.ArrayList; import java.util.List; import java.util.Random; import org.apache.log4j.Logger; import static gate.plugin.learningframework.LFUtils.newURL; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.text.SimpleDateFormat; import java.util.Date; /** * * @author Johann Petrak */ public class EngineMBMalletClass extends EngineMBMallet { private static final Logger LOGGER = Logger.getLogger(EngineMBMalletClass.class); public EngineMBMalletClass() { } @Override public void trainModel(File dataDirectory, String instanceType, String parms) { //System.err.println("EngineMalletClass.trainModel: trainer="+trainer); //System.err.println("EngineMalletClass.trainModel: CR="+corpusRepresentation); model=((ClassifierTrainer) trainer).train(corpusRepresentation.getRepresentationMallet()); updateInfo(); SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); info.modelWhenTrained = sdf.format(new Date()); info.algorithmParameters = parms; info.save(dataDirectory); featureInfo.save(dataDirectory); } @Override public List<ModelApplication> applyModel( AnnotationSet instanceAS, AnnotationSet inputAS, AnnotationSet sequenceAS, String parms) { // NOTE: the crm should be of type CorpusRepresentationMalletClass for this to work! if(!(corpusRepresentation instanceof CorpusRepresentationMalletTarget)) { throw new GateRuntimeException("Cannot perform classification with data from "+corpusRepresentation.getClass()); } CorpusRepresentationMalletTarget data = (CorpusRepresentationMalletTarget)corpusRepresentation; data.stopGrowth(); List<ModelApplication> gcs = new ArrayList<>(); LFPipe pipe = (LFPipe)data.getRepresentationMallet().getPipe(); Classifier classifier = (Classifier)model; // iterate over the instance annotations and create mallet instances for(Annotation instAnn : instanceAS.inDocumentOrder()) { Instance inst = data.extractIndependentFeatures(instAnn, inputAS); inst = pipe.instanceFrom(inst); Classification classification = classifier.classify(inst); Labeling labeling = classification.getLabeling(); LabelVector labelvec = labeling.toLabelVector(); List<String> classes = new ArrayList<>(labelvec.numLocations()); List<Double> confidences = new ArrayList<>(labelvec.numLocations()); for(int i=0; i<labelvec.numLocations(); i++) { classes.add(labelvec.getLabelAtRank(i).toString()); confidences.add(labelvec.getValueAtRank(i)); } ModelApplication gc = new ModelApplication(instAnn, labeling.getBestLabel().toString(), labeling.getBestValue(), classes, confidences); //System.err.println("ADDING GC "+gc); // now save the class in our special class feature on the instance as well instAnn.getFeatures().put("gate.LF.target",labeling.getBestLabel().toString()); gcs.add(gc); } data.startGrowth(); return gcs; } @Override public void initializeAlgorithm(Algorithm algorithm, String parms) { // if this is one of the algorithms were we need to deal with parameters in some way, // use the non-empty constructor, otherwise just instanciate the trainer class. // But only bother if we have a parameter at all if (parms == null || parms.trim().isEmpty()) { // no parameters, just instantiate the class Class<?> trainerClass = algorithm.getTrainerClass(); try { @SuppressWarnings("unchecked") Constructor<?> tmpc = trainerClass.getDeclaredConstructor(); trainer = tmpc.newInstance(); } catch (IllegalAccessException | IllegalArgumentException | InstantiationException | NoSuchMethodException | SecurityException | InvocationTargetException ex) { throw new GateRuntimeException("Could not create trainer instance for " + trainerClass, ex); } } else { // there are parameters, so if it is one of the algorithms were we support setting // a parameter do this if (algorithm.equals(AlgorithmClassification.MalletC45_CL_MR)) { Parms ps = new Parms(parms, "m:maxDepth:i", "p:prune:B","n:minNumInsts:i"); int maxDepth = (int)ps.getValueOrElse("maxDepth", 0); int minNumInsts = (int)ps.getValueOrElse("minNumInsts", 2); boolean prune = (boolean)ps.getValueOrElse("prune",true); C45Trainer c45trainer; if(maxDepth > 0) { if(!prune) { c45trainer = new C45Trainer(maxDepth,false); } else { c45trainer = new C45Trainer(maxDepth,true); } } else { c45trainer = new C45Trainer(prune); } c45trainer.setMinNumInsts(minNumInsts); trainer = c45trainer; } else if(algorithm.equals(AlgorithmClassification.MalletDecisionTree_CL_MR)) { DecisionTreeTrainer dtTrainer = new DecisionTreeTrainer(); Parms ps = new Parms(parms, "m:maxDepth:i", "i:minInfoGainSplit:d"); int maxDepth = (int)ps.getValueOrElse("maxDepth", DecisionTreeTrainer.DEFAULT_MAX_DEPTH); double minIGS = (double)ps.getValueOrElse("minInfoGainSplit",DecisionTreeTrainer.DEFAULT_MIN_INFO_GAIN_SPLIT); dtTrainer.setMaxDepth(maxDepth); dtTrainer.setMinInfoGainSplit(minIGS); trainer = dtTrainer; } else if(algorithm.equals(AlgorithmClassification.MalletMaxEnt_CL_MR)) { MaxEntTrainer tr = new MaxEntTrainer(); Parms ps = new Parms(parms, "v:gaussianPriorVariance:d", "l:l1Weight:d", "i:numIterations:i"); // TODO: the default values cannot be taken from MaxEntTrainer because // they are not public there double gaussianPriorVariance = (double)ps.getValueOrElse("gaussianPriorVariance", 1.0); tr.setGaussianPriorVariance(gaussianPriorVariance); double l1Weight = (double)ps.getValueOrElse("l1Weight", 0.0); tr.setL1Weight(l1Weight); int iters = (int)ps.getValueOrElse("numIterations", Integer.MAX_VALUE); tr.setNumIterations(iters); trainer = tr; // NOTE: for AdaBoost, use this method recursively to first initialize // the trainer with the base trainer. The parameters should contain // something like -A ALGNAME -N numRounds -a -b ... // where ALGNAME is an AlgorithmClassification constant and N is the // numRounds parameter for AdaBoost[M2] and all the other parameters // are for the base algorithm initialization } else if(algorithm.equals(AlgorithmClassification.MalletBalancedWinnow_CL_MR)) { Parms ps = new Parms(parms, "e:epsilon:d", "d:delta:d", "i:maxIterations:i", "c:coolingRate:d"); double epsilon = (double)ps.getValueOrElse("epsilon", BalancedWinnowTrainer.DEFAULT_EPSILON); double delta = (double)ps.getValueOrElse("delta", BalancedWinnowTrainer.DEFAULT_DELTA); int iters = (int)ps.getValueOrElse("int", BalancedWinnowTrainer.DEFAULT_MAX_ITERATIONS); double cr = (double)ps.getValueOrElse("coolingRate", BalancedWinnowTrainer.DEFAULT_COOLING_RATE); trainer = new BalancedWinnowTrainer(epsilon,delta,iters,cr); } else if(algorithm.equals(AlgorithmClassification.MalletWinnow_CL_MR)) { Parms ps = new Parms(parms, "a:alpha:d", "b:beta:d", "n:nfact:d"); double alpha = (double)ps.getValueOrElse("alpha", 2.0); double beta = (double)ps.getValueOrElse("beta", 2.0); double nfact = (double)ps.getValueOrElse("nfact", 0.5); trainer = new WinnowTrainer(alpha, beta, nfact); } else { // all other algorithms are still just instantiated from the class name, we ignore // the parameters LOGGER.warn("IMPORTANT: parameters ignored when creating Mallet trainer " + algorithm.getTrainerClass()); Class<?> trainerClass = algorithm.getTrainerClass(); try { @SuppressWarnings("unchecked") Constructor<?> tmpc = trainerClass.getDeclaredConstructor(); trainer = tmpc.newInstance(); } catch (IllegalAccessException | IllegalArgumentException | InstantiationException | NoSuchMethodException | SecurityException | InvocationTargetException ex) { throw new GateRuntimeException("Could not create trainer instance for " + trainerClass, ex); } } } } @Override protected void loadModel(URL directory, String parms) { URL modelFile = newURL(directory, FILENAME_MODEL); Classifier classifier; try (InputStream is = modelFile.openStream(); ObjectInputStream ois = new ObjectInputStream(is)) { classifier = (Classifier) ois.readObject(); model=classifier; } catch (Exception ex) { throw new GateRuntimeException("Could not load Mallet model", ex); } } @Override public EvaluationResult evaluate(String algorithmParameters, EvaluationMethod evaluationMethod, int numberOfFolds, double trainingFraction, int numberOfRepeats) { EvaluationResult ret = null; Parms parms = new Parms(algorithmParameters,"s:seed:i"); int seed = (Integer)parms.getValueOrElse("seed", 1); if(evaluationMethod == EvaluationMethod.CROSSVALIDATION) { CrossValidationIterator cvi = corpusRepresentation.getRepresentationMallet().crossValidationIterator(numberOfFolds, seed); if(algorithm instanceof AlgorithmClassification) { double sumOfAccs = 0.0; while(cvi.hasNext()) { InstanceList[] il = cvi.nextSplit(); InstanceList trainSet = il[0]; InstanceList testSet = il[1]; Classifier cl = ((ClassifierTrainer) trainer).train(trainSet); sumOfAccs += cl.getAccuracy(testSet); } EvaluationResultClXval e = new EvaluationResultClXval(); //e.internalEvaluationResult = null; e.accuracyEstimate = sumOfAccs/numberOfFolds; e.nrFolds = numberOfFolds; ret = e; } else { throw new GateRuntimeException("Mallet evaluation: not available for regression!"); } } else { if(algorithm instanceof AlgorithmClassification) { Random rnd = new Random(seed); double sumOfAccs = 0.0; for(int i = 0; i<numberOfRepeats; i++) { InstanceList[] sets = corpusRepresentation.getRepresentationMallet().split(rnd, new double[]{trainingFraction, 1-trainingFraction}); Classifier cl = ((ClassifierTrainer) trainer).train(sets[0]); sumOfAccs += cl.getAccuracy(sets[1]); } EvaluationResultClHO e = new EvaluationResultClHO(); //e.internalEvaluationResult = null; e.accuracyEstimate = sumOfAccs/numberOfRepeats; e.trainingFraction = trainingFraction; e.nrRepeats = numberOfRepeats; ret = e; } else { throw new GateRuntimeException("Mallet evaluation: not available for regression!"); } } return ret; } }