org.nd4j.linalg.exception.ND4JIllegalStateException Java Examples
The following examples show how to use
org.nd4j.linalg.exception.ND4JIllegalStateException.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: DataSetUtil.java From nd4j with Apache License 2.0 | 6 votes |
/** * Merge the specified labels and label mask arrays (i.e., concatenate the examples) * * @param labelsToMerge Features to merge * @param labelMasksToMerge Mask arrays to merge. May be null * @return Merged features and mask. Mask may be null */ public static Pair<INDArray, INDArray> mergeLabels(INDArray[] labelsToMerge, INDArray[] labelMasksToMerge) { int rankFeatures = labelsToMerge[0].rank(); switch (rankFeatures) { case 2: return DataSetUtil.merge2d(labelsToMerge, labelMasksToMerge); case 3: return DataSetUtil.mergeTimeSeries(labelsToMerge, labelMasksToMerge); case 4: return DataSetUtil.merge4d(labelsToMerge, labelMasksToMerge); default: throw new ND4JIllegalStateException("Cannot merge examples: labels rank must be in range 2 to 4" + " inclusive. First example features shape: " + Arrays.toString(labelsToMerge[0].shape())); } }
Example #2
Source File: OpExecutionerUtil.java From nd4j with Apache License 2.0 | 6 votes |
public static void checkForNaN(INDArray z) { if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) return; int match = 0; if (!z.isScalar()) { MatchCondition condition = new MatchCondition(z, Conditions.isNan()); match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0); } else { if (z.data().dataType() == DataBuffer.Type.DOUBLE) { if (Double.isNaN(z.getDouble(0))) match = 1; } else { if (Float.isNaN(z.getFloat(0))) match = 1; } } if (match > 0) throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): "); }
Example #3
Source File: BaseOp.java From nd4j with Apache License 2.0 | 6 votes |
@Override public void setY(INDArray y) { if (y == null) { if (args() != null && args().length > 1) { DifferentialFunction firstArg = args()[1]; if (firstArg instanceof SDVariable) { SDVariable sdVariable = (SDVariable) firstArg; if (sdVariable.getArr() != null) this.y = sdVariable.getArr(); } } else throw new ND4JIllegalStateException("Unable to set null array for y. Also unable to infer from differential function arguments"); } else this.y = y; numProcessed = 0; }
Example #4
Source File: SDVariable.java From nd4j with Apache License 2.0 | 6 votes |
/** * Allocate and return a new array * based on the vertex id and weight initialization. * @return the allocated array */ public INDArray storeAndAllocateNewArray() { val shape = sameDiff.getShapeForVarName(getVarName()); if(getArr() != null && Arrays.equals(getArr().shape(),shape)) return getArr(); if(varName == null) throw new ND4JIllegalStateException("Unable to store array for null variable name!"); if(shape == null) { throw new ND4JIllegalStateException("Unable to allocate new array. No shape found for variable " + varName); } val arr = getWeightInitScheme().create(shape); sameDiff.putArrayForVarName(getVarName(),arr); return arr; }
Example #5
Source File: Shape.java From nd4j with Apache License 2.0 | 6 votes |
public static int[] normalizeAxis(int rank, int... axis) { if (axis == null || axis.length == 0) return new int[] {Integer.MAX_VALUE}; // first we should get rid of all negative axis int[] tmp = new int[axis.length]; int cnt = 0; for (val v: axis) { val t = v < 0 ? v + rank : v; if ((t >= rank && t != Integer.MAX_VALUE)|| t < 0) throw new ND4JIllegalStateException("Axis array " + Arrays.toString(axis) + " contains values above rank " + rank); tmp[cnt++] = t; } // now we're sorting array Arrays.sort(tmp); // and getting rid of possible duplicates return uniquify(tmp); }
Example #6
Source File: Mmul.java From nd4j with Apache License 2.0 | 6 votes |
@Override public List<long[]> calculateOutputShape() { if(mMulTranspose == null) mMulTranspose = MMulTranspose.allFalse(); List<long[]> ret = new ArrayList<>(1); long[] aShape = mMulTranspose.isTransposeA() ? ArrayUtil.reverseCopy(larg().getShape()) : larg().getShape(); long[] bShape = mMulTranspose.isTransposeB() ? ArrayUtil.reverseCopy(rarg().getShape()) : rarg().getShape(); if(Shape.isPlaceholderShape(aShape) || Shape.isPlaceholderShape(bShape)) return Collections.emptyList(); if(aShape != null && bShape != null) { val shape = Shape.getMatrixMultiplyShape(aShape,bShape); ret.add(shape); } if(!ret.isEmpty()) { for(int i = 0; i < ret.get(0).length; i++) { if(ret.get(0)[i] < 1) throw new ND4JIllegalStateException("Invalid shape computed at index " + i); } } return ret; }
Example #7
Source File: CustomOpsTests.java From nd4j with Apache License 2.0 | 6 votes |
@Test(expected = ND4JIllegalStateException.class) public void testNoneInplaceOp3() throws Exception { val arrayX = Nd4j.create(10, 10); val arrayY = Nd4j.create(10, 10); arrayX.assign(4.0); arrayY.assign(2.0); val exp = Nd4j.create(10,10).assign(6.0); CustomOp op = DynamicCustomOp.builder("add") .addInputs(arrayX, arrayY) .callInplace(false) .build(); Nd4j.getExecutioner().exec(op); assertEquals(exp, arrayX); }
Example #8
Source File: NetworkOrganizer.java From nd4j with Apache License 2.0 | 6 votes |
public String getHottestNetworkA() { StringBuilder builder = new StringBuilder(); int depth = 0; VirtualNode startingNode = getHottestNode(); if (startingNode == null) throw new ND4JIllegalStateException( "VirtualTree wasn't properly initialized, and doesn't have any information within"); builder.append(startingNode.ownChar); for (int i = 0; i < 7; i++) { startingNode = startingNode.getHottestNode(); builder.append(startingNode.ownChar); } return builder.toString(); }
Example #9
Source File: LinAlgExceptions.java From nd4j with Apache License 2.0 | 6 votes |
/** * Asserts matrix multiply rules (columns of left == rows of right or rows of left == columns of right) * * @param nd1 the left ndarray * @param nd2 the right ndarray */ public static void assertMultiplies(INDArray nd1, INDArray nd2) { if (nd1.rank() == 2 && nd2.rank() == 2 && nd1.columns() == nd2.rows()) { return; } // 1D edge case if (nd1.rank() == 2 && nd2.rank() == 1 && nd1.columns() == nd2.length()) return; throw new ND4JIllegalStateException("Cannot execute matrix multiplication: " + Arrays.toString(nd1.shape()) + "x" + Arrays.toString(nd2.shape()) + (nd1.rank() != 2 || nd2.rank() != 2 ? ": inputs are not matrices" : ": Column of left array " + nd1.columns() + " != rows of right " + nd2.rows())); }
Example #10
Source File: TensorMmul.java From nd4j with Apache License 2.0 | 6 votes |
@Override public List<long[]> calculateOutputShape() { List<long[]> ret = new ArrayList<>(1); long[] aShape = mMulTranspose.isTransposeA() ? ArrayUtil.reverseCopy(larg().getShape()) : larg().getShape(); long[] bShape = mMulTranspose.isTransposeB() ? ArrayUtil.reverseCopy(rarg().getShape()) : rarg().getShape(); if(Shape.isPlaceholderShape(aShape) || Shape.isPlaceholderShape(bShape)) return Collections.emptyList(); if(aShape != null && bShape != null) { val shape = getTensorMmulShape(aShape,bShape, axes); ret.add(shape); } if(!ret.isEmpty()) { for(int i = 0; i < ret.get(0).length; i++) { if(ret.get(0)[i] < 1) throw new ND4JIllegalStateException("Invalid shape computed at index " + i); } } return ret; }
Example #11
Source File: Concat.java From nd4j with Apache License 2.0 | 6 votes |
@Override public void assertValidForExecution() { val descriptor = getDescriptor(); if(descriptor == null) throw new NoOpNameFoundException("No descriptor found for op name " + opName()); if(descriptor.getNumInputs() > 0 && numInputArguments() < 2) throw new ND4JIllegalStateException("Op failure for " + opName() + " Number of inputs is invalid for execution. Specified " + numInputArguments() + " but should be " + descriptor.getNumInputs()); if(descriptor.getNumOutputs() > 0 && numOutputArguments() != descriptor.getNumOutputs()) throw new ND4JIllegalStateException("Op failure for " + opName() + " Number of outputs is invalid for execution. Specified " + numOutputArguments() + " but should be " + descriptor.getNumOutputs()); //< 0 means dynamic size if(descriptor.getNumIArgs() >= 0 && numIArguments() != descriptor.getNumIArgs()) throw new ND4JIllegalStateException("Op failure for " + opName() + " Number of integer arguments is invalid for execution. Specified " + numIArguments() + " but should be " + descriptor.getNumIArgs()); if(descriptor.getNumTArgs() >= 0 && numTArguments() != descriptor.getNumTArgs()) throw new ND4JIllegalStateException("Op failure for " + opName() + " Number of inputs is invalid for execution. Specified " + numTArguments() + " but should be " + descriptor.getNumTArgs()); }
Example #12
Source File: DynamicCustomOp.java From nd4j with Apache License 2.0 | 6 votes |
@Override public void assertValidForExecution() { val descriptor = getDescriptor(); if (descriptor == null) throw new NoOpNameFoundException("No descriptor found for op name " + opName()); if (descriptor.getNumInputs() > 0 && numInputArguments() < descriptor.getNumInputs()) throw new ND4JIllegalStateException("Op failure for " + opName() + " Number of inputs is invalid for execution. Specified " + numInputArguments() + " but should be " + descriptor.getNumInputs()); if (descriptor.getNumOutputs() > 0 && numOutputArguments() < descriptor.getNumOutputs()) throw new ND4JIllegalStateException("Op failure for " + opName() + " Number of outputs is invalid for execution. Specified " + numOutputArguments() + " but should be " + descriptor.getNumOutputs()); //< 0 means dynamic size if (descriptor.getNumIArgs() >= 0 && numIArguments() < descriptor.getNumIArgs()) throw new ND4JIllegalStateException("Op failure for " + opName() + " Number of integer arguments is invalid for execution. Specified " + numIArguments() + " but should be " + descriptor.getNumIArgs()); if (descriptor.getNumTArgs() >= 0 && numTArguments() < descriptor.getNumTArgs()) throw new ND4JIllegalStateException("Op failure for " + opName() + " Number of inputs is invalid for execution. Specified " + numTArguments() + " but should be " + descriptor.getNumTArgs()); }
Example #13
Source File: SpecialTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test(expected = ND4JIllegalStateException.class) public void testScalarShuffle1() throws Exception { List<DataSet> listData = new ArrayList<>(); for (int i = 0; i < 3; i++) { INDArray features = Nd4j.ones(25, 25); INDArray label = Nd4j.create(new float[] {1}, new int[] {1}); DataSet dataset = new DataSet(features, label); listData.add(dataset); } DataSet data = DataSet.merge(listData); data.shuffle(); }
Example #14
Source File: DynamicCustomOp.java From nd4j with Apache License 2.0 | 5 votes |
/** * This method takes arbitrary number of Double arguments for op, * Note that this ACCUMULATES arguments. You are able to call this method * multiple times and it will add arguments to a list. * PLEASE NOTE: this method does NOT validate values. * * @return */ public DynamicCustomOpsBuilder addFloatingPointArguments(Double... targs) { if (numTArguments >= 0) { if (targs == null) throw new ND4JIllegalStateException("CustomOp [" + opName + "] expects at least " + numTArguments + " integer arguments. Null was passed instead."); if (numTArguments > targs.length) throw new ND4JIllegalStateException("CustomOp [" + opName + "] expects at least " + numTArguments + " integer arguments, but " + targs.length + " was passed to constructor"); } for (val in : targs) tArguments.add(in); return this; }
Example #15
Source File: CustomOpsTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test(expected = ND4JIllegalStateException.class) public void testScatterUpdate3() throws Exception { val matrix = Nd4j.create(5, 5); val updates = Nd4j.create(2, 5).assign(1.0); int[] dims = new int[]{1}; int[] indices = new int[]{0, 6}; val exp0 = Nd4j.create(1, 5).assign(0); val exp1 = Nd4j.create(1, 5).assign(1); ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD); }
Example #16
Source File: CustomOpsTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test(expected = ND4JIllegalStateException.class) public void testScatterUpdate2() throws Exception { val matrix = Nd4j.create(5, 5); val updates = Nd4j.create(2, 5).assign(1.0); int[] dims = new int[]{0}; int[] indices = new int[]{0, 1}; val exp0 = Nd4j.create(1, 5).assign(0); val exp1 = Nd4j.create(1, 5).assign(1); ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD); }
Example #17
Source File: AbstractDataSetNormalizer.java From nd4j with Apache License 2.0 | 5 votes |
@Override public void transform(INDArray features, INDArray featuresMask) { S featureStats = getFeatureStats(); if(featureStats == null){ throw new ND4JIllegalStateException("Features statistics were not yet calculated. Make sure to run fit() first."); } strategy.preProcess(features, featuresMask, featureStats); }
Example #18
Source File: RandomProjection.java From nd4j with Apache License 2.0 | 5 votes |
private static long[] targetShape(long[] shape, double eps, int targetDimension, boolean auto){ long components = targetDimension; if (auto) components = johnsonLindenStraussMinDim(shape[0], eps).get(0); // JL or user spec edge cases if (auto && (components <= 0 || components > shape[1])){ throw new ND4JIllegalStateException(String.format("Estimation led to a target dimension of %d, which is invalid", components)); } return new long[]{ shape[1], components}; }
Example #19
Source File: TestRandomProjection.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testTargetShapeTooHigh() { exception.expect(ND4JIllegalStateException.class); // original dimension too small targetShape(Nd4j.createUninitialized(new int[]{(int)1e2, 1}), 0.5); // target dimension too high targetShape(z1, 1001); // suggested dimension too high targetShape(z1, 0.1); // original samples too small targetShape(Nd4j.createUninitialized(new int[]{1, 1000}), 0.5); }
Example #20
Source File: SameDiffTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test(expected = ND4JIllegalStateException.class) public void testPlaceHolderWithFullShape() { val sd = SameDiff.create(); val placeholder = sd.var("somevar", new long[]{2, 2}); sd.addAsPlaceHolder(placeholder.getVarName()); assertTrue(sd.isPlaceHolder(placeholder.getVarName())); sd.resolveVariablesWith(Collections.singletonMap(placeholder.getVarName(), Nd4j.linspace(1, 4, 4))); }
Example #21
Source File: CpuThreshold.java From nd4j with Apache License 2.0 | 5 votes |
/** * This method allows you to configure threshold for delta extraction. Pass it as float/double value * * Default value: 1e-3 * @param vars */ @Override public void configure(Object... vars) { if (vars[0] instanceof Number) { Number t = (Number) vars[0]; threshold = FastMath.abs(t.floatValue()); log.info("Setting threshold to [{}]", threshold); } else { throw new ND4JIllegalStateException("Threshold value should be Number"); } }
Example #22
Source File: DifferentialFunction.java From nd4j with Apache License 2.0 | 5 votes |
/** * The left argument for this function * @return */ public SDVariable larg() { val args = args(); if(args == null || args.length == 0) throw new ND4JIllegalStateException("No arguments found."); return args()[0]; }
Example #23
Source File: StridedSlice.java From nd4j with Apache License 2.0 | 5 votes |
@Override public void assertValidForExecution() { if(numInputArguments() != 1 && numInputArguments() != 3 && numInputArguments() != 4) { throw new ND4JIllegalStateException("Num input arguments must be 1 3 or 4."); } if(numIArguments() < 5) { throw new ND4JIllegalStateException("Number of integer arguments must >= 5"); } }
Example #24
Source File: Indices.java From nd4j with Apache License 2.0 | 5 votes |
/** * Create an n dimensional index * based on the given interval indices. * Start and end represent the begin and * end of each interval * @param start the start indexes * @param end the end indexes * @return the interval index relative to the given * start and end indices */ public static INDArrayIndex[] createFromStartAndEnd(INDArray start, INDArray end) { if (start.length() != end.length()) throw new IllegalArgumentException("Start length must be equal to end length"); else { if (start.length() > Integer.MAX_VALUE) throw new ND4JIllegalStateException("Can't proceed with INDArray with length > Integer.MAX_VALUE"); INDArrayIndex[] indexes = new INDArrayIndex[(int) start.length()]; for (int i = 0; i < indexes.length; i++) { indexes[i] = NDArrayIndex.interval(start.getInt(i), end.getInt(i)); } return indexes; } }
Example #25
Source File: NativeOpExecutioner.java From nd4j with Apache License 2.0 | 5 votes |
@Override public long bitmapEncode(INDArray indArray, INDArray target, double threshold) { long length = indArray.lengthLong(); long tLen = target.data().length(); if (tLen != (length / 16 + 5)) throw new ND4JIllegalStateException("Length of target array should be " + (length / 16 + 5)); if (target.data().dataType() != DataBuffer.Type.INT) throw new ND4JIllegalStateException("Target array should have INT dataType"); DataBuffer buffer = target.data(); buffer.put(0, (int) length); buffer.put(1, (int) length); buffer.put(2, Float.floatToIntBits((float) threshold)); // format id buffer.put(3, ThresholdCompression.BITMAP_ENCODING); long affected = 0; if (indArray.data().dataType() == DataBuffer.Type.FLOAT) { affected = loop.encodeBitmapFloat(null, (FloatPointer) indArray.data().addressPointer(), length, (IntPointer) buffer.addressPointer(), (float) threshold); } else if (indArray.data().dataType() == DataBuffer.Type.DOUBLE) { affected = loop.encodeBitmapDouble(null, (DoublePointer) indArray.data().addressPointer(), length, (IntPointer) buffer.addressPointer(), (float) threshold); } else throw new UnsupportedOperationException("HALF precision isn't supported on CPU yet"); return affected; }
Example #26
Source File: DynamicCustomOp.java From nd4j with Apache License 2.0 | 5 votes |
@Override public void addInputArgument(INDArray... arg) { for (int i = 0; i < arg.length; i++) { if (arg[i] == null) throw new ND4JIllegalStateException("Input " + i + " was null!"); } inputArguments.addAll(Arrays.asList(arg)); val args = sameDiff != null ? args() : null; val arrsSoFar = inputArguments(); //validate arrays passed in, keep in mind that //this is a cumulative algorithm so we should always //refresh the current list if (args != null) { for (int i = 0; i < args.length; i++) { // it's possible to get into situation where number of args > number of arrays AT THIS MOMENT if (i >= arrsSoFar.length) continue; if (!Arrays.equals(args[i].getShape(), arrsSoFar[i].shape())) throw new ND4JIllegalStateException("Illegal array passed in. Expected shape " + Arrays.toString(args[i].getShape()) + " and received array with shape " + Arrays.toString(arg[i].shape())); } } }
Example #27
Source File: OnnxGraphMapper.java From nd4j with Apache License 2.0 | 5 votes |
@Override public String getAttrValueFromNode(OnnxProto3.NodeProto nodeProto, String key) { for(OnnxProto3.AttributeProto attributeProto : nodeProto.getAttributeList()) { if(attributeProto.getName().equals(key)) { return attributeProto.getS().toString(); } } throw new ND4JIllegalStateException("No key found for " + key); }
Example #28
Source File: OnnxGraphMapper.java From nd4j with Apache License 2.0 | 5 votes |
@Override public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph) { DataBuffer.Type type = dataTypeForTensor(tensorProto); if(!tensorProto.isInitialized()) { throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized"); } OnnxProto3.TensorProto tensor = null; for(int i = 0; i < graph.getInitializerCount(); i++) { val initializer = graph.getInitializer(i); if(initializer.getName().equals(tensorName)) { tensor = initializer; break; } } if(tensor == null) return null; ByteString bytes = tensor.getRawData(); ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder()); ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder()); directAlloc.put(byteBuffer); directAlloc.rewind(); long[] shape = getShapeFromTensor(tensorProto); DataBuffer buffer = Nd4j.createBuffer(directAlloc,type, ArrayUtil.prod(shape)); INDArray arr = Nd4j.create(buffer).reshape(shape); return arr; }
Example #29
Source File: Reshape.java From nd4j with Apache License 2.0 | 5 votes |
@Override public List<SDVariable> doDiff(List<SDVariable> i_v) { val origShape = arg().getShape(); if (origShape == null) { //TODO need a more robust way to do this throw new ND4JIllegalStateException("Cannot reshape: original array input shape is null"); } SDVariable ret = f().reshape(i_v.get(0), origShape); return Arrays.asList(ret); }
Example #30
Source File: Shape.java From nd4j with Apache License 2.0 | 5 votes |
public static long[] getMatrixMultiplyShape(long[] left, long[] right) { if(Shape.shapeIsScalar(left)) { return right; } if(Shape.shapeIsScalar(right)) { return left; } if (left.length != 2 && right.length != 2) { throw new IllegalArgumentException("Illegal shapes for matrix multiply. Must be of length 2"); } for(int i = 0; i < left.length; i++) { if(left[i] < 1) throw new ND4JIllegalStateException("Left shape contained value < 0 at index " + i); } for(int i = 0; i < right.length; i++) { if(right[i] < 1) throw new ND4JIllegalStateException("Right shape contained value < 0 at index " + i); } if (left.length > 1 && left[1] != right[0]) throw new IllegalArgumentException("Columns of left not equal to rows of right"); if(left.length < right.length) { if(left[0] == right[0]) { return new long[] {1, right[1]}; } } long[] shape = {left[0], right[1]}; return shape; }