org.nd4j.linalg.primitives.Pair Java Examples

The following examples show how to use org.nd4j.linalg.primitives.Pair. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: SameDiffTests.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testExpandSqueezeChain() {

    val origShape = new long[]{3, 4};

    for (int i = 0; i < 3; i++) {
        for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(origShape[0], origShape[1], 12345)) {
            INDArray inArr = p.getFirst().muli(100);

            SameDiff sd = SameDiff.create();
            SDVariable in = sd.var("in", inArr);
            SDVariable expand = sd.expandDims(in, i);
            SDVariable squeeze = sd.squeeze(expand, i);

            INDArray out = sd.execAndEndResult();

            String msg = "expand/Squeeze=" + i + ", source=" + p.getSecond();

            assertEquals(msg, out, inArr);  //expand -> squeeze: should be opposite ops
        }
    }
}
 
Example #2
Source File: NDArrayCreationUtil.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/** Get an array of INDArrays (2d) all with the specified shape. Pair<INDArray,String> returned to aid
 * debugging: String contains information on how to reproduce the matrix (i.e., which function, and arguments)
 * Each NDArray in the returned array has been obtained by applying an operation such as transpose, tensorAlongDimension,
 * etc to an original array.
 */
public static List<Pair<INDArray, String>> getAllTestMatricesWithShape(long rows, long cols, long seed) {
    List<Pair<INDArray, String>> all = new ArrayList<>();
    Nd4j.getRandom().setSeed(seed);
    all.add(new Pair<>(Nd4j.linspace(1L, rows * cols, rows * cols).reshape(rows, cols),
                    "Nd4j..linspace(1,rows * cols,rows * cols).reshape(rows,cols)"));

    all.add(getTransposedMatrixWithShape(rows, cols, seed));

    all.addAll(getSubMatricesWithShape(rows, cols, seed));

    all.addAll(getTensorAlongDimensionMatricesWithShape(rows, cols, seed));

    all.add(getPermutedWithShape(rows, cols, seed));
    all.add(getReshapedWithShape(rows, cols, seed));

    return all;
}
 
Example #3
Source File: NDArrayCreationUtil.java    From nd4j with Apache License 2.0 6 votes vote down vote up
public static Pair<INDArray, String> getReshapedWithShape(char ordering, long rows, long cols, long seed) {
    Nd4j.getRandom().setSeed(seed);
    long[] origShape = new long[3];
    if (rows % 2 == 0) {
        origShape[0] = rows / 2;
        origShape[1] = cols;
        origShape[2] = 2;
    } else if (cols % 2 == 0) {
        origShape[0] = rows;
        origShape[1] = cols / 2;
        origShape[2] = 2;
    } else {
        origShape[0] = 1;
        origShape[1] = rows;
        origShape[2] = cols;
    }

    int len = ArrayUtil.prod(origShape);
    INDArray orig = Nd4j.linspace(1, len, len).reshape(ordering, origShape);
    return new Pair<>(orig.reshape(ordering, rows, cols),
                    "getReshapedWithShape(" + rows + "," + cols + "," + seed + ")");
}
 
Example #4
Source File: Nd4jTestsComparisonFortran.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testAddSubtractWithOpsCommonsMath() {
    List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED);
    List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED);
    for (int i = 0; i < first.size(); i++) {
        for (int j = 0; j < second.size(); j++) {
            Pair<INDArray, String> p1 = first.get(i);
            Pair<INDArray, String> p2 = second.get(j);
            String errorMsg1 = getTestWithOpsErrorMsg(i, j, "add", p1, p2);
            String errorMsg2 = getTestWithOpsErrorMsg(i, j, "sub", p1, p2);
            boolean addFail = CheckUtil.checkAdd(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6);
            assertTrue(errorMsg1, addFail);
            boolean subFail = CheckUtil.checkSubtract(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6);
            assertTrue(errorMsg2, subFail);
        }
    }
}
 
