org.deeplearning4j.nn.api.Model Java Examples

The following examples show how to use org.deeplearning4j.nn.api.Model. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: InplaceParallelInference.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * This method does forward pass and returns output provided by OutputAdapter
 *
 * @param adapter
 * @param input
 * @param inputMasks
 * @param <T>
 * @return
 */
public <T> T output(@NonNull ModelAdapter<T> adapter, INDArray[] input, INDArray[] inputMasks, INDArray[] labelsMasks) {
    val holder = selector.getModelForThisThread();
    Model model = null;
    boolean acquired = false;
    try {
        model = holder.acquireModel();
        acquired = true;
        return adapter.apply(model, input, inputMasks, labelsMasks);
    } catch (InterruptedException e) {
        throw new RuntimeException(e);
    } finally {
        if (model != null && acquired)
            holder.releaseModel(model);
    }
}
 
Example #2
Source File: EvaluationRunner.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
private static void doEval(Model m, IEvaluation[] e, Iterator<DataSet> ds, Iterator<MultiDataSet> mds, int evalBatchSize){
    if(m instanceof MultiLayerNetwork){
        MultiLayerNetwork mln = (MultiLayerNetwork)m;
        if(ds != null){
            mln.doEvaluation(new IteratorDataSetIterator(ds, evalBatchSize), e);
        } else {
            mln.doEvaluation(new IteratorMultiDataSetIterator(mds, evalBatchSize), e);
        }
    } else {
        ComputationGraph cg = (ComputationGraph)m;
        if(ds != null){
            cg.doEvaluation(new IteratorDataSetIterator(ds, evalBatchSize), e);
        } else {
            cg.doEvaluation(new IteratorMultiDataSetIterator(mds, evalBatchSize), e);
        }
    }
}
 
Example #3
Source File: ModelTupleStreamTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void test() throws Exception {
  int testsCount = 0;
  for (int numInputs = 1; numInputs <= 5; ++numInputs) {
    for (int numOutputs = 1; numOutputs <= 5; ++numOutputs) {

      for (Model model : new Model[]{
          buildMultiLayerNetworkModel(numInputs, numOutputs),
          buildComputationGraphModel(numInputs, numOutputs)
        }) {

        doTest(model, numInputs, numOutputs);
        ++testsCount;

      }
    }
  }
  assertEquals(50, testsCount);
}
 
Example #4
Source File: EpochListener.java    From wekaDeeplearning4j with GNU General Public License v3.0 6 votes vote down vote up
@Override
public void onEpochEnd(Model model) {
  currentEpoch++;

  // Skip if this is not an evaluation epoch
  if (currentEpoch % n != 0) {
    return;
  }

  String s = "Epoch [" + currentEpoch + "/" + numEpochs + "]\n";

  if (isIntermediateEvaluationsEnabled) {
    s += "Train Set:      \n" + evaluateDataSetIterator(model, trainIterator, true);
    if (validationIterator != null) {
      s += "Validation Set: \n" + evaluateDataSetIterator(model, validationIterator, false);
    }
  }

  log(s);
}
 
Example #5
Source File: ModelGuesserTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testNormalizerInPlace() throws Exception {
    MultiLayerNetwork net = getNetwork();

    File tempFile = testDir.newFile("testNormalizerInPlace.bin");

    NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1);
    normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2})));
    ModelSerializer.writeModel(net, tempFile, true,normalizer);

    Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
    Normalizer<?> normalizer1 = ModelGuesser.loadNormalizer(tempFile.getAbsolutePath());
    assertEquals(model, net);
    assertEquals(normalizer, normalizer1);

}
 
Example #6
Source File: IntegrationTestRunner.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
private static Map<String,INDArray> getFrozenLayerParamCopies(Model m){
    Map<String,INDArray> out = new LinkedHashMap<>();
    org.deeplearning4j.nn.api.Layer[] layers;
    if (m instanceof MultiLayerNetwork) {
        layers = ((MultiLayerNetwork) m).getLayers();
    } else {
        layers = ((ComputationGraph) m).getLayers();
    }

    for(org.deeplearning4j.nn.api.Layer l : layers){
        if(l instanceof FrozenLayer){
            String paramPrefix;
            if(m instanceof MultiLayerNetwork){
                paramPrefix = l.getIndex() + "_";
            } else {
                paramPrefix = l.conf().getLayer().getLayerName() + "_";
            }
            Map<String,INDArray> paramTable = l.paramTable();
            for(Map.Entry<String,INDArray> e : paramTable.entrySet()){
                out.put(paramPrefix + e.getKey(), e.getValue().dup());
            }
        }
    }

    return out;
}
 
