package tsml.classifiers.distance_based.elastic_ensemble;

import com.google.common.collect.ImmutableList;
import evaluation.storage.ClassifierResults;
import java.util.function.Consumer;

import experiments.data.DatasetLoading;
import machine_learning.classifiers.ensembles.AbstractEnsemble;
import machine_learning.classifiers.ensembles.voting.MajorityVote;
import machine_learning.classifiers.ensembles.voting.ModuleVotingScheme;
import machine_learning.classifiers.ensembles.weightings.ModuleWeightingScheme;
import machine_learning.classifiers.ensembles.weightings.TrainAcc;
import tsml.classifiers.*;
import tsml.classifiers.distance_based.knn.KNNLOOCV;
import tsml.classifiers.distance_based.knn.strategies.RLTunedKNNSetup;
import tsml.classifiers.distance_based.tuned.RLTunedClassifier;
import tsml.classifiers.distance_based.utils.MemoryWatchable;
import tsml.classifiers.distance_based.utils.checkpointing.CheckpointUtils;
import tsml.classifiers.distance_based.utils.classifier_mixins.BaseClassifier;
import tsml.classifiers.distance_based.utils.classifier_mixins.TrainEstimateable;
import tsml.classifiers.distance_based.utils.memory.GcMemoryWatchable;
import tsml.classifiers.distance_based.utils.memory.MemoryWatcher;
import tsml.classifiers.distance_based.utils.stopwatch.Stated;
import tsml.classifiers.distance_based.utils.stopwatch.StopWatch;
import tsml.classifiers.distance_based.utils.stopwatch.StopWatchTrainTimeable;
import tsml.classifiers.distance_based.utils.StrUtils;
import tsml.classifiers.distance_based.utils.classifier_building.CompileTimeClassifierBuilderFactory;
import utilities.*;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

import java.util.*;
import java.util.logging.Level;
import java.util.logging.Logger;

// todo this has likeness to RLTuner, perhaps need to unify somewhere / make this a RLTuner?

