org.nd4j.linalg.indexing.INDArrayIndex Java Examples

The following examples show how to use org.nd4j.linalg.indexing.INDArrayIndex. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: ShapeResolutionTestsC.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
@Ignore
public void testIndexPointInterval() {
    INDArray zeros = Nd4j.zeros(3, 3, 3);
    INDArrayIndex x = NDArrayIndex.point(1);
    INDArrayIndex y = NDArrayIndex.interval(1, 2, true);
    INDArrayIndex z = NDArrayIndex.point(1);
    INDArray value = Nd4j.ones(1, 2);
    zeros.put(new INDArrayIndex[] {x, y, z}, value);

    String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + "  [[0,00,0,00,0,00]\n"
                    + " [0,00,1,00,0,00]\n" + " [0,00,1,00,0,00]]\n" + "  [[0,00,0,00,0,00]\n"
                    + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";

    String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + "  [[0.00,0.00,0.00]\n"
                    + " [0.00,1.00,0.00]\n" + " [0.00,1.00,0.00]]\n" + "  [[0.00,0.00,0.00]\n"
                    + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";

    if (!zeros.toString().equals(f2) && !zeros.toString().equals(f1))
        assertEquals(f2, zeros.toString());

}
 
Example #2
Source File: Nd4jVertex.java    From jstarcraft-ai with Apache License 2.0 6 votes vote down vote up
@Override
public void doForward() {
    GlobalMatrix inputKeyMatrix = GlobalMatrix.class.cast(inputKeyValues[0].getKey());
    GlobalMatrix inputValueMatrix = GlobalMatrix.class.cast(inputKeyValues[0].getValue());
    Nd4jMatrix outputKeyMatrix = Nd4jMatrix.class.cast(outputKeyValue.getKey());
    Nd4jMatrix outputValueMatrix = Nd4jMatrix.class.cast(outputKeyValue.getValue());

    {
        INDArray outputData = outputKeyMatrix.getArray();
        int cursor = 0;
        for (MathMatrix component : inputKeyMatrix.getComponentMatrixes()) {
            Nd4jMatrix nd4j = Nd4jMatrix.class.cast(component);
            INDArray array = nd4j.getArray();
            if (orientation) {
                outputData.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.interval(cursor, cursor + array.columns()) }, array);
                cursor += array.columns();
            } else {
                outputData.put(new INDArrayIndex[] { NDArrayIndex.interval(cursor, cursor + array.rows()), NDArrayIndex.all() }, array);
                cursor += array.rows();
            }
        }
    }

    outputValueMatrix.setValues(0F);
}
 
Example #3
Source File: BaseNDArrayFactory.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * Returns a vector with all of the elements in every nd array
 * equal to the sum of the lengths of the ndarrays
 *
 * @param matrices the ndarrays to getFloat a flattened representation of
 * @return the flattened ndarray
 */
@Override
public INDArray toFlattened(Collection<INDArray> matrices) {
    int length = 0;
    for (INDArray m : matrices)
        length += m.length();
    INDArray ret = Nd4j.create(1, length);
    int linearIndex = 0;
    for (INDArray d : matrices) {
        ret.put(new INDArrayIndex[] {NDArrayIndex.interval(linearIndex, linearIndex + d.length())}, d);
        linearIndex += d.length();
    }

    return ret;

}
 
Example #4
Source File: DataSetUtil.java    From nd4j with Apache License 2.0 6 votes vote down vote up
public static INDArray mergePerOutputMasks2d(long[] outShape, INDArray[] arrays, INDArray[] masks) {
    val numExamplesPerArr = new long[arrays.length];
    for (int i = 0; i < numExamplesPerArr.length; i++) {
        numExamplesPerArr[i] = arrays[i].size(0);
    }

    INDArray outMask = Nd4j.ones(outShape); //Initialize to 'all present' (1s)

    int rowsSoFar = 0;
    for (int i = 0; i < masks.length; i++) {
        long thisRows = numExamplesPerArr[i]; //Mask itself may be null -> all present, but may include multiple examples
        if (masks[i] == null) {
            continue;
        }

        outMask.put(new INDArrayIndex[] {NDArrayIndex.interval(rowsSoFar, rowsSoFar + thisRows),
                        NDArrayIndex.all()}, masks[i]);
        rowsSoFar += thisRows;
    }
    return outMask;
}
 
