Java Code Examples for org.nd4j.autodiff.samediff.SameDiff#execAndEndResult()

The following examples show how to use org.nd4j.autodiff.samediff.SameDiff#execAndEndResult() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: GradCheckMisc.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testReshapeGradient() {
    int[] origShape = new int[]{3, 4, 5};

    for (int[] toShape : new int[][]{{3, 4 * 5}, {3 * 4, 5}, {1, 3 * 4 * 5}, {3 * 4 * 5, 1}}) {
        for (Pair<INDArray, String> p : NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, origShape)) {
            INDArray inArr = p.getFirst().muli(100);

            SameDiff sd = SameDiff.create();
            SDVariable in = sd.var("in", inArr);
            SDVariable reshape = sd.reshape(in, toShape);
            //Using stdev here: mean/sum would backprop the same gradient for each input...
            SDVariable stdev = sd.standardDeviation("out", reshape, true);

            INDArray out = sd.execAndEndResult();
            INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE);
            assertEquals(expOut, out);

            String msg = "toShape=" + Arrays.toString(toShape) + ", source=" + p.getSecond();
            boolean ok = GradCheckUtil.checkGradients(sd);
            assertTrue(msg, ok);
        }
    }
}
 
Example 2
Source File: GradCheckMisc.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testPermuteGradient() {
    int[] origShape = new int[]{3, 4, 5};

    for (int[] perm : new int[][]{{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}}) {
        for (Pair<INDArray, String> p : NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, origShape)) {
            INDArray inArr = p.getFirst().muli(100);

            SameDiff sd = SameDiff.create();
            SDVariable in = sd.var("in", inArr);
            SDVariable permute = sd.f().permute(in, perm);
            //Using stdev here: mean/sum would backprop the same gradient for each input...
            SDVariable stdev = sd.standardDeviation("out", permute, true);

            INDArray out = sd.execAndEndResult();
            INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE);
            assertEquals(expOut, out);

            String msg = "permute=" + Arrays.toString(perm) + ", source=" + p.getSecond();
            boolean ok = GradCheckUtil.checkGradients(sd);
            assertTrue(msg, ok);
        }
    }
}
 
Example 3
Source File: LoadTensorFlowMNISTMLP.java    From dl4j-tutorials with MIT License 5 votes vote down vote up
public static void main(String[] args) throws Exception {
    final String FROZEN_MLP = new ClassPathResource(BASE_DIR + "/frozen_model.pb").getFile().getPath();

    //Load placeholder inputs and corresponding predictions generated from tensorflow
    Map<String, INDArray> inputsPredictions = readPlaceholdersAndPredictions();

    //Load the graph into samediff
    SameDiff graph = TFGraphMapper.getInstance().importGraph(new File(FROZEN_MLP));
    //libnd4j executor
    //running with input_a array expecting to get prediction_a
    graph.associateArrayWithVariable(inputsPredictions.get("input_a"), graph.variableMap().get("input"));
    NativeGraphExecutioner executioner = new NativeGraphExecutioner();
    INDArray[] results = executioner.executeGraph(graph); //returns an array of the outputs
    INDArray libnd4jPred = results[0];
    System.out.println("LIBND4J exec prediction for input_a:\n" + libnd4jPred);
    if (libnd4jPred.equals(inputsPredictions.get("prediction_a"))) {
        //this is true and therefore predictions are equal
        System.out.println("Predictions are equal to tensorflow");
    } else {
        throw new RuntimeException("Predictions don't match!");
    }

    //Now to run with the samediff executor, with input_b array expecting to get prediction_b
    SameDiff graphSD = TFGraphMapper.getInstance().importGraph(new File(FROZEN_MLP)); //Reimport graph here, necessary for the 1.0 alpha release
    graphSD.associateArrayWithVariable(inputsPredictions.get("input_b"), graph.variableMap().get("input"));
    INDArray samediffPred = graphSD.execAndEndResult();
    System.out.println("SameDiff exec prediction for input_b:\n" + samediffPred);
    if (samediffPred.equals(inputsPredictions.get("prediction_b"))) {
        //this is true and therefore predictions are equal
        System.out.println("Predictions are equal to tensorflow");
    }
    //add to graph to demonstrate pytorch like capability
    System.out.println("Adding new op to graph..");
    SDVariable linspaceConstant = graphSD.var("linspace", Nd4j.linspace(1, 10, 10));
    SDVariable totalOutput = graphSD.getVariable("output").add(linspaceConstant);
    INDArray totalOutputArr = totalOutput.eval();
    System.out.println(totalOutputArr);

}
 