Example #5
Source File: TestNdArrReadWriteTxt.java    From nd4j with Apache License 2.0 6 votes vote down vote up
public static void compareArrays(int rank, char ordering) {
    List<Pair<INDArray, String>> all = NDArrayCreationUtil.getTestMatricesWithVaryingShapes(rank,ordering);
    Iterator<Pair<INDArray,String>> iter = all.iterator();
    while (iter.hasNext()) {
        Pair<INDArray,String> currentPair = iter.next();
        INDArray origArray = currentPair.getFirst();
        //adding elements outside the bounds where print switches to scientific notation
        origArray.tensorAlongDimension(0,0).muli(0).addi(100000);
        origArray.putScalar(0,10001.1234);
        log.info("\nChecking shape ..." + currentPair.getSecond());
        log.info("\n"+ origArray.dup('c').toString());
        Nd4j.writeTxt(origArray, "someArr.txt");
        INDArray readBack = Nd4j.readTxt("someArr.txt");
        assertEquals("\nNot equal on shape " + ArrayUtils.toString(origArray.shape()), origArray, readBack);
        try {
            Files.delete(Paths.get("someArr.txt"));
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}
 
Example #6
Source File: PLNetDyadRanker.java    From AILibs with GNU Affero General Public License v3.0 6 votes vote down vote up
@Override
public IRanking<IDyad> predict(final IDyadRankingInstance xTest) throws PredictionException, InterruptedException {
	if (this.plNet == null) {
		int dyadSize = (xTest.getLabel().get(0).getContext().length()) + (xTest.getLabel().get(0).getAlternative().length());
		this.plNet = this.createNetwork(dyadSize);
		this.plNet.init();
	}

	List<Pair<IDyad, Double>> dyadUtilityPairs = new ArrayList<>(xTest.getNumAttributes());
	for (IDyad dyad : xTest) {
		INDArray plNetInput = this.dyadToVector(dyad);
		double plNetOutput = this.plNet.output(plNetInput).getDouble(0);
		dyadUtilityPairs.add(new Pair<>(dyad, plNetOutput));
	}
	// sort the instance in descending order of utility values
	Collections.sort(dyadUtilityPairs, Comparator.comparing(p -> -p.getRight()));
	List<IDyad> ranking = new ArrayList<>();
	dyadUtilityPairs.stream().map(Pair::getLeft).forEach(ranking::add);
	return new Ranking<>(ranking);
}
 
Example #7
Source File: PipelineImageTransform.java    From DataVec with Apache License 2.0 6 votes vote down vote up
/**
 * Takes an image and executes a pipeline of combined transforms.
 *
 * @param image to transform, null == end of stream
 * @param random object to use (or null for deterministic)
 * @return transformed image
 */
@Override
protected ImageWritable doTransform(ImageWritable image, Random random) {
    if (shuffle) {
        Collections.shuffle(imageTransforms);
    }

    currentTransforms.clear();

    // execute each item in the pipeline
    for (Pair<ImageTransform, Double> tuple : imageTransforms) {
        if (tuple.getSecond() == 1.0 || rng.nextDouble() < tuple.getSecond()) { // probability of execution
            currentTransforms.add(tuple.getFirst());
            image = random != null ? tuple.getFirst().transform(image, random)
                    : tuple.getFirst().transform(image);
        }
    }

    return image;
}
 
Example #8
Source File: GradCheckMisc.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testPermuteGradient() {
    int[] origShape = new int[]{3, 4, 5};

    for (int[] perm : new int[][]{{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}}) {
        for (Pair<INDArray, String> p : NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, origShape)) {
            INDArray inArr = p.getFirst().muli(100);

            SameDiff sd = SameDiff.create();
            SDVariable in = sd.var("in", inArr);
            SDVariable permute = sd.f().permute(in, perm);
            //Using stdev here: mean/sum would backprop the same gradient for each input...
            SDVariable stdev = sd.standardDeviation("out", permute, true);

            INDArray out = sd.execAndEndResult();
            INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE);
            assertEquals(expOut, out);

            String msg = "permute=" + Arrays.toString(perm) + ", source=" + p.getSecond();
            boolean ok = GradCheckUtil.checkGradients(sd);
            assertTrue(msg, ok);
        }
    }
}
 
