package com.jstarcraft.ai.model.neuralnetwork; import java.util.HashMap; import java.util.Map; import java.util.concurrent.Future; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.MergeVertex; import org.deeplearning4j.nn.conf.layers.EmbeddingLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.AbstractLayer; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Assert; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; import com.jstarcraft.ai.environment.EnvironmentContext; import com.jstarcraft.ai.environment.EnvironmentFactory; import com.jstarcraft.ai.math.MathUtility; import com.jstarcraft.ai.math.structure.MathCache; import com.jstarcraft.ai.math.structure.Nd4jCache; import com.jstarcraft.ai.math.structure.matrix.MathMatrix; import com.jstarcraft.ai.math.structure.matrix.Nd4jMatrix; import com.jstarcraft.ai.model.neuralnetwork.activation.IdentityActivationFunction; import com.jstarcraft.ai.model.neuralnetwork.layer.EmbedLayer; import com.jstarcraft.ai.model.neuralnetwork.layer.Layer; import com.jstarcraft.ai.model.neuralnetwork.layer.ParameterConfigurator; import com.jstarcraft.ai.model.neuralnetwork.layer.WeightLayer; import com.jstarcraft.ai.model.neuralnetwork.learn.SgdLearner; import com.jstarcraft.ai.model.neuralnetwork.loss.MSELossFunction; import com.jstarcraft.ai.model.neuralnetwork.normalization.IgnoreNormalizer; import com.jstarcraft.ai.model.neuralnetwork.optimization.StochasticGradientOptimizer; import com.jstarcraft.ai.model.neuralnetwork.parameter.CopyParameterFactory; import com.jstarcraft.ai.model.neuralnetwork.schedule.ConstantSchedule; import com.jstarcraft.ai.model.neuralnetwork.schedule.Schedule; import com.jstarcraft.ai.model.neuralnetwork.vertex.LayerVertex; import com.jstarcraft.ai.model.neuralnetwork.vertex.Nd4jVertex; import com.jstarcraft.ai.model.neuralnetwork.vertex.transformation.HorizontalAttachVertex; import com.jstarcraft.core.utility.RandomUtility; public class GraphTestCase { private final static float learnRatio = 0.01F; private final static float l1Regularization = 0.01F; private final static float l2Regularization = 0.05F; private static MathMatrix getMatrix(MathCache factory, INDArray array) { MathMatrix matrix = factory.makeMatrix(array.rows(), array.columns()); matrix.copyMatrix(new Nd4jMatrix(array), false); return matrix; } private static Map<String, ParameterConfigurator> getConfigurators(MathCache factory, AbstractLayer<?> layer) { Map<String, ParameterConfigurator> configurators = new HashMap<>(); CopyParameterFactory weight = new CopyParameterFactory(getMatrix(factory, layer.getParam(DefaultParamInitializer.WEIGHT_KEY))); configurators.put(WeightLayer.WEIGHT_KEY, new ParameterConfigurator(l1Regularization, l2Regularization, weight)); CopyParameterFactory bias = new CopyParameterFactory(getMatrix(factory, layer.getParam(DefaultParamInitializer.BIAS_KEY))); configurators.put(WeightLayer.BIAS_KEY, new ParameterConfigurator(l1Regularization, l2Regularization, bias)); return configurators; } private ComputationGraph getOldFunction() { NeuralNetConfiguration.Builder netBuilder = new NeuralNetConfiguration.Builder(); // 设置随机种子 netBuilder.seed(6); netBuilder.setL1(l1Regularization); netBuilder.setL1Bias(l1Regularization); netBuilder.setL2(l2Regularization); netBuilder.setL2Bias(l2Regularization); netBuilder.weightInit(WeightInit.XAVIER_UNIFORM); netBuilder.updater(new Sgd(learnRatio)).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); GraphBuilder graphBuilder = netBuilder.graphBuilder(); graphBuilder.addInputs("leftInput", "rightInput"); graphBuilder.addLayer("leftEmbed", new EmbeddingLayer.Builder().nIn(5).nOut(5).hasBias(true).activation(Activation.IDENTITY).build(), "leftInput"); graphBuilder.addLayer("rightEmbed", new EmbeddingLayer.Builder().nIn(5).nOut(5).hasBias(true).activation(Activation.IDENTITY).build(), "rightInput"); graphBuilder.addVertex("embed", new MergeVertex(), "leftEmbed", "rightEmbed"); graphBuilder.addLayer("output", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(10).nOut(1).build(), "embed"); graphBuilder.setOutputs("output"); ComputationGraphConfiguration configuration = graphBuilder.build(); ComputationGraph graph = new ComputationGraph(configuration); graph.init(); return graph; } private Graph getNewFunction(MathCache factory, ComputationGraph computationGraph) { Schedule schedule = new ConstantSchedule(learnRatio); GraphConfigurator configurator = new GraphConfigurator(); Layer leftEmbed = new EmbedLayer(5, 5, factory, getConfigurators(factory, (AbstractLayer<?>) computationGraph.getLayer("leftEmbed")), new IdentityActivationFunction()); Layer rightEmbed = new EmbedLayer(5, 5, factory, getConfigurators(factory, (AbstractLayer<?>) computationGraph.getLayer("rightEmbed")), new IdentityActivationFunction()); configurator.connect(new LayerVertex("leftEmbed", factory, leftEmbed, new SgdLearner(schedule), new IgnoreNormalizer())); configurator.connect(new LayerVertex("rightEmbed", factory, rightEmbed, new SgdLearner(schedule), new IgnoreNormalizer())); configurator.connect(new HorizontalAttachVertex("embed", factory), "leftEmbed", "rightEmbed"); configurator.connect(new Nd4jVertex("nd4j", factory, true), "embed"); Layer weightLayer = new WeightLayer(10, 1, factory, getConfigurators(factory, (AbstractLayer<?>) computationGraph.getLayer("output")), new IdentityActivationFunction()); configurator.connect(new LayerVertex("output", factory, weightLayer, new SgdLearner(schedule), new IgnoreNormalizer()), "nd4j"); Graph graph = new Graph(configurator, new StochasticGradientOptimizer(), new MSELossFunction()); return graph; } private static boolean equalMatrix(MathMatrix matrix, INDArray array) { for (int row = 0; row < matrix.getRowSize(); row++) { for (int column = 0; column < matrix.getColumnSize(); column++) { if (!MathUtility.equal(matrix.getValue(row, column), array.getFloat(row, column))) { return false; } } } return true; } @Test public void testPropagate() throws Exception { MathCache factory = new Nd4jCache(); EnvironmentContext context = EnvironmentFactory.getContext(); Future<?> task = context.doTask(() -> { ComputationGraph oldGraph = getOldFunction(); Graph graph = getNewFunction(factory, oldGraph); int size = 5; INDArray oldLeftInputs = Nd4j.zeros(size, 1); INDArray oldRightInputs = Nd4j.zeros(size, 1); INDArray oldMarks = Nd4j.zeros(size, 1).assign(5); for (int point = 0; point < 5; point++) { oldLeftInputs.put(point, 0, RandomUtility.randomInteger(5)); oldRightInputs.put(point, 0, RandomUtility.randomInteger(5)); } for (int index = 0; index < 50; index++) { oldGraph.setInputs(oldLeftInputs, oldRightInputs); oldGraph.setLabels(oldMarks); // 设置fit过程的迭代次数 for (int iteration = 0; iteration < 2; iteration++) { oldGraph.fit(); double oldScore = oldGraph.score(); System.out.println(oldScore); } } INDArray oldOutputs = oldGraph.outputSingle(oldLeftInputs, oldRightInputs); System.out.println(oldOutputs); AffinityManager manager = Nd4j.getAffinityManager(); manager.attachThreadToDevice(Thread.currentThread(), 0); MathMatrix newLeftInputs = getMatrix(factory, oldLeftInputs); MathMatrix newRightInputs = getMatrix(factory, oldRightInputs); MathMatrix newMarks = getMatrix(factory, oldMarks); MathMatrix newOutputs = getMatrix(factory, oldOutputs); for (int index = 0; index < 50; index++) { double newScore = graph.practice(2, new MathMatrix[] { newLeftInputs, newRightInputs }, new MathMatrix[] { newMarks }); System.out.println(newScore); } graph.predict(new MathMatrix[] { newLeftInputs, newRightInputs }, new MathMatrix[] { newOutputs }); System.out.println(newOutputs); Assert.assertTrue(equalMatrix(newOutputs, oldOutputs)); }); task.get(); } }