Example 4
Source File: GradCheckMisc.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testExpandDimsGradient() {
    val origShape = new long[]{3, 4};

    boolean first = true;
    for (int i = 0; i < 3; i++) {

        long[] expExpandShape;
        switch (i) {
            case 0:
                expExpandShape = new long[]{1, 3, 4};
                break;
            case 1:
                expExpandShape = new long[]{3, 1, 4};
                break;
            case 2:
                expExpandShape = new long[]{3, 4, 1};
                break;
            default:
                throw new RuntimeException();
        }

        for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(origShape[0], origShape[1], 12345)) {
            INDArray inArr = p.getFirst().muli(100);

            SameDiff sd = SameDiff.create();
            SDVariable in = sd.var("in", inArr);
            SDVariable expand = sd.f().expandDims(in, i);
            //Using stdev here: mean/sum would backprop the same gradient for each input...
            SDVariable stdev = sd.standardDeviation("out", expand, true);

            INDArray out = sd.execAndEndResult();
            INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE);
            assertEquals(expOut, out);

            assertArrayEquals(expExpandShape, expand.getArr().shape());
            INDArray expExpand = inArr.dup('c').reshape(expExpandShape);
            assertEquals(expExpand, expand.getArr());

            String msg = "expandDim=" + i + ", source=" + p.getSecond();
            log.info("Starting: " + msg);
            boolean ok = GradCheckUtil.checkGradients(sd);
            assertTrue(msg, ok);
        }
    }
}
 
Example 5
Source File: GradCheckMisc.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSqueezeGradient() {
    val origShape = new long[]{3, 4, 5};

    for (int i = 0; i < 3; i++) {

        val shape = origShape.clone();
        shape[i] = 1;

        for (Pair<INDArray, String> p : NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, shape)) {
            INDArray inArr = p.getFirst().muli(100);

            SameDiff sd = SameDiff.create();
            SDVariable in = sd.var("in", inArr);
            SDVariable squeeze = sd.f().squeeze(in, i);
            //Using stdev here: mean/sum would backprop the same gradient for each input...
            SDVariable stdev = sd.standardDeviation("out", squeeze, true);

            long[] expShapePostSqueeze;
            switch (i) {
                case 0:
                    expShapePostSqueeze = new long[]{4, 5};
                    break;
                case 1:
                    expShapePostSqueeze = new long[]{3, 5};
                    break;
                case 2:
                    expShapePostSqueeze = new long[]{3, 4};
                    break;
                default:
                    throw new RuntimeException();
            }

            sd.execAndEndResult();

            INDArray squeezed = squeeze.getArr();
            assertArrayEquals(expShapePostSqueeze, squeezed.shape());

            INDArray out = sd.execAndEndResult();
            INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE);
            assertEquals(expOut, out);

            String msg = "squeezeDim=" + i + ", source=" + p.getSecond();
            boolean ok = GradCheckUtil.checkGradients(sd);
            assertTrue(msg, ok);
        }
    }
}
 