Example #5
Source File: MtcnnUtil.java    From mtcnn-java with Apache License 2.0 6 votes vote down vote up
/**
 * Convert the bbox into square.
 *
 * original code:
 *  - https://github.com/kpzhang93/MTCNN_face_detection_alignment/blob/master/code/codes/MTCNNv2/rerec.m
 *  - https://github.com/davidsandberg/facenet/blob/master/src/align/detect_face.py#L646
 *
 * @param bbox
 * @param withFloor
 * @return Returns array representing the squared bbox
 */
public static INDArray rerec(INDArray bbox, boolean withFloor) {
	// convert bbox to square
	INDArray h = bbox.get(all(), point(3)).sub(bbox.get(all(), point(1)));
	INDArray w = bbox.get(all(), point(2)).sub(bbox.get(all(), point(0)));
	INDArray l = Transforms.max(w, h);

	bbox.put(new INDArrayIndex[] { all(), point(0) }, bbox.get(all(), point(0)).add(w.mul(0.5)).sub(l.mul(0.5)));
	bbox.put(new INDArrayIndex[] { all(), point(1) }, bbox.get(all(), point(1)).add(h.mul(0.5)).sub(l.mul(0.5)));
	INDArray lTile = Nd4j.repeat(l, 2).transpose();
	bbox.put(new INDArrayIndex[] { all(), interval(2, 4) }, bbox.get(all(), interval(0, 2)).add(lTile));

	if (withFloor) {
		bbox.put(new INDArrayIndex[] { all(), interval(0, 4) }, Transforms.floor(bbox.get(all(), interval(0, 4))));
	}

	return bbox;
}
 
Example #6
Source File: NDArrayIndexResolveTests.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testResolvePointVector() {
    INDArray arr = Nd4j.linspace(1, 4, 4);
    INDArrayIndex[] getPoint = {NDArrayIndex.point(1)};
    INDArrayIndex[] resolved = NDArrayIndex.resolve(arr.shape(), getPoint);
    if (getPoint.length == resolved.length)
        assertArrayEquals(getPoint, resolved);
    else {
        assertEquals(2, resolved.length);
        assertTrue(resolved[0] instanceof PointIndex);
        assertEquals(0, resolved[0].current());
        assertTrue(resolved[1] instanceof PointIndex);
        assertEquals(1, resolved[1].current());
    }

}
 
Example #7
Source File: CnnSentenceDataSetIterator.java    From wekaDeeplearning4j with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Create the features based on the data of this batch.
 *
 * @param data Batch data
 * @param maxTokenSizeBatch Maximum token size in this batch
 * @return INDArray containing the features
 */
protected INDArray createFeatures(List<Datum> data, int maxTokenSizeBatch) {
  int tokenLimit = getMaxSentenceLength();

  // Determine feature shape
  int[] featuresShape = getFeatureShape(maxTokenSizeBatch, data.size());

  // Create features from tokens
  INDArray features = Nd4j.create(featuresShape);
  for (int i = 0; i < data.size(); i++) {
    List<String> currSentence = data.get(i).getTokens();
    for (int j = 0; j < currSentence.size() && j < tokenLimit; j++) {
      String token = currSentence.get(j);
      INDArray vectorizedToken = getVector(token);
      INDArrayIndex[] indices = new INDArrayIndex[4];
      indices[0] = point(i);
      indices[1] = point(0);
      indices[2] = point(j);
      indices[3] = all();
      features.put(indices, vectorizedToken);
    }
  }
  return features;
}
 
Example #8
Source File: CnnSentenceDataSetIterator.java    From wekaDeeplearning4j with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Create the features based on the data of this batch.
 *
 * @param data Batch data
 * @param maxTokenSizeBatch Maximum token size in this batch
 * @return INDArray containing the features
 */
protected INDArray createFeatures(List<Datum> data, int maxTokenSizeBatch) {
  int tokenLimit = getMaxSentenceLength();

  // Determine feature shape
  int[] featuresShape = getFeatureShape(maxTokenSizeBatch, data.size());

  // Create features from tokens
  INDArray features = Nd4j.create(featuresShape);
  for (int i = 0; i < data.size(); i++) {
    List<String> currSentence = data.get(i).getTokens();
    for (int j = 0; j < currSentence.size() && j < tokenLimit; j++) {
      String token = currSentence.get(j);
      INDArray vectorizedToken = getVector(token);
      INDArrayIndex[] indices = new INDArrayIndex[4];
      indices[0] = point(i);
      indices[1] = point(0);
      indices[2] = point(j);
      indices[3] = all();
      features.put(indices, vectorizedToken);
    }
  }
  return features;
}
 