Example #9
Source File: FunctionalUtilsTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testCoGroup() {
    List<Pair<String,String>> leftMap = new ArrayList<>();
    List<Pair<String,String>> rightMap = new ArrayList<>();

    leftMap.add(Pair.of("cat","adam"));
    leftMap.add(Pair.of("dog","adam"));

    rightMap.add(Pair.of("fish","alex"));
    rightMap.add(Pair.of("cat","alice"));
    rightMap.add(Pair.of("dog","steve"));

    //[(fish,([],[alex])), (dog,([adam],[steve])), (cat,([adam],[alice]))]
    Map<String,Pair<List<String>,List<String>>> assertion = new HashMap<>();
    assertion.put("cat",Pair.of(Arrays.asList("adam"),Arrays.asList("alice")));
    assertion.put("dog",Pair.of(Arrays.asList("adam"),Arrays.asList("steve")));
    assertion.put("fish",Pair.of(Collections.<String>emptyList(),Arrays.asList("alex")));

    Map<String, Pair<List<String>, List<String>>> cogroup = FunctionalUtils.cogroup(leftMap, rightMap);
    assertEquals(assertion,cogroup);

}
 
Example #10
Source File: SameDiffTests.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testMseBackwards() {

    SameDiff sd = SameDiff.create();

    int nOut = 4;
    int minibatch = 3;
    SDVariable input = sd.var("in", new long[]{-1, nOut});
    SDVariable label = sd.var("label", new long[]{-1, nOut});

    SDVariable diff = input.sub(label);
    SDVariable sqDiff = diff.mul(diff);
    SDVariable msePerEx = sd.mean("msePerEx", sqDiff, 1);
    SDVariable avgMSE = sd.mean("loss", msePerEx, 0);

    INDArray inputArr = Nd4j.rand(minibatch, nOut);
    INDArray labelArr = Nd4j.rand(minibatch, nOut);

    sd.associateArrayWithVariable(inputArr, input);
    sd.associateArrayWithVariable(labelArr, label);

    INDArray result = sd.execAndEndResult();
    assertEquals(1, result.length());

    Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> p = sd.execBackwards();
}
 
Example #11
Source File: JsonMappers.java    From DataVec with Apache License 2.0 6 votes vote down vote up
/**
 * Register a set of classes (Layer, GraphVertex, InputPreProcessor, IActivation, ILossFunction, ReconstructionDistribution
 * ONLY) for JSON deserialization, with custom names.<br>
 * Using this method directly should never be required (instead: use {@link #registerLegacyCustomClassesForJSON(Class[])}
 * but is added in case it is required in non-standard circumstances.
 */
public static void registerLegacyCustomClassesForJSON(List<Pair<String,Class>> classes){
    for(Pair<String,Class> p : classes){
        String s = p.getFirst();
        Class c = p.getRight();
        //Check if it's a valid class to register...
        boolean found = false;
        for( Class<?> c2 : REGISTERABLE_CUSTOM_CLASSES){
            if(c2.isAssignableFrom(c)){
                Map<String,String> map = LegacyMappingHelper.legacyMappingForClass(c2);
                map.put(p.getFirst(), p.getSecond().getName());
                found = true;
            }
        }

        if(!found){
            throw new IllegalArgumentException("Cannot register class for legacy JSON deserialization: class " +
                    c.getName() + " is not a subtype of classes " + REGISTERABLE_CUSTOM_CLASSES);
        }
    }
}
 
Example #12
Source File: DeviceTADManager.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * This method removes all cached shape buffers
 */