Example 6
Source File: GradCheckMisc.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testGradientAutoBroadcast1() {

        Nd4j.getRandom().setSeed(12345);

        List<String> allFailed = new ArrayList<>();

        for (int dim_sz1 : new int[]{0, 1, 2}) {

            int[] in2Shape = {3, 4, 5};
            in2Shape[dim_sz1] = 1;

            for (int i = 2; i < 3; i++) {

                SameDiff sd = SameDiff.create();

                SDVariable in3 = sd.var("in3", Nd4j.rand(new int[]{3, 4, 5}));
                SDVariable in2 = sd.var("in2", in2Shape);

                SDVariable bcOp;
                String name;
                switch (i) {
                    case 0:
                        bcOp = in3.add(in2);
                        name = "add";
                        break;
                    case 1:
                        bcOp = in3.sub(in2);
                        name = "sub";
                        break;
                    case 2:
                        bcOp = in3.mul(in2);
                        name = "mul";
                        break;
                    case 3:
                        bcOp = in3.div(in2);
                        name = "div";
                        break;
                    case 4:
                        bcOp = in3.rsub(in2);
                        name = "rsub";
                        break;
                    case 5:
                        bcOp = in3.rdiv(in2);
                        name = "rdiv";
                        break;
                    case 6:
                        bcOp = sd.f().floorDiv(in3, in2);
                        name = "floordiv";
                        break;
                    case 7:
                        bcOp = sd.f().floorMod(in3, in2);
                        name = "floormod";
                        break;
                    default:
                        throw new RuntimeException();
                }

                SDVariable outVar = sd.sum(bcOp);

                String msg = "(test " + i + ": " + name + ", dimension=" + dim_sz1 + ")";
                log.info("*** Starting test: " + msg);

                INDArray in3Arr = Nd4j.randn(new int[]{3, 4, 5}).muli(100);
                INDArray in2Arr = Nd4j.randn(in2Shape).muli(100);

                sd.associateArrayWithVariable(in3Arr, in3);
                sd.associateArrayWithVariable(in2Arr, in2);

                try {
                    INDArray out = sd.execAndEndResult();
                    assertNotNull(out);
                    assertArrayEquals(new long[]{1, 1}, out.shape());

//                    System.out.println(sd.asFlatPrint());

                    boolean ok = GradCheckUtil.checkGradients(sd);
                    if (!ok) {
                        allFailed.add(msg);
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                    allFailed.add(msg + " - EXCEPTION");
                }
            }
        }

        assertEquals("Failed: " + allFailed, 0, allFailed.size());
    }
 
Example 7
Source File: GradCheckMisc.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testGradientAutoBroadcast2() {

        Nd4j.getRandom().setSeed(12345);

        List<String> allFailed = new ArrayList<>();

        for (int[] dim_sz1s : new int[][]{{0, 1}, {0, 2}, {1, 2}, {0,1,2}}) {

            int[] otherShape = {3, 4, 5};
            otherShape[dim_sz1s[0]] = 1;
            otherShape[dim_sz1s[1]] = 1;
            if(dim_sz1s.length == 3){
                otherShape[dim_sz1s[2]] = 1;
            }

            for (int i = 0; i < 6; i++) {

                SameDiff sd = SameDiff.create();

                SDVariable in3 = sd.var("in3", new int[]{3, 4, 5});
                SDVariable in2 = sd.var("inToBc", otherShape);

                String name;
                SDVariable bcOp;
                switch (i) {
                    case 0:
                        bcOp = in3.add(in2);
                        name = "add";
                        break;
                    case 1:
                        bcOp = in3.sub(in2);
                        name = "sub";
                        break;
                    case 2:
                        bcOp = in3.mul(in2);
                        name = "mul";
                        break;
                    case 3:
                        bcOp = in3.div(in2);
                        name = "div";
                        break;
                    case 4:
                        bcOp = in3.rsub(in2);
                        name = "rsub";
                        break;
                    case 5:
                        bcOp = in3.rdiv(in2);
                        name = "rdiv";
                        break;
                    case 6:
                        bcOp = sd.f().floorDiv(in3, in2);
                        name = "floordiv";
                        break;
                    case 7:
                        bcOp = sd.f().floorMod(in3, in2);
                        name = "floormod";
                        break;
                    default:
                        throw new RuntimeException();
                }

                SDVariable outVar = sd.sum(bcOp);

                String msg = "(test " + i + ": " + name + ", dimensions=" + Arrays.toString(dim_sz1s) + ")";
                log.info("*** Starting test: " + msg);

                INDArray in3Arr = Nd4j.randn(new int[]{3,4,5}).muli(100);
                INDArray in2Arr = Nd4j.randn(otherShape).muli(100);

                sd.associateArrayWithVariable(in3Arr, in3);
                sd.associateArrayWithVariable(in2Arr, in2);

                try {
                    INDArray out = sd.execAndEndResult();
                    assertNotNull(out);
                    assertArrayEquals(new long[]{1, 1}, out.shape());

//                    System.out.println(sd.asFlatPrint());

                    boolean ok = GradCheckUtil.checkGradients(sd);
                    if (!ok) {
                        allFailed.add(msg);
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                    allFailed.add(msg + " - EXCEPTION");
                }
            }
        }

        assertEquals("Failed: " + allFailed, 0, allFailed.size());
    }
 
Example 8
Source File: GradCheckLoss.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testLossSimple2d() {
        Nd4j.getRandom().setSeed(12345);

        for (String fn : new String[]{"mse", "l1", "l2", "mcxent"}) {

            for (LossFunctions.Reduction reduction : new LossFunctions.Reduction[]{
                    LossFunctions.Reduction.MEAN_BY_COUNT, LossFunctions.Reduction.MEAN_BY_WEIGHT, LossFunctions.Reduction.SUM}) {

                SameDiff sd = SameDiff.create();

                int nOut = 4;
                int minibatch = 10;
                SDVariable input = sd.var("in", new int[]{-1, nOut});
                SDVariable labels = sd.var("labels", new int[]{-1, nOut});

                INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100);
                INDArray labelsArr = Nd4j.randn(minibatch, nOut).muli(100);

                LossInfo lossInfo;
                INDArray expOut;
                switch (fn) {
                    case "mse":
                        lossInfo = LossFunctions.mse("out", input, labels, null, reduction, 1);
                        expOut = inputArr.sub(labelsArr);
                        expOut.muli(expOut);
                        expOut = expOut.mean(Integer.MAX_VALUE);
                        break;
                    case "l1":
                        lossInfo = LossFunctions.l1("out", input, labels, null, reduction, 1);
                        //L1 = sum abs error
                        expOut = Transforms.abs(inputArr.sub(labelsArr)).sum(1);
                        expOut = expOut.mean(Integer.MAX_VALUE);
                        break;
                    case "l2":
                        lossInfo = LossFunctions.l2("out", input, labels, null, reduction, 1);
                        //L2 = sum squared error
                        expOut = Transforms.pow(inputArr.sub(labelsArr), 2.0).sum(1).mean(Integer.MAX_VALUE);
                        break;
                    case "mcxent":
                        lossInfo = LossFunctions.mcxent("out", input, labels, null, reduction, 1);
                        //mcxent = sum label * log(prob)
                        expOut = labelsArr.mul(Transforms.log(inputArr)).sum(1).mean(Integer.MAX_VALUE);
                        break;
                    default:
                        throw new RuntimeException();
                }


                String msg = "test: " + lossInfo.getLossName() + ", reduction=" + reduction;
                log.info("*** Starting test: " + msg);


                sd.associateArrayWithVariable(inputArr, input);
                sd.associateArrayWithVariable(labelsArr, labels);

//            System.out.println(sd.asFlatPrint());

                INDArray out = sd.execAndEndResult();

                assertEquals(msg, expOut, out);

                System.out.println("STARTING GRADIENT CHECK");
                boolean ok = GradCheckUtil.checkGradients(sd);

                assertTrue(msg, ok);
            }
        }
    }
 
Example 9
Source File: GradCheckLoss.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testLossWeights2d() {

    String[] weightTypes = new String[]{"none", "per-example", "per-output", "per-example-output"};

    Nd4j.getRandom().setSeed(12345);

    int nOut = 4;
    int minibatch = 10;

    for (String weightType : weightTypes) {

        for (boolean binary : new boolean[]{true, false}) {  //Binary mask (like DL4J) or arbitrary weights?

            int[] weightShape;
            switch (weightType) {
                case "none":
                    weightShape = null;
                    break;
                case "per-example":
                    weightShape = new int[]{minibatch, 1};
                    break;
                case "per-output":
                    weightShape = new int[]{1, nOut};
                    break;
                case "per-example-output":
                    weightShape = new int[]{minibatch, nOut};
                    break;
                default:
                    throw new RuntimeException("Unknown type: " + weightType);
            }

            INDArray weightArr = null;
            if (!"none".equals(weightType)) {
                if (binary) {
                    weightArr = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(weightShape), 0.5));
                } else {
                    weightArr = Nd4j.rand(weightShape).muli(2.0);
                }
            }

            for (LossFunctions.Reduction reduction : new LossFunctions.Reduction[]{
                    LossFunctions.Reduction.MEAN_BY_COUNT, LossFunctions.Reduction.MEAN_BY_WEIGHT, LossFunctions.Reduction.SUM}) {

                for (String fn : new String[]{"mse", "l1", "l2", "mcxent"}) {

                    SameDiff sd = SameDiff.create();


                    SDVariable input = sd.var("in", new int[]{-1, nOut});
                    SDVariable labels = sd.var("labels", new int[]{-1, nOut});
                    SDVariable weight = null;
                    if (!"none".equals(weightType)) {
                        weight = sd.var("weights", weightArr);
                    }

                    INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100);
                    INDArray labelsArr = Nd4j.randn(minibatch, nOut).muli(100);

                    LossInfo lossInfo;
                    switch (fn) {
                        case "mse":
                            lossInfo = LossFunctions.mse("out", input, labels, weight, reduction, 1);
                            break;
                        case "l1":
                            lossInfo = LossFunctions.l1("out", input, labels, weight, reduction, 1);
                            //L1 = sum abs error
                            break;
                        case "l2":
                            lossInfo = LossFunctions.l2("out", input, labels, weight, reduction, 1);
                            //L2 = sum squared error
                            break;
                        case "mcxent":
                            lossInfo = LossFunctions.mcxent("out", input, labels, weight, reduction, 1);
                            //mcxent = sum label * log(prob)
                            break;
                        default:
                            throw new RuntimeException();
                    }


                    String msg = "lossFn=" + fn + ", reduction=" + reduction + ", weightType=" + weightType + ", binaryWeight=" + binary;
                    log.info("*** Starting test: " + msg);

                    sd.associateArrayWithVariable(inputArr, input);
                    sd.associateArrayWithVariable(labelsArr, labels);
                    if (weight != null) {
                        sd.associateArrayWithVariable(weightArr, weight);
                    }

                    INDArray out = sd.execAndEndResult();
                    assertEquals(1, out.length());

                    boolean ok = GradCheckUtil.checkGradients(sd);

                    assertTrue(msg, ok);
                }
            }
        }
    }
}
 