Example #7
Source File: BaseEarlyStoppingTrainer.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected void triggerEpochListeners(boolean epochStart, Model model, int epochNum){
    Collection<TrainingListener> listeners;
    if(model instanceof MultiLayerNetwork){
        MultiLayerNetwork n = ((MultiLayerNetwork) model);
        listeners = n.getListeners();
        n.setEpochCount(epochNum);
    } else if(model instanceof ComputationGraph){
        ComputationGraph cg = ((ComputationGraph) model);
        listeners = cg.getListeners();
        cg.getConfiguration().setEpochCount(epochNum);
    } else {
        return;
    }

    if(listeners != null && !listeners.isEmpty()){
        for (TrainingListener l : listeners) {
            if (epochStart) {
                l.onEpochStart(model);
            } else {
                l.onEpochEnd(model);
            }
        }
    }
}
 
Example #8
Source File: BaseStatsListener.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void onGradientCalculation(Model model) {
    int iterCount = getModelInfo(model).iterCount;
    if (calcFromGradients() && updateConfig.reportingFrequency() > 0
            && (iterCount == 0 || iterCount % updateConfig.reportingFrequency() == 0)) {
        Gradient g = model.gradient();
        if (updateConfig.collectHistograms(StatsType.Gradients)) {
            gradientHistograms = getHistograms(g.gradientForVariable(), updateConfig.numHistogramBins(StatsType.Gradients));
        }

        if (updateConfig.collectMean(StatsType.Gradients)) {
            meanGradients = calculateSummaryStats(g.gradientForVariable(), StatType.Mean);
        }
        if (updateConfig.collectStdev(StatsType.Gradients)) {
            stdevGradient = calculateSummaryStats(g.gradientForVariable(), StatType.Stdev);
        }
        if (updateConfig.collectMeanMagnitudes(StatsType.Gradients)) {
            meanMagGradients = calculateSummaryStats(g.gradientForVariable(), StatType.MeanMagnitude);
        }
    }
}
 
Example #9
Source File: TrainUtil.java    From FancyBing with GNU General Public License v3.0 5 votes vote down vote up
public static String saveModel(String name, Model model, int index, int accuracy) throws Exception {
  	System.err.println("Saving model, don't shutdown...");
      try {
      	String fn = name + "_idx_" + index + "_" + accuracy + ".zip";
	File locationToSave = new File(System.getProperty("user.dir") + "/model/" + fn);
	boolean saveUpdater = true;                                             //Updater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this if you want to train your network more in the future
	ModelSerializer.writeModel(model, locationToSave, saveUpdater);
	System.err.println("Model saved");
	return fn;
} catch (IOException e) {
	System.err.println("Save model failed");
	e.printStackTrace();
	throw e;
}
  }
 
Example #10
Source File: BaseOptimizer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public static int getEpochCount(Model model){
    if (model instanceof MultiLayerNetwork) {
        return ((MultiLayerNetwork) model).getLayerWiseConfigurations().getEpochCount();
    } else if (model instanceof ComputationGraph) {
        return ((ComputationGraph) model).getConfiguration().getEpochCount();
    } else {
        return model.conf().getEpochCount();
    }
}
 
Example #11
Source File: IntegrationTestRunner.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
private static void validateLayerIterCounts(Model m, int expEpoch, int expIter){
    //Check that the iteration and epoch counts - on the layers - are synced
    org.deeplearning4j.nn.api.Layer[] layers;
    if (m instanceof MultiLayerNetwork) {
        layers = ((MultiLayerNetwork) m).getLayers();
    } else {
        layers = ((ComputationGraph) m).getLayers();
    }

    for(org.deeplearning4j.nn.api.Layer l : layers){
        assertEquals("Epoch count", expEpoch, l.getEpochCount());
        assertEquals("Iteration count", expIter, l.getIterationCount());
    }
}
 