Example #9
Source File: ShapeResolutionTestsC.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
@Ignore
public void testIndexPointAll() {
    INDArray zeros = Nd4j.zeros(3, 3, 3);
    INDArrayIndex x = NDArrayIndex.point(1);
    INDArrayIndex y = NDArrayIndex.all();
    INDArrayIndex z = NDArrayIndex.point(1);
    INDArray value = Nd4j.ones(1, 3);
    zeros.put(new INDArrayIndex[] {x, y, z}, value);

    String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + "  [[0,00,1,00,0,00]\n"
                    + " [0,00,1,00,0,00]\n" + " [0,00,1,00,0,00]]\n" + "  [[0,00,0,00,0,00]\n"
                    + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";

    String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + "  [[0.00,1.00,0.00]\n"
                    + " [0.00,1.00,0.00]\n" + " [0.00,1.00,0.00]]\n" + "  [[0.00,0.00,0.00]\n"
                    + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";

    if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2))
        assertEquals(f2, zeros.toString());
}
 
Example #10
Source File: ShapeResolutionTestsC.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
@Ignore
public void testIndexIntervalAll() {
    INDArray zeros = Nd4j.zeros(3, 3, 3);
    INDArrayIndex x = NDArrayIndex.interval(0, 1, true);
    INDArrayIndex y = NDArrayIndex.all();
    INDArrayIndex z = NDArrayIndex.interval(1, 2, true);
    INDArray value = Nd4j.ones(2, 6);
    zeros.put(new INDArrayIndex[] {x, y, z}, value);

    String f1 = "[[[0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + "  [[0,00,1,00,1,00]\n"
                    + " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + "  [[0,00,0,00,0,00]\n"
                    + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";

    String f2 = "[[[0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + "  [[0.00,1.00,1.00]\n"
                    + " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + "  [[0.00,0.00,0.00]\n"
                    + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";

    if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2))
        assertEquals(f2, zeros.toString());
}
 
Example #11
Source File: ShapeResolutionTestsC.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
@Ignore
public void testIndexPointIntervalAll() {
    INDArray zeros = Nd4j.zeros(3, 3, 3);
    INDArrayIndex x = NDArrayIndex.point(1);
    INDArrayIndex y = NDArrayIndex.all();
    INDArrayIndex z = NDArrayIndex.interval(1, 2, true);
    INDArray value = Nd4j.ones(3, 2);
    zeros.put(new INDArrayIndex[] {x, y, z}, value);

    String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + "  [[0,00,1,00,1,00]\n"
                    + " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + "  [[0,00,0,00,0,00]\n"
                    + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";

    String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + "  [[0.00,1.00,1.00]\n"
                    + " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + "  [[0.00,0.00,0.00]\n"
                    + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";

    if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2))
        assertEquals(f2, zeros.toString());
}
 
Example #12
Source File: Shape.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Convert the given int indexes
 * to nd array indexes
 * @param indices the indices to convert
 * @return the converted indexes
 */
public static INDArrayIndex[] toIndexes(int[] indices) {
    INDArrayIndex[] ret = new INDArrayIndex[indices.length];
    for (int i = 0; i < ret.length; i++)
        ret[i] = new NDArrayIndex(indices[i]);
    return ret;
}
 
Example #13
Source File: DeepFMParameter.java    From jstarcraft-rns with Apache License 2.0 5 votes vote down vote up
@Override
public Map<String, INDArray> init(NeuralNetConfiguration configuration, INDArray view, boolean initialize) {
    Map<String, INDArray> parameters = Collections.synchronizedMap(new LinkedHashMap<String, INDArray>());
    FeedForwardLayer layerConfiguration = (FeedForwardLayer) configuration.getLayer();
    long numberOfOut = layerConfiguration.getNOut();
    long numberOfWeights = numberOfFeatures * numberOfOut;
    INDArray weight = view.get(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.interval(0, numberOfWeights) });
    INDArray bias = view.get(NDArrayIndex.point(0), NDArrayIndex.interval(numberOfWeights, numberOfWeights + numberOfOut));

    parameters.put(WEIGHT_KEY, this.createWeightMatrix(configuration, weight, initialize));
    parameters.put(BIAS_KEY, createBias(configuration, bias, initialize));
    configuration.addVariable(WEIGHT_KEY);
    configuration.addVariable(BIAS_KEY);
    return parameters;
}
 
Example #14
Source File: ComplexNDArrayUtil.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Pads an ndarray with zeros
 *
 * @param nd          the ndarray to pad
 * @param targetShape the the new shape
 * @return the padded ndarray
 */
public static IComplexNDArray padWithZeros(IComplexNDArray nd, long[] targetShape) {
    if (Arrays.equals(nd.shape(), targetShape))
        return nd;
    //no padding required
    if (ArrayUtil.prod(nd.shape()) >= ArrayUtil.prod(targetShape))
        return nd;

    IComplexNDArray ret = Nd4j.createComplex(targetShape);
    INDArrayIndex[] targetShapeIndex = NDArrayIndex.createCoveringShape(nd.shape());
    ret.put(targetShapeIndex, nd);
    return ret;

}
 
Example #15
Source File: BaseSparseNDArray.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray put(INDArray indices, INDArray element) {
    INDArrayIndex[] realIndices = new INDArrayIndex[indices.rank()];
    for(int i = 0; i < realIndices.length; i++) {
        realIndices[i] = new SpecifiedIndex(indices.slice(i).dup().data().asInt());
    }


    return put(realIndices,element);

}
 
Example #16
Source File: NDArrayList.java    From nd4j with Apache License 2.0 5 votes vote down vote up
private void moveForward(int index) {
    int numMoved = size - index - 1;
    INDArrayIndex[] getRange = new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(index,index + numMoved)};
    INDArray get = container.get(getRange).dup();
    INDArrayIndex[] first = new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(index + 1,index + 1 + get.length())};
    container.put(first,get);
}
 