Example 10
Source File: GradCheckReductions.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testReductionGradients1() {
        //Test reductions: final, but *not* the only function
        Nd4j.getRandom().setSeed(12345);

        List<String> allFailed = new ArrayList<>();

        for (int dim : new int[]{0, Integer.MAX_VALUE}) {    //These two cases are equivalent here

            for (int i = 0; i < 10; i++) {

                SameDiff sd = SameDiff.create();

                int nOut = 4;
                int minibatch = 10;
                SDVariable input = sd.var("in", new int[]{-1, nOut});
                SDVariable label = sd.var("label", new int[]{-1, nOut});

                SDVariable diff = input.sub(label);
                SDVariable sqDiff = diff.mul(diff);
                SDVariable msePerEx = sd.mean("msePerEx", sqDiff, 1);

                SDVariable loss;
                String name;
                switch (i) {
                    case 0:
                        loss = sd.mean("loss", msePerEx, dim);
                        name = "mean";
                        break;
                    case 1:
                        loss = sd.sum("loss", msePerEx, dim);
                        name = "sum";
                        break;
                    case 2:
                        loss = sd.standardDeviation("loss", msePerEx, true, dim);
                        name = "stdev";
                        break;
                    case 3:
                        loss = sd.min("loss", msePerEx, dim);
                        name = "min";
                        break;
                    case 4:
                        loss = sd.max("loss", msePerEx, dim);
                        name = "max";
                        break;
                    case 5:
                        loss = sd.variance("loss", msePerEx, true, dim);
                        name = "variance";
                        break;
                    case 6:
                        loss = sd.prod("loss", msePerEx, dim);
                        name = "prod";
                        break;
                    case 7:
                        loss = sd.norm1("loss", msePerEx, dim);
                        name = "norm1";
                        break;
                    case 8:
                        loss = sd.norm2("loss", msePerEx, dim);
                        name = "norm2";
                        break;
                    case 9:
                        loss = sd.normmax("loss", msePerEx, dim);
                        name = "normmax";
                        break;
                    default:
                        throw new RuntimeException();
                }


                String msg = "(test " + i + " - " + name + ", dimension=" + dim + ")";
                log.info("*** Starting test: " + msg);

                INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100);
                INDArray labelArr = Nd4j.randn(minibatch, nOut).muli(100);

                sd.associateArrayWithVariable(inputArr, input);
                sd.associateArrayWithVariable(labelArr, label);

                try {
                    INDArray out = sd.execAndEndResult();
                    assertNotNull(out);
                    assertArrayEquals(new int[]{1, 1}, out.shape());

//                    System.out.println(sd.asFlatPrint());

                    boolean ok = GradCheckUtil.checkGradients(sd);
                    if (!ok) {
                        allFailed.add(msg);
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                    allFailed.add(msg + " - EXCEPTION");
                }
            }
        }

        assertEquals("Failed: " + allFailed, 0, allFailed.size());
    }
 
