/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.nn.modelimport.keras.e2e; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.layers.custom.KerasLRN; import org.deeplearning4j.nn.modelimport.keras.layers.custom.KerasPoolHelper; import org.deeplearning4j.util.ModelSerializer; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import java.io.File; import java.net.URL; /** * Test import of Keras custom layers. Must be run manually, since user must download weights and config from * http://blob.deeplearning4j.org/models/googlenet_keras_weights.h5 * http://blob.deeplearning4j.org/models/googlenet_config.json * * @author Justin Long (crockpotveggies) */ @Slf4j public class KerasCustomLayerTest extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); // run manually, might take a long time to load (too long for unit tests) @Ignore @Test public void testCustomLayerImport() throws Exception { // file paths String kerasWeightsAndConfigUrl = DL4JResources.getURLString("googlenet_keras_weightsandconfig.h5"); File cachedKerasFile = testDir.newFile("googlenet_keras_weightsandconfig.h5"); String outputPath = testDir.newFile("googlenet_dl4j_inference.zip").getAbsolutePath(); KerasLayer.registerCustomLayer("PoolHelper", KerasPoolHelper.class); KerasLayer.registerCustomLayer("LRN", KerasLRN.class); // download file if (!cachedKerasFile.exists()) { log.info("Downloading model to " + cachedKerasFile.toString()); FileUtils.copyURLToFile(new URL(kerasWeightsAndConfigUrl), cachedKerasFile); cachedKerasFile.deleteOnExit(); } org.deeplearning4j.nn.api.Model importedModel = KerasModelImport.importKerasModelAndWeights(cachedKerasFile.getAbsolutePath()); ModelSerializer.writeModel(importedModel, outputPath, false); ComputationGraph serializedModel = ModelSerializer.restoreComputationGraph(outputPath); log.info(serializedModel.summary()); } }