Example #17
Source File: ComplexNDArrayUtil.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Center an array
 *
 * @param arr   the arr to center
 * @param shape the shape of the array
 * @return the center portion of the array based on the
 * specified shape
 */
public static IComplexNDArray center(IComplexNDArray arr, long[] shape) {
    if (arr.length() < ArrayUtil.prod(shape))
        return arr;
    for (int i = 0; i < shape.length; i++)
        if (shape[i] < 1)
            shape[i] = 1;

    INDArray shapeMatrix = NDArrayUtil.toNDArray(shape);
    INDArray currShape = NDArrayUtil.toNDArray(arr.shape());

    INDArray startIndex = Transforms.floor(currShape.sub(shapeMatrix).divi(Nd4j.scalar(2)));
    INDArray endIndex = startIndex.add(shapeMatrix);
    INDArrayIndex[] indexes = Indices.createFromStartAndEnd(startIndex, endIndex);

    if (shapeMatrix.length() > 1)
        return arr.get(indexes);


    else {
        IComplexNDArray ret = Nd4j.createComplex(new int[] {(int) shapeMatrix.getDouble(0)});
        int start = (int) startIndex.getDouble(0);
        int end = (int) endIndex.getDouble(0);
        int count = 0;
        for (int i = start; i < end; i++) {
            ret.putScalar(count++, arr.getComplex(i));
        }

        return ret;
    }


}
 
Example #18
Source File: NDArrayList.java    From nd4j with Apache License 2.0 5 votes vote down vote up
private void growCapacity(int idx) {
    if(container == null) {
        container = Nd4j.create(10);
    }
    else if(idx >= container.length()) {
        val max = Math.max(container.length() * 2,idx);
        INDArray newContainer = Nd4j.create(max);
        newContainer.put(new INDArrayIndex[]{NDArrayIndex.interval(0,container.length())},container);
        container = newContainer;
    }
}
 