Example #12
Source File: SystemInfoFilePrintListener.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void onForwardPass(Model model, List<INDArray> activations) {
    if(!printOnBackwardPass || printFileTarget == null)
        return;

    writeFileWithMessage("forward pass");

}
 
Example #13
Source File: ModelTupleStreamIntegrationTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
private static Model buildModel() throws Exception {

    final int numInputs = 3;
    final int numOutputs = 2;

    final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .list(
            new OutputLayer.Builder()
                           .nIn(numInputs)
                           .nOut(numOutputs)
                           .activation(Activation.IDENTITY)
                           .lossFunction(LossFunctions.LossFunction.MSE)
                           .build()
            )
        .build();

    final MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();

    final float[] floats = new float[]{ +1, +1, +1, -1, -1, -1, 0, 0 };
    // positive weight for first output, negative weight for second output, no biases
    assertEquals((numInputs+1)*numOutputs, floats.length);

    final INDArray params = Nd4j.create(floats);
    model.setParams(params);

    return model;
  }
 
Example #14
Source File: SleepyTrainingListener.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void iterationDone(Model model, int iteration, int epoch) {
    sleep(lastIteration.get(), timerIteration);

    if (lastIteration.get() == null)
        lastIteration.set(new AtomicLong(System.currentTimeMillis()));
    else
        lastIteration.get().set(System.currentTimeMillis());
}
 
Example #15
Source File: SystemInfoFilePrintListener.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void onBackwardPass(Model model) {
    if(!printOnBackwardPass || printFileTarget == null)
        return;

    writeFileWithMessage("backward pass");
}
 
Example #16
Source File: DL4JArbiterStatusReportingListener.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void iterationDone(Model model, int iteration, int epoch) {
    if (statusListeners == null) {
        return;
    }

    for (StatusListener sl : statusListeners) {
        sl.onCandidateIteration(candidateInfo, model, iteration);
    }
}
 
Example #17
Source File: UpdaterCreator.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public static org.deeplearning4j.nn.api.Updater getUpdater(Model layer) {
    if (layer instanceof MultiLayerNetwork) {
        return new MultiLayerUpdater((MultiLayerNetwork) layer);
    } else if (layer instanceof ComputationGraph) {
        return new ComputationGraphUpdater((ComputationGraph) layer);
    } else {
        return new LayerUpdater((Layer) layer);
    }
}
 
Example #18
Source File: ModelGuesser.java    From konduit-serving with Apache License 2.0 5 votes vote down vote up
/**
 * Loads a dl4j zip file (either computation graph or multi layer network)
 *
 * @param path the path to the file to load
 * @return a loaded dl4j model
 * @throws Exception if loading a dl4j model fails
 */
public static Model loadDl4jGuess(String path) throws Exception {
    if (isZipFile(new File(path))) {
        log.debug("Loading file " + path);
        boolean compGraph = false;
        try (ZipFile zipFile = new ZipFile(path)) {
            List<String> collect = zipFile.stream().map(ZipEntry::getName)
                    .collect(Collectors.toList());
            log.debug("Entries " + collect);
            if (collect.contains(ModelSerializer.COEFFICIENTS_BIN) && collect.contains(ModelSerializer.CONFIGURATION_JSON)) {
                ZipEntry entry = zipFile.getEntry(ModelSerializer.CONFIGURATION_JSON);
                log.debug("Loaded configuration");
                try (InputStream is = zipFile.getInputStream(entry)) {
                    String configJson = IOUtils.toString(is, StandardCharsets.UTF_8);
                    JSONObject jsonObject = new JSONObject(configJson);
                    if (jsonObject.has("vertexInputs")) {
                        log.debug("Loading computation graph.");
                        compGraph = true;
                    } else {
                        log.debug("Loading multi layer network.");
                    }

                }
            }
        }

        if (compGraph) {
            return ModelSerializer.restoreComputationGraph(new File(path));
        } else {
            return ModelSerializer.restoreMultiLayerNetwork(new File(path));
        }
    }

    return null;
}
 