@Override
public void purgeBuffers() {
    log.info("Purging TAD buffers...");

    tadCache = new ArrayList<>();

    int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();

    for (int i = 0; i < numDevices; i++) {
        log.info("Resetting device: [{}]", i);
        tadCache.add(i, new ConcurrentHashMap<TadDescriptor, Pair<DataBuffer, DataBuffer>>());
    }

    super.purgeBuffers();
}
 
Example #13
Source File: DeepFMOutputLayer.java    From jstarcraft-rns with Apache License 2.0 6 votes vote down vote up
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray previous, LayerWorkspaceMgr workspaceMgr) {
    assertInputSet(true);
    Pair<Gradient, INDArray> pair = getGradientsAndDelta(preOutput2d(true, workspaceMgr), workspaceMgr); // Returns Gradient and delta^(this), not Gradient and epsilon^(this-1)
    INDArray delta = pair.getSecond();

    INDArray w = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, true, workspaceMgr);
    INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new long[] { w.size(0), delta.size(0) }, 'f');
    epsilonNext = w.mmuli(delta.transpose(), epsilonNext).transpose();

    // Normally we would clear weightNoiseParams here - but we want to reuse them
    // for forward + backward + score
    // So this is instead done in MultiLayerNetwork/CompGraph backprop methods

    epsilonNext = backpropDropOutIfPresent(epsilonNext);
    return new Pair<>(pair.getFirst(), epsilonNext);
}
 
Example #14
Source File: DeepFMSumVertex.java    From jstarcraft-rns with Apache License 2.0 6 votes vote down vote up
@Override
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) {
    if (!canDoBackward()) {
        throw new IllegalStateException("Cannot do backward pass: errors not set");
    }
    // epsilons[index] => {batchSize, numberOfEmbeds}
    INDArray[] epsilons = new INDArray[inputs.length];
    // epsilon => {batchSize, 1}
    // inputs[index] => {batchSize, numberOfEmbeds}
    // TODO 如何通过inputs[index]与epsilon求导epsilons[index]
    INDArray output = doForward(true, workspaceMgr);
    for (int index = 0; index < inputs.length; index++) {
        epsilons[index] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, inputs[index]);
        epsilons[index].muliColumnVector(epsilon).diviColumnVector(output);
    }
    return new Pair<>(null, epsilons);
}
 
Example #15
Source File: DeepFMProductVertex.java    From jstarcraft-rns with Apache License 2.0 6 votes vote down vote up
@Override
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) {
    if (!canDoBackward()) {
        throw new IllegalStateException("Cannot do backward pass: errors not set");
    }

    // epsilons[index] => {batchSize, numberOfEmbeds}
    INDArray[] epsilons = new INDArray[inputs.length];
    epsilons[0] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, inputs[0]);
    epsilons[1] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, inputs[1]);
    // epsilon => {batchSize, 1}
    // inputs[index] => {batchSize, numberOfEmbeds}
    // TODO 如何通过inputs[index]与epsilon求导epsilons[index]
    INDArray left = inputs[0];
    INDArray right = inputs[1];
    for (int index = 0; index < epsilon.rows(); index++) {
        epsilons[0].putRow(index, right.getRow(index).transpose().mmul(epsilon.getRow(index)).transpose());
        epsilons[1].putRow(index, left.getRow(index).transpose().mmul(epsilon.getRow(index)).transpose());
    }
    return new Pair<>(null, epsilons);
}
 