Example #19
Source File: LossMixtureDensity.java    From nd4j with Apache License 2.0 5 votes vote down vote up
private INDArray labelsMinusMu(INDArray labels, INDArray mu) {
    // Now that we have the mixtures, let's compute the negative
    // log likelihodd of the label against the 
    long nSamples = labels.size(0);
    long labelsPerSample = labels.size(1);

    // This worked, but was actually much
    // slower than the for loop below.
    // labels = samples, mixtures, labels
    // mu = samples, mixtures
    // INDArray labelMinusMu = labels
    //        .reshape('f', nSamples, labelsPerSample, 1)
    //        .repeat(2, mMixtures)
    //        .permute(0, 2, 1)
    //        .subi(mu);

    // The above code does the same thing as the loop below,
    // but it does it with index magix instead of a for loop.
    // It turned out to be way less efficient than the simple 'for' here.
    INDArray labelMinusMu = Nd4j.zeros(nSamples, mMixtures, labelsPerSample);
    for (int k = 0; k < mMixtures; k++) {
        labelMinusMu.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(k), NDArrayIndex.all()},
                        labels);
    }
    labelMinusMu.subi(mu);

    return labelMinusMu;
}
 
Example #20
Source File: BaseNDArrayList.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public boolean remove(Object o) {
    int idx = BooleanIndexing.firstIndex(container,new EqualsCondition((double) o)).getInt(0);
    if(idx < 0)
        return false;
    container.put(new INDArrayIndex[]{NDArrayIndex.interval(idx,container.length())},container.get(NDArrayIndex.interval(idx + 1,container.length())));
    container = container.reshape(1,size);
    return true;
}
 
Example #21
Source File: BaseNDArrayList.java    From nd4j with Apache License 2.0 5 votes vote down vote up
private void growCapacity(int idx) {
    if(container == null) {
        container = Nd4j.create(10);
    }
    else if(idx >= container.length()) {
        val max = Math.max(container.length() * 2,idx);
        INDArray newContainer = Nd4j.create(max);
        newContainer.put(new INDArrayIndex[]{NDArrayIndex.interval(0,container.length())},container);
        container = newContainer;
    }
}
 
Example #22
Source File: NDArrayList.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public boolean remove(Object o) {
    int idx = BooleanIndexing.firstIndex(container,new EqualsCondition((double) o)).getInt(0);
    if(idx < 0)
        return false;
    container.put(new INDArrayIndex[]{NDArrayIndex.interval(idx,container.length())},container.get(NDArrayIndex.interval(idx + 1,container.length())));
    container = container.reshape(1,size);
    return true;
}
 
Example #23
Source File: IndexingTestsC.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testPointIndexing() {
    int slices = 5;
    int rows = 5;
    int cols = 5;
    int l = slices * rows * cols;
    INDArray A = Nd4j.linspace(1, l, l).reshape(slices, rows, cols);

    for (int s = 0; s < slices; s++) {
        INDArrayIndex ndi_Slice = NDArrayIndex.point(s);
        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                log.info("Running for ( {}, {} - {} , {} - {} )", s, i, rows, j, cols);
                INDArrayIndex ndi_I = NDArrayIndex.interval(i, rows);
                INDArrayIndex ndi_J = NDArrayIndex.interval(j, cols);
                INDArray aView = A.get(ndi_Slice).get(ndi_I, ndi_J);
                INDArray sameView = A.get(ndi_Slice, ndi_I, ndi_J);
                String failureMessage = String.format("Fails for (%d , %d - %d, %d - %d)\n", s, i, rows, j, cols);
                try {
                    assertEquals(failureMessage, aView, sameView);
                } catch (Throwable t) {
                    collector.addError(t);
                }
            }
        }
    }
}
 
Example #24
Source File: IndexShapeTests.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testNewAxis() {
    //normal prepend
    int[] prependAssertion = {1, 1, 1, 1, 2, 1, 3, 4, 5, 1};
    INDArrayIndex[] prependTest = {NDArrayIndex.newAxis(), NDArrayIndex.newAxis(), NDArrayIndex.all(),
                    NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(),
                    NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(),

    };

    assertArrayEquals(prependAssertion, Indices.shape(shape, prependTest));

    //test setting for particular indexes.
    //when an all is encountered before a new axis,
    //it is assumed that new axis must occur at the destination
    //where the new axis was specified
    int[] addToMiddle = {1, 1, 2, 1, 1, 1, 3, 4, 5, 1};
    INDArrayIndex[] setInMiddleTest = {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(),
                    NDArrayIndex.newAxis(), NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.all(),
                    NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(),};
    assertArrayEquals(addToMiddle, Indices.shape(shape, setInMiddleTest));

    //test prepending AND adding to middle
    int[] prependAndAddToMiddleAssertion = {1, 1, 1, 1, 2, 1, 1, 1, 3, 4, 5, 1};

    INDArrayIndex[] prependAndMiddle = {NDArrayIndex.newAxis(), NDArrayIndex.newAxis(), NDArrayIndex.all(),
                    NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.newAxis(), NDArrayIndex.newAxis(),
                    NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(),
                    NDArrayIndex.all(),};
    assertArrayEquals(prependAndAddToMiddleAssertion, Indices.shape(shape, prependAndMiddle));

}
 
