/* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package ai.djl.examples.training; import ai.djl.Device; import ai.djl.MalformedModelException; import ai.djl.training.TrainingResult; import ai.djl.translate.TranslateException; import java.io.IOException; import org.apache.commons.cli.ParseException; import org.testng.Assert; import org.testng.SkipException; import org.testng.annotations.Test; public class TrainPikachuTest { @Test public void testDetection() throws IOException, MalformedModelException, TranslateException, ParseException { // this is nightly test if (!Boolean.getBoolean("nightly")) { throw new SkipException("Nightly only"); } String[] args; float expectedLoss = 0; int expectedMinNumber = 0; int expectedMaxNumber = 0; if (Device.getGpuCount() > 0) { args = new String[] {"-e", "20", "-b", "32", "-g", "1"}; expectedLoss = 2.5e-3f; expectedMaxNumber = 15; expectedMinNumber = 6; } else { // test train 1 epoch and predict workflow works on CPU args = new String[] {"-e", "1", "-m", "1", "-b", "32"}; } // test train TrainingResult result = TrainPikachu.runExample(args); if (expectedLoss > 0) { Assert.assertTrue(result.getValidateLoss() < expectedLoss); } // test predict int numberOfPikachus = TrainPikachu.predict("build/model", "src/test/resources/pikachu.jpg"); if (expectedMinNumber > 0) { Assert.assertTrue(numberOfPikachus >= expectedMinNumber); Assert.assertTrue(numberOfPikachus <= expectedMaxNumber); } } }