Example #16
Source File: ElementWiseStrideTests.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testEWS1() throws Exception {
    List<Pair<INDArray,String>> list = NDArrayCreationUtil.getAllTestMatricesWithShape(4,5,12345);
    list.addAll(NDArrayCreationUtil.getAll3dTestArraysWithShape(12345,4,5,6));
    list.addAll(NDArrayCreationUtil.getAll4dTestArraysWithShape(12345,4,5,6,7));
    list.addAll(NDArrayCreationUtil.getAll5dTestArraysWithShape(12345,4,5,6,7,8));
    list.addAll(NDArrayCreationUtil.getAll6dTestArraysWithShape(12345,4,5,6,7,8,9));


    for(Pair<INDArray,String> p : list){
        int ewsBefore = Shape.elementWiseStride(p.getFirst().shapeInfo());
        INDArray reshapeAttempt = Shape.newShapeNoCopy(p.getFirst(),new int[]{1,p.getFirst().length()}, Nd4j.order() == 'f');

        if (reshapeAttempt != null && ewsBefore == -1 && reshapeAttempt.elementWiseStride() != -1 ) {
            System.out.println("NDArrayCreationUtil." + p.getSecond());
            System.out.println("ews before: " + ewsBefore);
            System.out.println(p.getFirst().shapeInfoToString());
            System.out.println("ews returned by elementWiseStride(): " + p.getFirst().elementWiseStride());
            System.out.println("ews returned by reshape(): " + reshapeAttempt.elementWiseStride());
            System.out.println();
      //      assertTrue(false);
        } else {
      //      System.out.println("FAILED: " + p.getFirst().shapeInfoToString());
        }
    }
}
 
Example #17
Source File: NDArrayTestsFortran.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testDupAndDupWithOrder() {
    List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123);
    int count = 0;
    for (Pair<INDArray, String> pair : testInputs) {
        String msg = pair.getSecond();
        INDArray in = pair.getFirst();
        System.out.println("Count " + count);
        INDArray dup = in.dup();
        INDArray dupc = in.dup('c');
        INDArray dupf = in.dup('f');

        assertEquals(msg, in, dup);
        assertEquals(msg, dup.ordering(), (char) Nd4j.order());
        assertEquals(msg, dupc.ordering(), 'c');
        assertEquals(msg, dupf.ordering(), 'f');
        assertEquals(msg, in, dupc);
        assertEquals(msg, in, dupf);
        count++;
    }
}
 
Example #18
Source File: NDArrayCreationUtil.java    From nd4j with Apache License 2.0 5 votes vote down vote up
public static List<Pair<INDArray, String>> get3dTensorAlongDimensionWithShape(long seed, long... shape) {
    List<Pair<INDArray, String>> list = new ArrayList<>();
    String baseMsg = "get3dTensorAlongDimensionWithShape(" + seed + "," + Arrays.toString(shape) + ")";

    //Create some 4d arrays and get subsets using 3d TAD on them
    //This is not an exhaustive list of possible 3d arrays from 4d via TAD

    Nd4j.getRandom().setSeed(seed);
    //            int[] shape4d1 = {shape[2],shape[1],shape[0],3};
    val shape4d1 = new long[]{shape[0], shape[1], shape[2], 3};
    int lenshape4d1 = ArrayUtil.prod(shape4d1);
    INDArray orig1a = Nd4j.linspace(1, lenshape4d1, lenshape4d1).reshape(shape4d1);
    INDArray tad1a = orig1a.javaTensorAlongDimension(0, 0, 1, 2);
    INDArray orig1b = Nd4j.linspace(1, lenshape4d1, lenshape4d1).reshape(shape4d1);
    INDArray tad1b = orig1b.javaTensorAlongDimension(1, 0, 1, 2);

    list.add(new Pair<>(tad1a, baseMsg + ".get(0)"));
    list.add(new Pair<>(tad1b, baseMsg + ".get(1)"));

    long[] shape4d2 = {3, shape[0], shape[1], shape[2]};
    int lenshape4d2 = ArrayUtil.prod(shape4d2);
    INDArray orig2 = Nd4j.linspace(1, lenshape4d2, lenshape4d2).reshape(shape4d2);
    INDArray tad2 = orig2.javaTensorAlongDimension(1, 1, 2, 3);
    list.add(new Pair<>(tad2, baseMsg + ".get(2)"));

    long[] shape4d3 = {shape[0], shape[1], 3, shape[2]};
    int lenshape4d3 = ArrayUtil.prod(shape4d3);
    INDArray orig3 = Nd4j.linspace(1, lenshape4d3, lenshape4d3).reshape(shape4d3);
    INDArray tad3 = orig3.javaTensorAlongDimension(1, 1, 3, 0);
    list.add(new Pair<>(tad3, baseMsg + ".get(3)"));

    long[] shape4d4 = {shape[0], 3, shape[1], shape[2]};
    int lenshape4d4 = ArrayUtil.prod(shape4d4);
    INDArray orig4 = Nd4j.linspace(1, lenshape4d4, lenshape4d4).reshape(shape4d4);
    INDArray tad4 = orig4.javaTensorAlongDimension(1, 2, 0, 3);
    list.add(new Pair<>(tad4, baseMsg + ".get(4)"));

    return list;
}
 
