org.nd4j.linalg.indexing.INDArrayIndex Java Examples
The following examples show how to use
org.nd4j.linalg.indexing.INDArrayIndex.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: ShapeResolutionTestsC.java From nd4j with Apache License 2.0 | 6 votes |
@Test @Ignore public void testIndexPointInterval() { INDArray zeros = Nd4j.zeros(3, 3, 3); INDArrayIndex x = NDArrayIndex.point(1); INDArrayIndex y = NDArrayIndex.interval(1, 2, true); INDArrayIndex z = NDArrayIndex.point(1); INDArray value = Nd4j.ones(1, 2); zeros.put(new INDArrayIndex[] {x, y, z}, value); String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + " [[0,00,0,00,0,00]\n" + " [0,00,1,00,0,00]\n" + " [0,00,1,00,0,00]]\n" + " [[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]"; String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + " [[0.00,0.00,0.00]\n" + " [0.00,1.00,0.00]\n" + " [0.00,1.00,0.00]]\n" + " [[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]"; if (!zeros.toString().equals(f2) && !zeros.toString().equals(f1)) assertEquals(f2, zeros.toString()); }
Example #2
Source File: Nd4jVertex.java From jstarcraft-ai with Apache License 2.0 | 6 votes |
@Override public void doForward() { GlobalMatrix inputKeyMatrix = GlobalMatrix.class.cast(inputKeyValues[0].getKey()); GlobalMatrix inputValueMatrix = GlobalMatrix.class.cast(inputKeyValues[0].getValue()); Nd4jMatrix outputKeyMatrix = Nd4jMatrix.class.cast(outputKeyValue.getKey()); Nd4jMatrix outputValueMatrix = Nd4jMatrix.class.cast(outputKeyValue.getValue()); { INDArray outputData = outputKeyMatrix.getArray(); int cursor = 0; for (MathMatrix component : inputKeyMatrix.getComponentMatrixes()) { Nd4jMatrix nd4j = Nd4jMatrix.class.cast(component); INDArray array = nd4j.getArray(); if (orientation) { outputData.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.interval(cursor, cursor + array.columns()) }, array); cursor += array.columns(); } else { outputData.put(new INDArrayIndex[] { NDArrayIndex.interval(cursor, cursor + array.rows()), NDArrayIndex.all() }, array); cursor += array.rows(); } } } outputValueMatrix.setValues(0F); }
Example #3
Source File: BaseNDArrayFactory.java From nd4j with Apache License 2.0 | 6 votes |
/** * Returns a vector with all of the elements in every nd array * equal to the sum of the lengths of the ndarrays * * @param matrices the ndarrays to getFloat a flattened representation of * @return the flattened ndarray */ @Override public INDArray toFlattened(Collection<INDArray> matrices) { int length = 0; for (INDArray m : matrices) length += m.length(); INDArray ret = Nd4j.create(1, length); int linearIndex = 0; for (INDArray d : matrices) { ret.put(new INDArrayIndex[] {NDArrayIndex.interval(linearIndex, linearIndex + d.length())}, d); linearIndex += d.length(); } return ret; }
Example #4
Source File: DataSetUtil.java From nd4j with Apache License 2.0 | 6 votes |
public static INDArray mergePerOutputMasks2d(long[] outShape, INDArray[] arrays, INDArray[] masks) { val numExamplesPerArr = new long[arrays.length]; for (int i = 0; i < numExamplesPerArr.length; i++) { numExamplesPerArr[i] = arrays[i].size(0); } INDArray outMask = Nd4j.ones(outShape); //Initialize to 'all present' (1s) int rowsSoFar = 0; for (int i = 0; i < masks.length; i++) { long thisRows = numExamplesPerArr[i]; //Mask itself may be null -> all present, but may include multiple examples if (masks[i] == null) { continue; } outMask.put(new INDArrayIndex[] {NDArrayIndex.interval(rowsSoFar, rowsSoFar + thisRows), NDArrayIndex.all()}, masks[i]); rowsSoFar += thisRows; } return outMask; }
Example #5
Source File: MtcnnUtil.java From mtcnn-java with Apache License 2.0 | 6 votes |
/** * Convert the bbox into square. * * original code: * - https://github.com/kpzhang93/MTCNN_face_detection_alignment/blob/master/code/codes/MTCNNv2/rerec.m * - https://github.com/davidsandberg/facenet/blob/master/src/align/detect_face.py#L646 * * @param bbox * @param withFloor * @return Returns array representing the squared bbox */ public static INDArray rerec(INDArray bbox, boolean withFloor) { // convert bbox to square INDArray h = bbox.get(all(), point(3)).sub(bbox.get(all(), point(1))); INDArray w = bbox.get(all(), point(2)).sub(bbox.get(all(), point(0))); INDArray l = Transforms.max(w, h); bbox.put(new INDArrayIndex[] { all(), point(0) }, bbox.get(all(), point(0)).add(w.mul(0.5)).sub(l.mul(0.5))); bbox.put(new INDArrayIndex[] { all(), point(1) }, bbox.get(all(), point(1)).add(h.mul(0.5)).sub(l.mul(0.5))); INDArray lTile = Nd4j.repeat(l, 2).transpose(); bbox.put(new INDArrayIndex[] { all(), interval(2, 4) }, bbox.get(all(), interval(0, 2)).add(lTile)); if (withFloor) { bbox.put(new INDArrayIndex[] { all(), interval(0, 4) }, Transforms.floor(bbox.get(all(), interval(0, 4)))); } return bbox; }
Example #6
Source File: NDArrayIndexResolveTests.java From nd4j with Apache License 2.0 | 6 votes |
@Test public void testResolvePointVector() { INDArray arr = Nd4j.linspace(1, 4, 4); INDArrayIndex[] getPoint = {NDArrayIndex.point(1)}; INDArrayIndex[] resolved = NDArrayIndex.resolve(arr.shape(), getPoint); if (getPoint.length == resolved.length) assertArrayEquals(getPoint, resolved); else { assertEquals(2, resolved.length); assertTrue(resolved[0] instanceof PointIndex); assertEquals(0, resolved[0].current()); assertTrue(resolved[1] instanceof PointIndex); assertEquals(1, resolved[1].current()); } }
Example #7
Source File: CnnSentenceDataSetIterator.java From wekaDeeplearning4j with GNU General Public License v3.0 | 6 votes |
/** * Create the features based on the data of this batch. * * @param data Batch data * @param maxTokenSizeBatch Maximum token size in this batch * @return INDArray containing the features */ protected INDArray createFeatures(List<Datum> data, int maxTokenSizeBatch) { int tokenLimit = getMaxSentenceLength(); // Determine feature shape int[] featuresShape = getFeatureShape(maxTokenSizeBatch, data.size()); // Create features from tokens INDArray features = Nd4j.create(featuresShape); for (int i = 0; i < data.size(); i++) { List<String> currSentence = data.get(i).getTokens(); for (int j = 0; j < currSentence.size() && j < tokenLimit; j++) { String token = currSentence.get(j); INDArray vectorizedToken = getVector(token); INDArrayIndex[] indices = new INDArrayIndex[4]; indices[0] = point(i); indices[1] = point(0); indices[2] = point(j); indices[3] = all(); features.put(indices, vectorizedToken); } } return features; }
Example #8
Source File: CnnSentenceDataSetIterator.java From wekaDeeplearning4j with GNU General Public License v3.0 | 6 votes |
/** * Create the features based on the data of this batch. * * @param data Batch data * @param maxTokenSizeBatch Maximum token size in this batch * @return INDArray containing the features */ protected INDArray createFeatures(List<Datum> data, int maxTokenSizeBatch) { int tokenLimit = getMaxSentenceLength(); // Determine feature shape int[] featuresShape = getFeatureShape(maxTokenSizeBatch, data.size()); // Create features from tokens INDArray features = Nd4j.create(featuresShape); for (int i = 0; i < data.size(); i++) { List<String> currSentence = data.get(i).getTokens(); for (int j = 0; j < currSentence.size() && j < tokenLimit; j++) { String token = currSentence.get(j); INDArray vectorizedToken = getVector(token); INDArrayIndex[] indices = new INDArrayIndex[4]; indices[0] = point(i); indices[1] = point(0); indices[2] = point(j); indices[3] = all(); features.put(indices, vectorizedToken); } } return features; }
Example #9
Source File: ShapeResolutionTestsC.java From nd4j with Apache License 2.0 | 6 votes |
@Test @Ignore public void testIndexPointAll() { INDArray zeros = Nd4j.zeros(3, 3, 3); INDArrayIndex x = NDArrayIndex.point(1); INDArrayIndex y = NDArrayIndex.all(); INDArrayIndex z = NDArrayIndex.point(1); INDArray value = Nd4j.ones(1, 3); zeros.put(new INDArrayIndex[] {x, y, z}, value); String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + " [[0,00,1,00,0,00]\n" + " [0,00,1,00,0,00]\n" + " [0,00,1,00,0,00]]\n" + " [[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]"; String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + " [[0.00,1.00,0.00]\n" + " [0.00,1.00,0.00]\n" + " [0.00,1.00,0.00]]\n" + " [[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]"; if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2)) assertEquals(f2, zeros.toString()); }
Example #10
Source File: ShapeResolutionTestsC.java From nd4j with Apache License 2.0 | 6 votes |
@Test @Ignore public void testIndexIntervalAll() { INDArray zeros = Nd4j.zeros(3, 3, 3); INDArrayIndex x = NDArrayIndex.interval(0, 1, true); INDArrayIndex y = NDArrayIndex.all(); INDArrayIndex z = NDArrayIndex.interval(1, 2, true); INDArray value = Nd4j.ones(2, 6); zeros.put(new INDArrayIndex[] {x, y, z}, value); String f1 = "[[[0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + " [[0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + " [[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]"; String f2 = "[[[0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + " [[0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + " [[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]"; if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2)) assertEquals(f2, zeros.toString()); }
Example #11
Source File: ShapeResolutionTestsC.java From nd4j with Apache License 2.0 | 6 votes |
@Test @Ignore public void testIndexPointIntervalAll() { INDArray zeros = Nd4j.zeros(3, 3, 3); INDArrayIndex x = NDArrayIndex.point(1); INDArrayIndex y = NDArrayIndex.all(); INDArrayIndex z = NDArrayIndex.interval(1, 2, true); INDArray value = Nd4j.ones(3, 2); zeros.put(new INDArrayIndex[] {x, y, z}, value); String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + " [[0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + " [[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]"; String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + " [[0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + " [[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]"; if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2)) assertEquals(f2, zeros.toString()); }
Example #12
Source File: Shape.java From nd4j with Apache License 2.0 | 5 votes |
/** * Convert the given int indexes * to nd array indexes * @param indices the indices to convert * @return the converted indexes */ public static INDArrayIndex[] toIndexes(int[] indices) { INDArrayIndex[] ret = new INDArrayIndex[indices.length]; for (int i = 0; i < ret.length; i++) ret[i] = new NDArrayIndex(indices[i]); return ret; }
Example #13
Source File: DeepFMParameter.java From jstarcraft-rns with Apache License 2.0 | 5 votes |
@Override public Map<String, INDArray> init(NeuralNetConfiguration configuration, INDArray view, boolean initialize) { Map<String, INDArray> parameters = Collections.synchronizedMap(new LinkedHashMap<String, INDArray>()); FeedForwardLayer layerConfiguration = (FeedForwardLayer) configuration.getLayer(); long numberOfOut = layerConfiguration.getNOut(); long numberOfWeights = numberOfFeatures * numberOfOut; INDArray weight = view.get(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.interval(0, numberOfWeights) }); INDArray bias = view.get(NDArrayIndex.point(0), NDArrayIndex.interval(numberOfWeights, numberOfWeights + numberOfOut)); parameters.put(WEIGHT_KEY, this.createWeightMatrix(configuration, weight, initialize)); parameters.put(BIAS_KEY, createBias(configuration, bias, initialize)); configuration.addVariable(WEIGHT_KEY); configuration.addVariable(BIAS_KEY); return parameters; }
Example #14
Source File: ComplexNDArrayUtil.java From nd4j with Apache License 2.0 | 5 votes |
/** * Pads an ndarray with zeros * * @param nd the ndarray to pad * @param targetShape the the new shape * @return the padded ndarray */ public static IComplexNDArray padWithZeros(IComplexNDArray nd, long[] targetShape) { if (Arrays.equals(nd.shape(), targetShape)) return nd; //no padding required if (ArrayUtil.prod(nd.shape()) >= ArrayUtil.prod(targetShape)) return nd; IComplexNDArray ret = Nd4j.createComplex(targetShape); INDArrayIndex[] targetShapeIndex = NDArrayIndex.createCoveringShape(nd.shape()); ret.put(targetShapeIndex, nd); return ret; }
Example #15
Source File: BaseSparseNDArray.java From nd4j with Apache License 2.0 | 5 votes |
@Override public INDArray put(INDArray indices, INDArray element) { INDArrayIndex[] realIndices = new INDArrayIndex[indices.rank()]; for(int i = 0; i < realIndices.length; i++) { realIndices[i] = new SpecifiedIndex(indices.slice(i).dup().data().asInt()); } return put(realIndices,element); }
Example #16
Source File: NDArrayList.java From nd4j with Apache License 2.0 | 5 votes |
private void moveForward(int index) { int numMoved = size - index - 1; INDArrayIndex[] getRange = new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(index,index + numMoved)}; INDArray get = container.get(getRange).dup(); INDArrayIndex[] first = new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(index + 1,index + 1 + get.length())}; container.put(first,get); }
Example #17
Source File: ComplexNDArrayUtil.java From nd4j with Apache License 2.0 | 5 votes |
/** * Center an array * * @param arr the arr to center * @param shape the shape of the array * @return the center portion of the array based on the * specified shape */ public static IComplexNDArray center(IComplexNDArray arr, long[] shape) { if (arr.length() < ArrayUtil.prod(shape)) return arr; for (int i = 0; i < shape.length; i++) if (shape[i] < 1) shape[i] = 1; INDArray shapeMatrix = NDArrayUtil.toNDArray(shape); INDArray currShape = NDArrayUtil.toNDArray(arr.shape()); INDArray startIndex = Transforms.floor(currShape.sub(shapeMatrix).divi(Nd4j.scalar(2))); INDArray endIndex = startIndex.add(shapeMatrix); INDArrayIndex[] indexes = Indices.createFromStartAndEnd(startIndex, endIndex); if (shapeMatrix.length() > 1) return arr.get(indexes); else { IComplexNDArray ret = Nd4j.createComplex(new int[] {(int) shapeMatrix.getDouble(0)}); int start = (int) startIndex.getDouble(0); int end = (int) endIndex.getDouble(0); int count = 0; for (int i = start; i < end; i++) { ret.putScalar(count++, arr.getComplex(i)); } return ret; } }
Example #18
Source File: NDArrayList.java From nd4j with Apache License 2.0 | 5 votes |
private void growCapacity(int idx) { if(container == null) { container = Nd4j.create(10); } else if(idx >= container.length()) { val max = Math.max(container.length() * 2,idx); INDArray newContainer = Nd4j.create(max); newContainer.put(new INDArrayIndex[]{NDArrayIndex.interval(0,container.length())},container); container = newContainer; } }
Example #19
Source File: LossMixtureDensity.java From nd4j with Apache License 2.0 | 5 votes |
private INDArray labelsMinusMu(INDArray labels, INDArray mu) { // Now that we have the mixtures, let's compute the negative // log likelihodd of the label against the long nSamples = labels.size(0); long labelsPerSample = labels.size(1); // This worked, but was actually much // slower than the for loop below. // labels = samples, mixtures, labels // mu = samples, mixtures // INDArray labelMinusMu = labels // .reshape('f', nSamples, labelsPerSample, 1) // .repeat(2, mMixtures) // .permute(0, 2, 1) // .subi(mu); // The above code does the same thing as the loop below, // but it does it with index magix instead of a for loop. // It turned out to be way less efficient than the simple 'for' here. INDArray labelMinusMu = Nd4j.zeros(nSamples, mMixtures, labelsPerSample); for (int k = 0; k < mMixtures; k++) { labelMinusMu.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(k), NDArrayIndex.all()}, labels); } labelMinusMu.subi(mu); return labelMinusMu; }
Example #20
Source File: BaseNDArrayList.java From nd4j with Apache License 2.0 | 5 votes |
@Override public boolean remove(Object o) { int idx = BooleanIndexing.firstIndex(container,new EqualsCondition((double) o)).getInt(0); if(idx < 0) return false; container.put(new INDArrayIndex[]{NDArrayIndex.interval(idx,container.length())},container.get(NDArrayIndex.interval(idx + 1,container.length()))); container = container.reshape(1,size); return true; }
Example #21
Source File: BaseNDArrayList.java From nd4j with Apache License 2.0 | 5 votes |
private void growCapacity(int idx) { if(container == null) { container = Nd4j.create(10); } else if(idx >= container.length()) { val max = Math.max(container.length() * 2,idx); INDArray newContainer = Nd4j.create(max); newContainer.put(new INDArrayIndex[]{NDArrayIndex.interval(0,container.length())},container); container = newContainer; } }
Example #22
Source File: NDArrayList.java From nd4j with Apache License 2.0 | 5 votes |
@Override public boolean remove(Object o) { int idx = BooleanIndexing.firstIndex(container,new EqualsCondition((double) o)).getInt(0); if(idx < 0) return false; container.put(new INDArrayIndex[]{NDArrayIndex.interval(idx,container.length())},container.get(NDArrayIndex.interval(idx + 1,container.length()))); container = container.reshape(1,size); return true; }
Example #23
Source File: IndexingTestsC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testPointIndexing() { int slices = 5; int rows = 5; int cols = 5; int l = slices * rows * cols; INDArray A = Nd4j.linspace(1, l, l).reshape(slices, rows, cols); for (int s = 0; s < slices; s++) { INDArrayIndex ndi_Slice = NDArrayIndex.point(s); for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { log.info("Running for ( {}, {} - {} , {} - {} )", s, i, rows, j, cols); INDArrayIndex ndi_I = NDArrayIndex.interval(i, rows); INDArrayIndex ndi_J = NDArrayIndex.interval(j, cols); INDArray aView = A.get(ndi_Slice).get(ndi_I, ndi_J); INDArray sameView = A.get(ndi_Slice, ndi_I, ndi_J); String failureMessage = String.format("Fails for (%d , %d - %d, %d - %d)\n", s, i, rows, j, cols); try { assertEquals(failureMessage, aView, sameView); } catch (Throwable t) { collector.addError(t); } } } } }
Example #24
Source File: IndexShapeTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testNewAxis() { //normal prepend int[] prependAssertion = {1, 1, 1, 1, 2, 1, 3, 4, 5, 1}; INDArrayIndex[] prependTest = {NDArrayIndex.newAxis(), NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), }; assertArrayEquals(prependAssertion, Indices.shape(shape, prependTest)); //test setting for particular indexes. //when an all is encountered before a new axis, //it is assumed that new axis must occur at the destination //where the new axis was specified int[] addToMiddle = {1, 1, 2, 1, 1, 1, 3, 4, 5, 1}; INDArrayIndex[] setInMiddleTest = {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.newAxis(), NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(),}; assertArrayEquals(addToMiddle, Indices.shape(shape, setInMiddleTest)); //test prepending AND adding to middle int[] prependAndAddToMiddleAssertion = {1, 1, 1, 1, 2, 1, 1, 1, 3, 4, 5, 1}; INDArrayIndex[] prependAndMiddle = {NDArrayIndex.newAxis(), NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.newAxis(), NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(),}; assertArrayEquals(prependAndAddToMiddleAssertion, Indices.shape(shape, prependAndMiddle)); }
Example #25
Source File: IndexShapeTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testSinglePoint() { /* Assumes all indexes are filled out. Test simple general point case */ int[] assertion = {2, 1, 4, 5, 1}; INDArrayIndex[] indexes = new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.all()}; int[] testShape = Indices.shape(shape, indexes); assertArrayEquals(assertion, testShape); int[] secondAssertion = {1, 2, 1, 5, 1}; INDArrayIndex[] otherCase = new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0) }; assertArrayEquals(secondAssertion, Indices.shape(shape, otherCase)); int[] thridAssertion = {1, 2, 1, 4, 5, 1}; INDArrayIndex[] thirdCase = new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(0), }; assertArrayEquals(thridAssertion, Indices.shape(shape, thirdCase)); }
Example #26
Source File: IndexShapeTests2d.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testNewAxis2d() { assertArrayEquals(new long[] {1, 3, 2}, Indices.shape(shape, new INDArrayIndex[] {NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.all()})); assertArrayEquals(new long[] {3, 1, 2}, Indices.shape(shape, new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.newAxis(), NDArrayIndex.all()})); }
Example #27
Source File: ShapeResolutionTestsC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testVectorIndexPointPointOutOfRange2() { INDArray zeros = Nd4j.zeros(1, 4); INDArrayIndex x = NDArrayIndex.point(1); INDArrayIndex y = NDArrayIndex.point(2); INDArray value = Nd4j.ones(1, 1); try { zeros.put(new INDArrayIndex[] {x, y}, value); fail("Out of range index should throw an IllegalArgumentException"); } catch (IllegalArgumentException e) { //do nothing } }
Example #28
Source File: ShapeResolutionTestsC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testVectorIndexPointPointOutOfRange() { INDArray zeros = Nd4j.zeros(1, 4); INDArrayIndex x = NDArrayIndex.point(0); INDArrayIndex y = NDArrayIndex.point(4); INDArray value = Nd4j.ones(1, 1); try { zeros.put(new INDArrayIndex[] {x, y}, value); fail("Out of range index should throw an IllegalArgumentException"); } catch (IllegalArgumentException e) { //do nothing } }
Example #29
Source File: ShapeResolutionTestsC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testVectorIndexPointPoint() { INDArray zeros = Nd4j.zeros(1, 4); INDArrayIndex x = NDArrayIndex.point(0); INDArrayIndex y = NDArrayIndex.point(2); INDArray value = Nd4j.ones(1, 1); zeros.put(new INDArrayIndex[] {x, y}, value); INDArray assertion = Nd4j.create(new double[] {0.0, 0.0, 1.0, 0.0}); assertEquals(assertion, zeros); }
Example #30
Source File: ShapeResolutionTestsC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testFlatIndexPointInterval() { INDArray zeros = Nd4j.zeros(1, 4); INDArrayIndex x = NDArrayIndex.point(0); INDArrayIndex y = NDArrayIndex.interval(1, 2, true); INDArray value = Nd4j.ones(1, 2); zeros.put(new INDArrayIndex[] {x, y}, value); INDArray assertion = Nd4j.create(new double[] {0.0, 1.0, 1.0, 0.0}); assertEquals(assertion, zeros); }