Example #19
Source File: SystemInfoPrintListener.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void onForwardPass(Model model, Map<String, INDArray> activations) {
    if(!printOnForwardPass)
        return;

    SystemInfo systemInfo = new SystemInfo();
    log.info(SYSTEM_INFO);
    log.info(systemInfo.toPrettyJSON());
}
 
Example #20
Source File: BaseOptimizer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 *
 * @param conf
 * @param stepFunction
 * @param trainingListeners
 * @param model
 */
public BaseOptimizer(NeuralNetConfiguration conf, StepFunction stepFunction,
                Collection<TrainingListener> trainingListeners, Model model) {
    this.conf = conf;
    this.stepFunction = (stepFunction != null ? stepFunction : getDefaultStepFunctionForOptimizer(this.getClass()));
    this.trainingListeners = trainingListeners != null ? trainingListeners : new ArrayList<TrainingListener>();
    this.model = model;
    lineMaximizer = new BackTrackLineSearch(model, this.stepFunction, this);
    lineMaximizer.setStepMax(stepMax);
    lineMaximizer.setMaxIterations(conf.getMaxNumLineSearchIterations());
}
 
Example #21
Source File: SleepyTrainingListener.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void onBackwardPass(Model model) {
    sleep(lastBP.get(), timerBP);

    if (lastBP.get() == null)
        lastBP.set(new AtomicLong(System.currentTimeMillis()));
    else
        lastBP.get().set(System.currentTimeMillis());
}
 
Example #22
Source File: CheckpointListener.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
protected static String getModelType(Model model){
    if(model.getClass() == MultiLayerNetwork.class){
        return "MultiLayerNetwork";
    } else if(model.getClass() == ComputationGraph.class){
        return "ComputationGraph";
    } else {
        return "Model";
    }
}
 
Example #23
Source File: ModelTupleStream.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Uses the {@link ModelGuesser#loadModelGuess(InputStream)} method.
 */
protected Model restoreModel(InputStream inputStream) throws IOException {
  final File instanceDir = solrResourceLoader.getInstancePath().toFile();
  try {
    return ModelGuesser.loadModelGuess(inputStream, instanceDir);
  } catch (Exception e) {
    throw new IOException("Failed to restore model from given file (" + serializedModelFileName + ")", e);
  }
}
 
Example #24
Source File: SystemInfoFilePrintListener.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void onForwardPass(Model model, Map<String, INDArray> activations) {
    if(!printOnForwardPass || printFileTarget == null)
        return;

    writeFileWithMessage("forward pass");

}
 
Example #25
Source File: CheckpointListener.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
protected static int getEpoch(Model model) {
    if (model instanceof MultiLayerNetwork) {
        return ((MultiLayerNetwork) model).getLayerWiseConfigurations().getEpochCount();
    } else if (model instanceof ComputationGraph) {
        return ((ComputationGraph) model).getConfiguration().getEpochCount();
    } else {
        return model.conf().getEpochCount();
    }
}
 
Example #26
Source File: SystemInfoPrintListener.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void onGradientCalculation(Model model) {
    if(!printOnGradientCalculation)
        return;

    SystemInfo systemInfo = new SystemInfo();
    log.info(SYSTEM_INFO);
    log.info(systemInfo.toPrettyJSON());
}
 
Example #27
Source File: CheckpointListener.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void onEpochEnd(Model model) {
    int epochsDone = getEpoch(model) + 1;
    if(saveEveryNEpochs != null && epochsDone > 0 && epochsDone % saveEveryNEpochs == 0){
        //Save:
        saveCheckpoint(model);
    }
    //General saving conditions: don't need to check here - will check in iterationDone
}
 
Example #28
Source File: FailureTestingListener.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
    boolean b = false;
    for(FailureTrigger ft : triggers)
        b |= ft.triggerFailure(callType, iteration, epoch, model);
    return b;
}
 
Example #29
Source File: InplaceParallelInference.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
protected synchronized Model[] getCurrentModelsFromWorkers() {
    val models = new Model[holders.size()];
    int cnt = 0;
    for (val h:holders) {
        models[cnt++] = h.sourceModel;
    }

    return models;
}
 
Example #30
Source File: FailureTestingListener.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public void onForwardPass(Model model, List<INDArray> activations) {
    call(CallType.FORWARD_PASS, model);
}