Example #19
Source File: FunctionalUtilsTest.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testGroupBy() {
    List<Pair<Integer,Integer>> list = new ArrayList<>();
    for(int i = 0; i < 10; i++) {
        for(int j = 0; j < 5; j++) {
            list.add(Pair.of(i, j));
        }
    }

    Map<Integer, List<Integer>> integerIterableMap = FunctionalUtils.groupByKey(list);
    assertEquals(10,integerIterableMap.keySet().size());
    assertEquals(5,integerIterableMap.get(0).size());
}
 
Example #20
Source File: NDArrayCreationUtil.java    From nd4j with Apache License 2.0 5 votes vote down vote up
public static List<Pair<INDArray, String>> get4dReshapedWithShape(int seed, int... shape) {
    Nd4j.getRandom().setSeed(seed);
    int[] shape2d = {shape[0] * shape[2], shape[1] * shape[3]};
    INDArray array2d = Nd4j.rand(shape2d);
    INDArray array3d = array2d.reshape(ArrayUtil.toLongArray(shape));
    return Collections.singletonList(new Pair<>(array3d,
                    "get4dReshapedWithShape(" + seed + "," + Arrays.toString(shape) + ").get(0)"));
}
 
Example #21
Source File: NDArrayCreationUtil.java    From nd4j with Apache License 2.0 5 votes vote down vote up
public static List<Pair<INDArray, String>> get3dReshapedWithShape(long seed, long... shape) {
    Nd4j.getRandom().setSeed(seed);
    long[] shape2d = {shape[0] * shape[2], shape[1]};
    int lenshape2d = ArrayUtil.prod(shape2d);
    INDArray array2d = Nd4j.linspace(1, lenshape2d, lenshape2d).reshape(shape2d);
    INDArray array3d = array2d.reshape(shape);
    return Collections.singletonList(new Pair<>(array3d,
                    "get3dReshapedWithShape(" + seed + "," + Arrays.toString(shape) + ").get(0)"));
}
 
Example #22
Source File: ProtectedCudaShapeInfoProvider.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public Pair<DataBuffer, long[]> createShapeInformation(int[] shape, int[] stride, long offset, int elementWiseStride, char order) {
    // We enforce offset to 0 in shapeBuffer, since we need it for cache efficiency + we don't actually use offset value @ native side
    offset = 0;

    Integer deviceId = AtomicAllocator.getInstance().getDeviceId();

    ShapeDescriptor descriptor = new ShapeDescriptor(shape, stride, offset, elementWiseStride, order);

    if (!protector.containsDataBuffer(deviceId, descriptor)) {
        Pair<DataBuffer, long[]> buffer = null;
        synchronized (this) {
            if (!protector.containsDataBuffer(deviceId, descriptor)) {
                //log.info("Cache miss: {}", descriptor);
                buffer = super.createShapeInformation(shape, stride, offset, elementWiseStride, order);
                buffer.getFirst().setConstant(true);

                if (CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.IMMEDIATE) {
                    Nd4j.getConstantHandler().moveToConstantSpace(buffer.getFirst());
                }

                //deviceCache.get(deviceId).put(descriptor, buffer);
                protector.persistDataBuffer(deviceId, descriptor, buffer);

                bytes.addAndGet(buffer.getFirst().length() * 4 * 2);

                cacheMiss.incrementAndGet();
            } else {
                buffer = protector.getDataBuffer(deviceId, descriptor);
            }
        }
        return buffer;
    } else {
        //       log.info("Cache hit: {}", descriptor);
        cacheHit.incrementAndGet();
    }

    return protector.getDataBuffer(deviceId, descriptor); //deviceCache.get(deviceId).get(descriptor);
}
 