Example 11
Source File: GradCheckReductions.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testReduce3() {

    Nd4j.getRandom().setSeed(12345);

    int d0 = 3;
    int d1 = 4;
    int d2 = 5;

    List<String> allFailed = new ArrayList<>();
    for (int[] reduceDims : new int[][]{{Integer.MAX_VALUE}, {0, 1, 2}, {0}, {1}, {2}, {0, 1}, {0, 2}, {1, 2}}) {
        for (int i = 0; i < 6; i++) {

            SameDiff sd = SameDiff.create();
            sd.setLogExecution(false);


            SDVariable in = sd.var("in", new int[]{-1, d1, d2});
            SDVariable in2 = sd.var("in2", new int[]{-1, d1, d2});

            INDArray inArr = Nd4j.randn(new int[]{d0, d1, d2}).muli(100);
            INDArray in2Arr = Nd4j.randn(inArr.shape()).muli(100);

            SDVariable reduced;
            String name;
            switch (i) {
                case 0:
                    reduced = sd.manhattanDistance(in, in2, reduceDims);
                    name = "manhattan";
                    break;
                case 1:
                    reduced = sd.euclideanDistance(in, in2, reduceDims);
                    name = "euclidean";
                    break;
                case 2:
                    reduced = sd.cosineSimilarity(in, in2, reduceDims);
                    name = "cosine";
                    break;
                case 3:
                    reduced = sd.cosineDistance(in, in2, reduceDims);
                    name = "cosinedistance";
                    break;
                case 4:
                    reduced = sd.hammingDistance(in, in2, reduceDims);
                    name = "hamming";
                    break;
                case 5:
                    name = "jaccard";
                    reduced = sd.jaccardDistance(name, in, in2, reduceDims);
                    inArr.divi(100).addi(0.1);
                    in2Arr.divi(100).addi(0.1);
                    break;
                default:
                    throw new RuntimeException();
            }

            //Sum: note that this should be a no-op for the full array cases
            SDVariable sum = sd.sum(reduced, Integer.MAX_VALUE);


            String msg = "(test " + i + " - " + name + ", dimensions=" + Arrays.toString(reduceDims) + ")";
            log.info("*** Starting test: " + msg);

            sd.associateArrayWithVariable(inArr, in);
            sd.associateArrayWithVariable(in2Arr, in2);

            sd.execAndEndResult();

            // FIXME: we can't swallow exceptions here now, but once release out and stuff stabilized - we can
            //try {
                boolean ok = GradCheckUtil.checkGradients(sd, 1e-5, 1e-5, 1e-4, true, false);
                if (!ok) {
                    allFailed.add(msg);
                }
            /*
            } catch (Exception e) {
                e.printStackTrace();
                allFailed.add(msg + " - EXCEPTION");
            }
            */
        }
    }

    assertEquals("Failed: " + allFailed, 0, allFailed.size());
}
 
Example 12
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
@Ignore
public void importGraph4() throws Exception {
    SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_multiply.pb.txt").getInputStream());

    assertNotNull(graph);

    val p0 = Nd4j.create(10, 10).assign(2.0);
    val p1 = Nd4j.create(10, 10).assign(3.0);

    graph.associateArrayWithVariable(p0,graph.variableMap().get("Placeholder"));
    graph.associateArrayWithVariable(p1, graph.variableMap().get("Placeholder_1"));


    graph.var("Placeholder", p0);
    graph.var("Placeholder_1", p1);

    val res = graph.execAndEndResult();



    assertEquals(6.0, res.meanNumber().doubleValue(), 1e-5);
}
 
Example 13
Source File: BasicGraphExecutioner.java    From nd4j with Apache License 2.0 2 votes vote down vote up
/**
 * This method executes given graph and returns results
 *
 * @param graph
 * @return
 */
@Override
public INDArray[] executeGraph(SameDiff graph, ExecutorConfiguration configuration) {
    return new INDArray[]{graph.execAndEndResult()};
}