Example #25
Source File: IndexShapeTests.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testSinglePoint() {
    /*
    Assumes all indexes are filled out.
    Test simple general point case
     */
    int[] assertion = {2, 1, 4, 5, 1};
    INDArrayIndex[] indexes = new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
                    NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.all()};

    int[] testShape = Indices.shape(shape, indexes);
    assertArrayEquals(assertion, testShape);

    int[] secondAssertion = {1, 2, 1, 5, 1};
    INDArrayIndex[] otherCase = new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.all(),
                    NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)

    };

    assertArrayEquals(secondAssertion, Indices.shape(shape, otherCase));


    int[] thridAssertion = {1, 2, 1, 4, 5, 1};
    INDArrayIndex[] thirdCase = new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.all(),
                    NDArrayIndex.all(), NDArrayIndex.point(0),

    };
    assertArrayEquals(thridAssertion, Indices.shape(shape, thirdCase));

}
 
Example #26
Source File: IndexShapeTests2d.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testNewAxis2d() {
    assertArrayEquals(new long[] {1, 3, 2}, Indices.shape(shape,
                    new INDArrayIndex[] {NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.all()}));
    assertArrayEquals(new long[] {3, 1, 2}, Indices.shape(shape,
                    new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.newAxis(), NDArrayIndex.all()}));

}
 
Example #27
Source File: ShapeResolutionTestsC.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testVectorIndexPointPointOutOfRange2() {
    INDArray zeros = Nd4j.zeros(1, 4);
    INDArrayIndex x = NDArrayIndex.point(1);
    INDArrayIndex y = NDArrayIndex.point(2);
    INDArray value = Nd4j.ones(1, 1);
    try {
        zeros.put(new INDArrayIndex[] {x, y}, value);
        fail("Out of range index should throw an IllegalArgumentException");
    } catch (IllegalArgumentException e) {
        //do nothing
    }
}
 
Example #28
Source File: ShapeResolutionTestsC.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testVectorIndexPointPointOutOfRange() {
    INDArray zeros = Nd4j.zeros(1, 4);
    INDArrayIndex x = NDArrayIndex.point(0);
    INDArrayIndex y = NDArrayIndex.point(4);
    INDArray value = Nd4j.ones(1, 1);
    try {
        zeros.put(new INDArrayIndex[] {x, y}, value);
        fail("Out of range index should throw an IllegalArgumentException");
    } catch (IllegalArgumentException e) {
        //do nothing
    }
}
 
Example #29
Source File: ShapeResolutionTestsC.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testVectorIndexPointPoint() {
    INDArray zeros = Nd4j.zeros(1, 4);
    INDArrayIndex x = NDArrayIndex.point(0);
    INDArrayIndex y = NDArrayIndex.point(2);
    INDArray value = Nd4j.ones(1, 1);
    zeros.put(new INDArrayIndex[] {x, y}, value);

    INDArray assertion = Nd4j.create(new double[] {0.0, 0.0, 1.0, 0.0});
    assertEquals(assertion, zeros);
}
 
Example #30
Source File: ShapeResolutionTestsC.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testFlatIndexPointInterval() {
    INDArray zeros = Nd4j.zeros(1, 4);
    INDArrayIndex x = NDArrayIndex.point(0);
    INDArrayIndex y = NDArrayIndex.interval(1, 2, true);
    INDArray value = Nd4j.ones(1, 2);
    zeros.put(new INDArrayIndex[] {x, y}, value);

    INDArray assertion = Nd4j.create(new double[] {0.0, 1.0, 1.0, 0.0});
    assertEquals(assertion, zeros);
}