Example #23
Source File: FunctionalUtils.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Convert a map with a set of entries of type K for key
 * and V for value in to a list of {@link Pair}
 * @param map the map to collapse
 * @param <K> the key type
 * @param <V> the value type
 * @return the collapsed map as a {@link List}
 */
public static <K,V> List<Pair<K,V>> mapToPair(Map<K,V> map) {
    List<Pair<K,V>> ret = new ArrayList<>(map.size());
    for(Map.Entry<K,V> entry : map.entrySet()) {
        ret.add(Pair.of(entry.getKey(),entry.getValue()));
    }

    return ret;
}
 
Example #24
Source File: FilesAsBytesFunction.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Override
public Pair<Text, BytesWritable> apply(Pair<String, InputStream> in) {
    try {
        return Pair.of(new Text(in.getFirst()), new BytesWritable(IOUtils.toByteArray(in.getSecond())));
    } catch (IOException e) {
        throw new IllegalStateException(e);

    }

}
 
Example #25
Source File: BaseShapeInfoProvider.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * This method creates shapeInformation buffer, based on shape & order being passed in
 *
 * @param shape
 * @param order
 * @return
 */
@Override
public Pair<DataBuffer, long[]> createShapeInformation(int[] shape, char order) {
    int[] stride = Nd4j.getStrides(shape, order);

    // this won't be view, so ews is 1
    int ews = 1;

    return createShapeInformation(shape, stride, 0, ews, order);
}
 
Example #26
Source File: LossHinge.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public Pair<Double, INDArray> computeGradientAndScore(INDArray labels,
                INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
    //TODO: probably a more efficient way to do this...
    return new Pair<>(computeScore(labels, preOutput, activationFn, mask, average),
                    computeGradient(labels, preOutput, activationFn, mask));
}
 
Example #27
Source File: PipelineImageTransform.java    From DataVec with Apache License 2.0 5 votes vote down vote up
/**
 * This method adds given transform with given invocation probability to this pipelien
 *
 * @param transform
 * @param probability
 * @return
 */
public Builder addImageTransform(@NonNull ImageTransform transform, Double probability) {
    if (probability < 0.0) {
        probability = 0.0;
    }
    if (probability > 1.0) {
        probability = 1.0;
    }

    imageTransforms.add(Pair.makePair(transform, probability));
    return this;
}
 
Example #28
Source File: ActivationRReLU.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {

    INDArray dLdz = Nd4j.ones(in.shape());
    BooleanIndexing.replaceWhere(dLdz, alpha, Conditions.lessThanOrEqual(0.0));
    dLdz.muli(epsilon);

    return new Pair<>(dLdz, null);
}
 
Example #29
Source File: BaseNDArray.java    From nd4j with Apache License 2.0 5 votes vote down vote up
protected void read(ObjectInputStream s) {
    shapeInformation = Nd4j.createBuffer(new int[Shape.shapeInfoLength(rank())], 0);
    shapeInformation.read(s);
    setShapeInformation(Pair.create(shapeInformation, shapeInformation.asLong()));
    data = Nd4j.createBuffer(length(), false);
    data().read(s);
}
 
Example #30
Source File: LossBinaryXENT.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn,
                INDArray mask, boolean average) {
    //TODO: probably a more efficient way to do this...

    return new Pair<>(computeScore(labels, preOutput, activationFn, mask, average),
                    computeGradient(labels, preOutput, activationFn, mask));
}