public class ElasticEnsemble extends BaseClassifier implements TrainTimeContractable, Checkpointable,
    GcMemoryWatchable, StopWatchTrainTimeable {

    public static void main(String[] args) throws Exception {
        int seed = 0;
        Instances[] data = DatasetLoading.sampleGunPoint(seed);
        ElasticEnsemble classifier = FACTORY.EE_V2.build();
        classifier.setEstimateOwnPerformance(true);
        classifier.setSeed(0);
        classifier.getLogger().setLevel(Level.ALL);
        ClassifierResults results = ClassifierTools.trainAndTest(data, classifier);
        results.setDetails(classifier, data[1]);
        ClassifierResults trainResults = ((TrainEstimateable) classifier).getTrainResults();
        trainResults.setDetails(classifier, data[0]);
        System.out.println(trainResults.writeSummaryResultsToString());
        System.out.println(results.writeSummaryResultsToString());
    }

    public static final Factory FACTORY = new Factory();

    /**
     * get whether the train estimate will be regenerated
     * @return
     */
    public boolean isRegenerateTrainEstimate() {
        return regenerateTrainEstimate;
    }

    /**
     * set whether the train estimate will be regenerated
     * @param regenerateTrainEstimate
     * @return
     */
    protected ElasticEnsemble setRegenerateTrainEstimate(boolean regenerateTrainEstimate) {
        this.regenerateTrainEstimate = regenerateTrainEstimate;
        return this;
    }

    public static class Factory extends CompileTimeClassifierBuilderFactory<ElasticEnsemble> {
        public final ClassifierBuilder<? extends ElasticEnsemble> EE_V1 = add(new SuppliedClassifierBuilder<>("EE_V1",
            Factory::buildEeV1));
        public final ClassifierBuilder<? extends ElasticEnsemble> EE_V2 = add(new SuppliedClassifierBuilder<>("EE_V2",
            Factory::buildEeV2));
        public final ClassifierBuilder<? extends ElasticEnsemble> LEE = add(new SuppliedClassifierBuilder<>("LEE",
            Factory::buildLee));


        public static ImmutableList<Classifier> buildV1Constituents() {
            return ImmutableList.of(
                KNNLOOCV.FACTORY.ED_1NN_V1.build(),
                KNNLOOCV.FACTORY.DTW_1NN_V1.build(),
                KNNLOOCV.FACTORY.DDTW_1NN_V1.build(),
                KNNLOOCV.TUNED_FACTORY.TUNED_DTW_1NN_V1.build(),
                KNNLOOCV.TUNED_FACTORY.TUNED_DDTW_1NN_V1.build(),
                KNNLOOCV.TUNED_FACTORY.TUNED_WDTW_1NN_V1.build(),
                KNNLOOCV.TUNED_FACTORY.TUNED_WDDTW_1NN_V1.build(),
                KNNLOOCV.TUNED_FACTORY.TUNED_ERP_1NN_V1.build(),
                KNNLOOCV.TUNED_FACTORY.TUNED_MSM_1NN_V1.build(),
                KNNLOOCV.TUNED_FACTORY.TUNED_LCSS_1NN_V1.build(),
                KNNLOOCV.TUNED_FACTORY.TUNED_TWED_1NN_V1.build()
            );
        }

        public static ImmutableList<Classifier> buildV2Constituents() {
            return ImmutableList.of(
                KNNLOOCV.FACTORY.ED_1NN_V2.build(),
                KNNLOOCV.FACTORY.DTW_1NN_V2.build(),
                KNNLOOCV.FACTORY.DDTW_1NN_V2.build(),
                KNNLOOCV.TUNED_FACTORY.TUNED_DTW_1NN_V2.build(),
                KNNLOOCV.TUNED_FACTORY.TUNED_DDTW_1NN_V2.build(),
                KNNLOOCV.TUNED_FACTORY.TUNED_WDTW_1NN_V2.build(),
                KNNLOOCV.TUNED_FACTORY.TUNED_WDDTW_1NN_V2.build(),
                KNNLOOCV.TUNED_FACTORY.TUNED_ERP_1NN_V2.build(),
                KNNLOOCV.TUNED_FACTORY.TUNED_MSM_1NN_V2.build(),
                KNNLOOCV.TUNED_FACTORY.TUNED_LCSS_1NN_V2.build(),
                KNNLOOCV.TUNED_FACTORY.TUNED_TWED_1NN_V2.build()
            );
        }

        public static ElasticEnsemble buildEeV1() {
            ElasticEnsemble elasticEnsemble = new ElasticEnsemble();
            elasticEnsemble.setConstituents(buildV1Constituents());
            setTrainSelectedBenchmarksFully(elasticEnsemble,false);
            return elasticEnsemble;
        }

        public static ElasticEnsemble buildEeV2() {
            ElasticEnsemble elasticEnsemble = new ElasticEnsemble();
            elasticEnsemble.setConstituents(buildV2Constituents());
            setTrainSelectedBenchmarksFully(elasticEnsemble,false);
            return elasticEnsemble;
        }

        private static ElasticEnsemble forEachTunedConstituent(ElasticEnsemble elasticEnsemble, Consumer<RLTunedKNNSetup> consumer) {
            for(Classifier classifier : elasticEnsemble.getConstituents()) {
                if(!(classifier instanceof RLTunedClassifier)) {
                    continue;
                }
                RLTunedClassifier tuner = (RLTunedClassifier) classifier;
                RLTunedKNNSetup config = (RLTunedKNNSetup) tuner.getTrainSetupFunction();
                consumer.accept(config);
            }
            return elasticEnsemble;
        }

        public static ElasticEnsemble setLimitedParameters(ElasticEnsemble elasticEnsemble, int limit) {
            return forEachTunedConstituent(elasticEnsemble, RLTunedKNNSetup -> RLTunedKNNSetup.setParamSpaceSizeLimit(limit));
        }

        public static ElasticEnsemble setLimitedParametersPercentage(ElasticEnsemble elasticEnsemble, double limit) {
            return forEachTunedConstituent(elasticEnsemble, RLTunedKNNSetup -> RLTunedKNNSetup.setParamSpaceSizeLimitPercentage(limit));
        }

        public static ElasticEnsemble setLimitedNeighbours(ElasticEnsemble elasticEnsemble, int limit) {
            return forEachTunedConstituent(elasticEnsemble, RLTunedKNNSetup -> RLTunedKNNSetup.setNeighbourhoodSizeLimit(limit));
        }

        public static ElasticEnsemble setLimitedNeighboursPercentage(ElasticEnsemble elasticEnsemble, double limit) { // todo params from cmdline in experiment + append to cls name
            return forEachTunedConstituent(elasticEnsemble, RLTunedKNNSetup -> RLTunedKNNSetup.setParamSpaceSizeLimitPercentage(limit));
        }

        public static ElasticEnsemble setTrainSelectedBenchmarksFully(ElasticEnsemble elasticEnsemble, boolean state) { // todo params from cmdline in experiment + append to cls name
            return forEachTunedConstituent(elasticEnsemble, RLTunedKNNSetup -> RLTunedKNNSetup.setTrainSelectedBenchmarksFully(state));
        }

        private static ElasticEnsemble buildLee() {
            ElasticEnsemble elasticEnsemble = new ElasticEnsemble();
            ImmutableList<Classifier> constituents = buildV2Constituents();
            elasticEnsemble.setConstituents(constituents);
            setLimitedNeighboursPercentage(elasticEnsemble, 0.1);
            setLimitedParametersPercentage(elasticEnsemble, 0.5);
            setTrainSelectedBenchmarksFully(elasticEnsemble,true);
            return elasticEnsemble;
        }
    }

    /**
     * get the constituents in this ensemble
     * @return
     */
    public ImmutableList<EnhancedAbstractClassifier> getConstituents() {
        return constituents;
    }

    /**
     * set the constituents
     * @param constituents
     */
    public void setConstituents(final Iterable<? extends Classifier> constituents) {
        List<EnhancedAbstractClassifier> list = new ArrayList<>();
        for(Classifier constituent : constituents) {
            if(constituent instanceof EnhancedAbstractClassifier) {
                list.add((EnhancedAbstractClassifier) constituent);
            } else {
                throw new IllegalArgumentException("constituents have to be EAC"); // todo some kind of wrapper
                // around ones which aren't EAC for generic'ness, not important right now
            }
        }
        this.constituents = ImmutableList.copyOf(list);
    }

    /**
     * build default v1 EE (the traditional version)
     */
    public ElasticEnsemble() {
        super(true);
        setConstituents(Factory.buildV1Constituents());
    }

    // the constituents
    private ImmutableList<EnhancedAbstractClassifier> constituents = ImmutableList.of();
    // the constituents which we're currently looking at as part of a batch of training. For example, if we chose to
    // train each constituent for 5 mins then this list contains all of the constituents which have *NOT* had their 5
    // mins of training yet
    private List<EnhancedAbstractClassifier> constituentsBatch = new ArrayList<>();
    // constituents which still have work remaining. For example, this would be any constituent which is not done after
    // the 5 mins above
    private List<EnhancedAbstractClassifier> nextConstituentsBatch = new ArrayList<>();
    // constituents which have been fully built
    private List<EnhancedAbstractClassifier> trainedConstituents = new ArrayList<>();
    // track the train time
    private StopWatch trainTimer = new StopWatch();
    // track the train estimate time
    private StopWatch trainEstimateTimer = new StopWatch();
    // how to combine the constituent votes
    private ModuleVotingScheme votingScheme = new MajorityVote();
    // how to weight each constituent
    private ModuleWeightingScheme weightingScheme = new TrainAcc();
    // final module array of constituents
    private AbstractEnsemble.EnsembleModule[] modules;
    // the amount of train time remaining for each constituent. In our above example this would be 5 mins
    private long remainingTrainTimeNanosPerConstituent;
    // whether we've done a first pass of all constituents
    private boolean firstBatchDone;
    // record the memory usage
    private final MemoryWatcher memoryWatcher = new MemoryWatcher();
    // store the train data
    private transient Instances trainData;
    // switch to regenerate the train estimate
    private boolean regenerateTrainEstimate = true;
    // train time limit
    private transient long trainTimeLimitNanos = -1;
    // version id for serialisation
    protected transient long trainContractTimeNanos = -1;
    //TODO George to integrate the boolean into the classifier logic
    private boolean trainTimeContract = false;

    private static final long serialVersionUID = 0;
    // minimum checkpoint interval
    private transient long minCheckpointIntervalNanos = Checkpointable.DEFAULT_MIN_CHECKPOINT_INTERVAL;
    // timestamp of last checkpoint
    private transient long lastCheckpointTimeStamp = 0;
    // save path for checkpoints
    private transient String savePath = null;
    // load path for checkpoints
    private transient String loadPath = null;
    // whether to skip the final checkpoint
    private transient boolean skipFinalCheckpoint = false;

    @Override
    public boolean isSkipFinalCheckpoint() {
        return skipFinalCheckpoint;
    }

    @Override
    public void setSkipFinalCheckpoint(boolean skipFinalCheckpoint) {
        this.skipFinalCheckpoint = skipFinalCheckpoint;
    }

    @Override
    public String getSavePath() {
        return savePath;
    }

    @Override
    public boolean setCheckpointPath(String path) {
        boolean result = Checkpointable.super.createDirectories(path);
        if(result) {
            savePath = StrUtils.asDirPath(path);
        } else {
            savePath = null;
        }
        return result;
    }

    @Override public String getLoadPath() {
        return loadPath;
    }

    @Override public boolean setLoadPath(final String path) {
        boolean result = Checkpointable.super.setLoadPath(path);
        if(result) {
            loadPath = StrUtils.asDirPath(path);
        } else {
            loadPath = null;
        }
        return result;
    }

    public StopWatch getTrainTimer() {
        return trainTimer;
    }

    public Instances getTrainData() { // todo is this needed?
        return trainData;
    }

    public long getLastCheckpointTimeStamp() {
        return lastCheckpointTimeStamp;
    }

    public boolean saveToCheckpoint() throws Exception {
        trainTimer.suspend();
        trainEstimateTimer.suspend();
        memoryWatcher.suspend();
        boolean result = CheckpointUtils.saveToSingleCheckpoint(this, getLogger(), isBuilt() && !skipFinalCheckpoint);
        memoryWatcher.unsuspend();
        trainEstimateTimer.unsuspend();
        trainTimer.unsuspend();
        return result;
    }

    public boolean loadFromCheckpoint() {
        trainTimer.suspend(); // todo better way of handling this
        trainEstimateTimer.suspend();
        memoryWatcher.suspend();
        boolean result = CheckpointUtils.loadFromSingleCheckpoint(this, getLogger());
        lastCheckpointTimeStamp = System.nanoTime();
        memoryWatcher.unsuspend();
        trainEstimateTimer.unsuspend();
        trainTimer.unsuspend();
        return result;
    }

    public void setMinCheckpointIntervalNanos(final long nanos) {
        if(minCheckpointIntervalNanos < 0) {
            throw new IllegalArgumentException("cannot be less than 0: " + nanos);
        }
        minCheckpointIntervalNanos = nanos;
    }

    public long getMinCheckpointIntervalNanos() {
        return minCheckpointIntervalNanos;
    }

    @Override public MemoryWatcher getMemoryWatcher() {
        return memoryWatcher;
    }

    @Override public void setLastCheckpointTimeStamp(final long lastCheckpointTimeStamp) {
        this.lastCheckpointTimeStamp = lastCheckpointTimeStamp;
    }

    public StopWatch getTrainEstimateTimer() {
        return trainEstimateTimer;
    }

    @Override
    public void setTrainTimeLimit(long nanos) {
        trainTimeLimitNanos = nanos;
    }

    @Override public long predictNextTrainTimeNanos() { // todo this may be better in its own interface
        long result = 0;
        // if we've got no more constituents to look at then we're done
        if(!nextConstituentsBatch.isEmpty()) {
            // otherwise get the next constituent
            EnhancedAbstractClassifier classifier = nextConstituentsBatch.get(0);
            // if it's able to predict its next amount of time then use that
            if(classifier instanceof TrainTimeContractable) {
                result = ((TrainTimeContractable) classifier).predictNextTrainTimeNanos();
            }
        }
        return result;
    }

    @Override public long getTrainContractTimeNanos() {
        return trainContractTimeNanos;
    }

    private void setRemainingTrainTimeNanosPerConstituent() {
        // if we've got no train time limit then the constituents can take as long as they like
        // if we've got no constituents in the batch then there's no remaining time
        if(!hasTrainTimeLimit() || constituentsBatch.isEmpty()) {
            remainingTrainTimeNanosPerConstituent = -1;
        } else {
            remainingTrainTimeNanosPerConstituent = getRemainingTrainTimeNanos() / constituentsBatch.size();
        }
    }

    @Override public void buildClassifier(final Instances trainData) throws Exception {
        // first lets load from a checkpoint if there is one
        loadFromCheckpoint();
        // enable the resource trackers
        trainTimer.enable();
        memoryWatcher.enable();
        trainEstimateTimer.checkDisabled();
        final Logger logger = getLogger();
        // find whether we're rebuilding
        final boolean rebuild = isRebuild();
        // if we're rebuilding
        if(rebuild) {
            // reset the resource trackers
            trainTimer.resetAndEnable();
            memoryWatcher.resetAndEnable();
            trainEstimateTimer.resetAndDisable();
        }
        // let super build
        super.buildClassifier(trainData);
        // hold the train data
        this.trainData = trainData;
        if(rebuild) {
            // if we're rebuilding then setup
            if(constituents == null || constituents.isEmpty()) {
                throw new IllegalStateException("empty constituents");
            }
            // initialise
            firstBatchDone = false;
            constituentsBatch = new ArrayList<>(constituents);
            trainedConstituents = new ArrayList<>();
            // for each constituent
            for(EnhancedAbstractClassifier constituent : constituents) {
                // set their seed to match ours - better reproducibility if we run a constituent individually with
                // same seed
                constituent.setSeed(seed);
                // tell them to find a train estimate for weighting
                constituent.setEstimateOwnPerformance(true);
                // if the constituent can do checkpointing
                if(constituent instanceof Checkpointable) {
                    // setup all the checkpointing details
                    if(isCheckpointLoadingEnabled()) { // todo paths need to be appended with constituent name
                        ((Checkpointable) constituent).setLoadPath(loadPath);
                    }
                    if(isCheckpointSavingEnabled()) {
                        ((Checkpointable) constituent).setCheckpointPath(savePath);
                    }
                    ((Checkpointable) constituent).setMinCheckpointIntervalNanos(minCheckpointIntervalNanos);
                    ((Checkpointable) constituent).setSkipFinalCheckpoint(skipFinalCheckpoint);
                }
            }
            nextConstituentsBatch = new ArrayList<>();
            // find how much train time remains and split between the constituents
            trainTimer.lap();
            setRemainingTrainTimeNanosPerConstituent();
        }
        // switch resource monitors if not already
        trainTimer.enableAnyway();
        trainEstimateTimer.disableAnyway();
        // while there's constituents to process and time left
        while(hasNextBuildTick()) {
            // process another constituent
            nextBuildTick();
            // save this to checkpoint
            saveToCheckpoint();
        }
        // if we're estimating our train
        if(regenerateTrainEstimate) {
            logger.fine("generating train estimate");
            modules = new AbstractEnsemble.EnsembleModule[constituents.size()];
            int i = 0;
            // translate constituents to modules
            for(EnhancedAbstractClassifier constituent : constituents) {
                trainEstimateTimer.add(constituent.getTrainResults().getBuildPlusEstimateTime());
                modules[i] = new AbstractEnsemble.EnsembleModule();
                modules[i].setClassifier(constituent);
                modules[i].trainResults = constituent.getTrainResults();
                i++;
            }
            // weight constituents
            logger.fine("weighting constituents");
            weightingScheme.defineWeightings(modules, trainData.numClasses());
            votingScheme.trainVotingScheme(modules, trainData.numClasses());
            trainResults = new ClassifierResults();
            // vote constituents
            for(i = 0; i < trainData.size(); i++) {
                StopWatch predictionTimer = new StopWatch(Stated.State.ENABLED);
                double[] distribution = votingScheme.distributionForTrainInstance(modules, i);
                int prediction = ArrayUtilities.argMax(distribution);
                predictionTimer.disable();
                double trueClassValue = trainData.get(i).classValue();
                trainResults.addPrediction(trueClassValue, distribution, prediction, predictionTimer.getTimeNanos(), null);
            }
        }
        // have regenerated train estimate so disable
        regenerateTrainEstimate = false;
        // disable resource monitors
        memoryWatcher.disableAnyway();
        trainEstimateTimer.disableAnyway();
        trainTimer.disableAnyway();
        // set train results details
        trainResults.setDetails(this, trainData);
        // free up train data
        this.trainData = null;
        // we're built by here
        setBuilt(true);
        logger.info("build finished");
        saveToCheckpoint();
    }

    private boolean hasTimeRemainingPerConstituent() {
        return remainingTrainTimeNanosPerConstituent >= 0;
    }

    /**
     * further iterations training a singular constituent by the remaining train time per constituent. Updates the
     * constituent records afterwards reflecting whether the constituent is done or has training remaining. If this
     * handles the last untrained constituent for this batch then we repopulate the batch from the constituents which
     * are not finished training, distributing the train time between them again. E.g. if we have 3 classifiers, A, B
     * and C. If we have a train time of 15 mins we would do 3 executions of this function with 5 mins for each
     * classifier. Suppose classifier B and C finished in the full 5 mins and classifier A finished in 3 mins then
     * there is a remaining 2 mins left of the 15 mins total train contract. We then repopulate the batch of
     * classifiers with B and C (the unfinished classifiers) and split the remaining train time between them (2 mins
     * --> 1 min each). Repeat until all classifiers are trained or train time is depleted to zero.
     * @throws Exception
     */
    private void nextBuildTick() throws Exception {
        final Logger logger = getLogger();
        // get the next constituent
        EnhancedAbstractClassifier constituent = constituentsBatch.remove(0);
        if(constituent == null) {
            throw new IllegalStateException("something has gone wrong, constituent should not be null");
        }
        // set the train time limit if possible
        if(constituent instanceof TrainTimeContractable && hasTimeRemainingPerConstituent()) {
            ((TrainTimeContractable) constituent).setTrainTimeLimitNanos(remainingTrainTimeNanosPerConstituent);
        }
        // track the train time of the constituent
        StopWatch constituentTrainTimer = new StopWatch();
        // disable our train timer as the constituent train timer will take it from here
        trainTimer.disable();
        if(constituent instanceof TrainTimeable) {
            constituentTrainTimer.disableAnyway();
        } else {
            constituentTrainTimer.enableAnyway();
        }
        // track the memory of the constituent
        MemoryWatcher constituentMemoryWatcher = new MemoryWatcher();
        // disable our memory watcher as the constituent memory watcher will take it from here
        memoryWatcher.disable();
        if(constituent instanceof MemoryWatchable) {
            constituentMemoryWatcher.disableAnyway();
        } else {
            constituentMemoryWatcher.enableAnyway();
        }
        logger.fine(() -> "running constituent {id: "+
                   (constituents.size() - constituentsBatch.size())+
                   " " +
                   constituent.getClassifierName()+
                   " }");
        constituent.buildClassifier(trainData);
        logger.fine(() -> "ran constituent {id: "+
                   (constituents.size() - constituentsBatch.size())+
                   " acc: "+
                   constituent.getTrainResults().getAcc()+
                   " "+
                   constituent.getClassifierName()+
                   " }");
        // disable resource monitors for the constituent and re-enable ours
        constituentTrainTimer.disableAnyway();
        constituentMemoryWatcher.disableAnyway();
        memoryWatcher.enable();
        trainTimer.enable();
        // sanity check the train estimate timer is disabled
        trainEstimateTimer.checkDisabled();
        // add the constituent's train time onto ours
        if(constituent instanceof TrainTimeable) { // todo these can probs be a util method as similar elsewhere
            // (RLTune)
            trainTimer.add(((TrainTimeable) constituent).getTrainTimeNanos());
        } else {
            trainTimer.add(constituentTrainTimer);
        }
        // add the constituent's train estimate time onto ours
        if(constituent instanceof TrainEstimateTimeable) {
            // the classifier tracked its time internally
            this.trainEstimateTimer.add(((TrainTimeable) constituent).getTrainTimeNanos());
        } else {
            // we already tracked this as part of the train time
        }
        // add the constituents memory usage onto ours
        if(constituent instanceof MemoryWatchable) {
            memoryWatcher.add((MemoryWatchable) constituent);
        } else {
            memoryWatcher.add(constituentMemoryWatcher);
        }
        // if the constituent is contracting train time AND there's time remaining for each constituent AND the
        // constituent has remaining work to do
        if(constituent instanceof TrainTimeContractable && hasTimeRemainingPerConstituent() &&
                ((TrainTimeContractable) constituent).hasRemainingTraining()) {
            // add it to the next batch of constituents
            nextConstituentsBatch.add(constituent);
        }
        // if there's no more constituents to process
        if(constituentsBatch.isEmpty()) {
            // we have definitely seen all constituents here
            firstBatchDone = true;
            // add all of the next constituent batch to the current batch
            constituentsBatch.addAll(nextConstituentsBatch);
            // clear out the next constituent batch as they've all been added to the current batch
            nextConstituentsBatch.clear();
            // recalculate the remaining time for each constituent
            setRemainingTrainTimeNanosPerConstituent();
        }
        // we've adjusted one of the constituents therefore we need to regenerate the train estimate
        setRegenerateTrainEstimate(true);
    }

    /**
     * whether further build steps remain
     * @return
     * @throws Exception
     */
    public boolean hasNextBuildTick() throws Exception {
        // must do a first pass of all constituents, therefore if the first batch hasn't been completed this should
        // always return true
        // otherwise, it's dependent on whether there's further training remaining
        return !firstBatchDone || (hasRemainingTrainTime() && !constituentsBatch.isEmpty());
    }

    @Override public double[] distributionForInstance(final Instance instance) throws Exception {
        return votingScheme.distributionForInstance(modules, instance);
    }

    @Override public double classifyInstance(final Instance instance) throws Exception {
        return Utilities.argMax(distributionForInstance(instance), getRandom());
    }

    public ModuleVotingScheme getVotingScheme() {
        return votingScheme;
    }

    public void setVotingScheme(final ModuleVotingScheme votingScheme) {
        this.votingScheme = votingScheme;
    }

    public ModuleWeightingScheme getWeightingScheme() {
        return weightingScheme;
    }

    public void setWeightingScheme(final ModuleWeightingScheme weightingScheme) {
        this.weightingScheme = weightingScheme;
    }
}