Java Code Examples for org.deeplearning4j.nn.graph.ComputationGraph#getNumOutputArrays()

The following examples show how to use org.deeplearning4j.nn.graph.ComputationGraph#getNumOutputArrays() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: EarlyStoppingGraphTrainer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**Constructor for training using a {@link DataSetIterator}
 * @param esConfig Configuration
 * @param net Network to train using early stopping
 * @param train DataSetIterator for training the network
 * @param listener Early stopping listener. May be null.
 */
public EarlyStoppingGraphTrainer(EarlyStoppingConfiguration<ComputationGraph> esConfig, ComputationGraph net,
                DataSetIterator train, EarlyStoppingListener<ComputationGraph> listener) {
    super(esConfig, net, train, null, listener);
    if (net.getNumInputArrays() != 1 || net.getNumOutputArrays() != 1)
        throw new IllegalStateException(
                        "Cannot do early stopping training on ComputationGraph with DataSetIterator: graph does not have 1 input and 1 output array");
    this.net = net;
}
 
Example 2
Source File: ScoreUtil.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Get the evaluation
 * for the given model and test dataset
 * @param model the model to get the evaluation from
 * @param testData the test data to do the evaluation on
 * @return the evaluation object with accumulated statistics
 * for the current test data
 */
public static Evaluation getEvaluation(ComputationGraph model, MultiDataSetIterator testData) {
    if (model.getNumOutputArrays() != 1)
        throw new IllegalStateException("GraphSetSetAccuracyScoreFunction cannot be "
                        + "applied to ComputationGraphs with more than one output. NumOutputs = "
                        + model.getNumOutputArrays());

    return model.evaluate(testData);
}
 
Example 3
Source File: ScoreUtil.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Get the evaluation
 * for the given model and test dataset
 * @param model the model to get the evaluation from
 * @param testData the test data to do the evaluation on
 * @return the evaluation object with accumulated statistics
 * for the current test data
 */
public static Evaluation getEvaluation(ComputationGraph model, DataSetIterator testData) {
    if (model.getNumOutputArrays() != 1)
        throw new IllegalStateException("GraphSetSetAccuracyScoreFunctionDataSet cannot be "
                        + "applied to ComputationGraphs with more than one output. NumOutputs = "
                        + model.getNumOutputArrays());

    return model.evaluate(testData);
}