org.nd4j.linalg.exception.ND4JIllegalStateException Java Examples

The following examples show how to use org.nd4j.linalg.exception.ND4JIllegalStateException. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: DataSetUtil.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * Merge the specified labels and label mask arrays (i.e., concatenate the examples)
 *
 * @param labelsToMerge     Features to merge
 * @param labelMasksToMerge Mask arrays to merge. May be null
 * @return Merged features and mask. Mask may be null
 */
public static Pair<INDArray, INDArray> mergeLabels(INDArray[] labelsToMerge, INDArray[] labelMasksToMerge) {
    int rankFeatures = labelsToMerge[0].rank();

    switch (rankFeatures) {
        case 2:
            return DataSetUtil.merge2d(labelsToMerge, labelMasksToMerge);
        case 3:
            return DataSetUtil.mergeTimeSeries(labelsToMerge, labelMasksToMerge);
        case 4:
            return DataSetUtil.merge4d(labelsToMerge, labelMasksToMerge);
        default:
            throw new ND4JIllegalStateException("Cannot merge examples: labels rank must be in range 2 to 4"
                            + " inclusive. First example features shape: "
                            + Arrays.toString(labelsToMerge[0].shape()));
    }
}
 
Example #2
Source File: OpExecutionerUtil.java    From nd4j with Apache License 2.0 6 votes vote down vote up
public static void checkForNaN(INDArray z) {
    if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC
                    && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
        return;

    int match = 0;
    if (!z.isScalar()) {
        MatchCondition condition = new MatchCondition(z, Conditions.isNan());
        match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0);
    } else {
        if (z.data().dataType() == DataBuffer.Type.DOUBLE) {
            if (Double.isNaN(z.getDouble(0)))
                match = 1;
        } else {
            if (Float.isNaN(z.getFloat(0)))
                match = 1;
        }
    }

    if (match > 0)
        throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): ");
}
 
Example #3
Source File: BaseOp.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public void setY(INDArray y) {
    if (y == null) {
        if (args() != null && args().length > 1) {
            DifferentialFunction firstArg = args()[1];
            if (firstArg instanceof SDVariable) {
                SDVariable sdVariable = (SDVariable) firstArg;
                if (sdVariable.getArr() != null)
                    this.y = sdVariable.getArr();
            }
        } else
            throw new ND4JIllegalStateException("Unable to set null array for y. Also unable to infer from differential function arguments");
    } else
        this.y = y;
    numProcessed = 0;
}
 
Example #4
Source File: SDVariable.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * Allocate and return a  new array
 * based on the vertex id and weight initialization.
 * @return the allocated array
 */
public INDArray storeAndAllocateNewArray() {
    val shape = sameDiff.getShapeForVarName(getVarName());
    if(getArr() != null && Arrays.equals(getArr().shape(),shape))
        return getArr();

    if(varName == null)
        throw new ND4JIllegalStateException("Unable to store array for null variable name!");

    if(shape == null) {
        throw new ND4JIllegalStateException("Unable to allocate new array. No shape found for variable " + varName);
    }

    val arr = getWeightInitScheme().create(shape);
    sameDiff.putArrayForVarName(getVarName(),arr);
    return arr;
}
 
Example #5
Source File: Shape.java    From nd4j with Apache License 2.0 6 votes vote down vote up
public static int[] normalizeAxis(int rank, int... axis) {
    if (axis == null || axis.length == 0)
        return new int[] {Integer.MAX_VALUE};

    // first we should get rid of all negative axis
    int[] tmp = new int[axis.length];

    int cnt = 0;
    for (val v: axis) {
        val t = v < 0 ? v + rank : v;

        if ((t >= rank && t != Integer.MAX_VALUE)|| t < 0)
            throw new ND4JIllegalStateException("Axis array " + Arrays.toString(axis) + " contains values above rank " + rank);

        tmp[cnt++] = t;
    }

    // now we're sorting array
    Arrays.sort(tmp);

    // and getting rid of possible duplicates
    return uniquify(tmp);
}
 
Example #6
Source File: Mmul.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public List<long[]> calculateOutputShape() {
    if(mMulTranspose == null)
        mMulTranspose = MMulTranspose.allFalse();
    List<long[]> ret = new ArrayList<>(1);
    long[] aShape = mMulTranspose.isTransposeA() ? ArrayUtil.reverseCopy(larg().getShape()) : larg().getShape();
    long[] bShape = mMulTranspose.isTransposeB() ? ArrayUtil.reverseCopy(rarg().getShape()) : rarg().getShape();
    if(Shape.isPlaceholderShape(aShape) || Shape.isPlaceholderShape(bShape))
        return Collections.emptyList();

    if(aShape != null && bShape != null) {
        val shape =  Shape.getMatrixMultiplyShape(aShape,bShape);
        ret.add(shape);
    }
    if(!ret.isEmpty()) {
        for(int i = 0; i < ret.get(0).length; i++) {
            if(ret.get(0)[i] < 1)
                throw new ND4JIllegalStateException("Invalid shape computed at index " +  i);
        }
    }
    return ret;
}
 
Example #7
Source File: CustomOpsTests.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test(expected = ND4JIllegalStateException.class)
public void testNoneInplaceOp3() throws Exception {
    val arrayX = Nd4j.create(10, 10);
    val arrayY = Nd4j.create(10, 10);

    arrayX.assign(4.0);
    arrayY.assign(2.0);

    val exp = Nd4j.create(10,10).assign(6.0);

    CustomOp op = DynamicCustomOp.builder("add")
            .addInputs(arrayX, arrayY)
            .callInplace(false)
            .build();

    Nd4j.getExecutioner().exec(op);

    assertEquals(exp, arrayX);
}
 
Example #8
Source File: NetworkOrganizer.java    From nd4j with Apache License 2.0 6 votes vote down vote up
public String getHottestNetworkA() {
    StringBuilder builder = new StringBuilder();

    int depth = 0;
    VirtualNode startingNode = getHottestNode();

    if (startingNode == null)
        throw new ND4JIllegalStateException(
                        "VirtualTree wasn't properly initialized, and doesn't have any information within");

    builder.append(startingNode.ownChar);

    for (int i = 0; i < 7; i++) {
        startingNode = startingNode.getHottestNode();
        builder.append(startingNode.ownChar);
    }

    return builder.toString();
}
 
Example #9
Source File: LinAlgExceptions.java    From nd4j with Apache License 2.0 6 votes vote down vote up
/**
 * Asserts matrix multiply rules (columns of left == rows of right or rows of left == columns of right)
 *
 * @param nd1 the left ndarray
 * @param nd2 the right ndarray
 */
public static void assertMultiplies(INDArray nd1, INDArray nd2) {
    if (nd1.rank() == 2 && nd2.rank() == 2 && nd1.columns() == nd2.rows()) {
        return;
    }

    // 1D edge case
    if (nd1.rank() == 2 && nd2.rank() == 1 && nd1.columns() == nd2.length())
        return;

    throw new ND4JIllegalStateException("Cannot execute matrix multiplication: " + Arrays.toString(nd1.shape())
                    + "x" + Arrays.toString(nd2.shape())
                    + (nd1.rank() != 2 || nd2.rank() != 2 ? ": inputs are not matrices"
                                    : ": Column of left array " + nd1.columns() + " != rows of right "
                                                    + nd2.rows()));
}
 
Example #10
Source File: TensorMmul.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public List<long[]> calculateOutputShape() {
    List<long[]> ret = new ArrayList<>(1);
    long[] aShape = mMulTranspose.isTransposeA() ? ArrayUtil.reverseCopy(larg().getShape()) : larg().getShape();
    long[] bShape = mMulTranspose.isTransposeB() ? ArrayUtil.reverseCopy(rarg().getShape()) : rarg().getShape();
    if(Shape.isPlaceholderShape(aShape) || Shape.isPlaceholderShape(bShape))
        return Collections.emptyList();

    if(aShape != null && bShape != null) {
        val shape =  getTensorMmulShape(aShape,bShape, axes);
        ret.add(shape);
    }
    if(!ret.isEmpty()) {
        for(int i = 0; i < ret.get(0).length; i++) {
            if(ret.get(0)[i] < 1)
                throw new ND4JIllegalStateException("Invalid shape computed at index " +  i);
        }
    }
    return ret;
}
 
Example #11
Source File: Concat.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public void assertValidForExecution() {
    val descriptor = getDescriptor();
    if(descriptor == null)
        throw new NoOpNameFoundException("No descriptor found for op name " + opName());


    if(descriptor.getNumInputs() > 0 && numInputArguments() < 2)
        throw new ND4JIllegalStateException("Op failure for " + opName() + " Number of inputs is invalid for execution. Specified " + numInputArguments() + " but should be " + descriptor.getNumInputs());

    if(descriptor.getNumOutputs() > 0 && numOutputArguments() != descriptor.getNumOutputs())
        throw new ND4JIllegalStateException("Op failure for " + opName() + " Number of outputs is invalid for execution. Specified " + numOutputArguments() + " but should be " + descriptor.getNumOutputs());

    //< 0 means dynamic size
    if(descriptor.getNumIArgs() >= 0 && numIArguments() != descriptor.getNumIArgs())
        throw new ND4JIllegalStateException("Op failure for " + opName() + " Number of integer arguments is invalid for execution. Specified " + numIArguments() + " but should be " + descriptor.getNumIArgs());

    if(descriptor.getNumTArgs() >= 0 && numTArguments() != descriptor.getNumTArgs())
        throw new ND4JIllegalStateException("Op failure for " + opName() + " Number of inputs is invalid for execution. Specified " + numTArguments() + " but should be " + descriptor.getNumTArgs());

}
 
Example #12
Source File: DynamicCustomOp.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public void assertValidForExecution() {
    val descriptor = getDescriptor();
    if (descriptor == null)
        throw new NoOpNameFoundException("No descriptor found for op name " + opName());


    if (descriptor.getNumInputs() > 0 && numInputArguments() < descriptor.getNumInputs())
        throw new ND4JIllegalStateException("Op failure for " + opName() + " Number of inputs is invalid for execution. Specified " + numInputArguments() + " but should be " + descriptor.getNumInputs());

    if (descriptor.getNumOutputs() > 0 && numOutputArguments() < descriptor.getNumOutputs())
        throw new ND4JIllegalStateException("Op failure for " + opName() + " Number of outputs is invalid for execution. Specified " + numOutputArguments() + " but should be " + descriptor.getNumOutputs());

    //< 0 means dynamic size
    if (descriptor.getNumIArgs() >= 0 && numIArguments() < descriptor.getNumIArgs())
        throw new ND4JIllegalStateException("Op failure for " + opName() + " Number of integer arguments is invalid for execution. Specified " + numIArguments() + " but should be " + descriptor.getNumIArgs());

    if (descriptor.getNumTArgs() >= 0 && numTArguments() < descriptor.getNumTArgs())
        throw new ND4JIllegalStateException("Op failure for " + opName() + " Number of inputs is invalid for execution. Specified " + numTArguments() + " but should be " + descriptor.getNumTArgs());

}
 
Example #13
Source File: SpecialTests.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test(expected = ND4JIllegalStateException.class)
public void testScalarShuffle1() throws Exception {
    List<DataSet> listData = new ArrayList<>();
    for (int i = 0; i < 3; i++) {
        INDArray features = Nd4j.ones(25, 25);
        INDArray label = Nd4j.create(new float[] {1}, new int[] {1});
        DataSet dataset = new DataSet(features, label);
        listData.add(dataset);
    }
    DataSet data = DataSet.merge(listData);
    data.shuffle();
}
 
Example #14
Source File: DynamicCustomOp.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * This method takes arbitrary number of Double arguments for op,
 * Note that this ACCUMULATES arguments. You are able to call this method
 * multiple times and it will add arguments to a list.
 * PLEASE NOTE: this method does NOT validate values.
 *
 * @return
 */
public DynamicCustomOpsBuilder addFloatingPointArguments(Double... targs) {
    if (numTArguments >= 0) {
        if (targs == null)
            throw new ND4JIllegalStateException("CustomOp [" + opName + "] expects at least " + numTArguments + " integer arguments. Null was passed instead.");

        if (numTArguments > targs.length)
            throw new ND4JIllegalStateException("CustomOp [" + opName + "] expects at least " + numTArguments + " integer arguments, but " + targs.length + " was passed to constructor");
    }

    for (val in : targs)
        tArguments.add(in);

    return this;
}
 
Example #15
Source File: CustomOpsTests.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test(expected = ND4JIllegalStateException.class)
public void testScatterUpdate3() throws Exception {
    val matrix = Nd4j.create(5, 5);
    val updates = Nd4j.create(2, 5).assign(1.0);
    int[] dims = new int[]{1};
    int[] indices = new int[]{0, 6};

    val exp0 = Nd4j.create(1, 5).assign(0);
    val exp1 = Nd4j.create(1, 5).assign(1);

    ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD);
}
 
Example #16
Source File: CustomOpsTests.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test(expected = ND4JIllegalStateException.class)
public void testScatterUpdate2() throws Exception {
    val matrix = Nd4j.create(5, 5);
    val updates = Nd4j.create(2, 5).assign(1.0);
    int[] dims = new int[]{0};
    int[] indices = new int[]{0, 1};

    val exp0 = Nd4j.create(1, 5).assign(0);
    val exp1 = Nd4j.create(1, 5).assign(1);

    ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD);
}
 
Example #17
Source File: AbstractDataSetNormalizer.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public void transform(INDArray features, INDArray featuresMask) {
    S featureStats = getFeatureStats();

    if(featureStats == null){
        throw new ND4JIllegalStateException("Features statistics were not yet calculated. Make sure to run fit() first.");
    }

    strategy.preProcess(features, featuresMask, featureStats);    }
 
Example #18
Source File: RandomProjection.java    From nd4j with Apache License 2.0 5 votes vote down vote up
private static long[] targetShape(long[] shape, double eps, int targetDimension, boolean auto){
    long components = targetDimension;
    if (auto) components = johnsonLindenStraussMinDim(shape[0], eps).get(0);
    // JL or user spec edge cases
    if (auto && (components <= 0 || components > shape[1])){
        throw new ND4JIllegalStateException(String.format("Estimation led to a target dimension of %d, which is invalid", components));
    }
    return new long[]{ shape[1], components};
}
 
Example #19
Source File: TestRandomProjection.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testTargetShapeTooHigh() {
    exception.expect(ND4JIllegalStateException.class);
    // original dimension too small
    targetShape(Nd4j.createUninitialized(new int[]{(int)1e2, 1}), 0.5);
    // target dimension too high
    targetShape(z1, 1001);
    // suggested dimension too high
    targetShape(z1, 0.1);
    // original samples too small
    targetShape(Nd4j.createUninitialized(new int[]{1, 1000}), 0.5);
}
 
Example #20
Source File: SameDiffTests.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test(expected = ND4JIllegalStateException.class)
public void testPlaceHolderWithFullShape() {
    val sd = SameDiff.create();
    val placeholder = sd.var("somevar", new long[]{2, 2});
    sd.addAsPlaceHolder(placeholder.getVarName());
    assertTrue(sd.isPlaceHolder(placeholder.getVarName()));
    sd.resolveVariablesWith(Collections.singletonMap(placeholder.getVarName(), Nd4j.linspace(1, 4, 4)));
}
 
Example #21
Source File: CpuThreshold.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * This method allows you to configure threshold for delta extraction. Pass it as float/double value
 *
 * Default value: 1e-3
 * @param vars
 */
@Override
public void configure(Object... vars) {
    if (vars[0] instanceof Number) {
        Number t = (Number) vars[0];
        threshold = FastMath.abs(t.floatValue());
        log.info("Setting threshold to [{}]", threshold);
    } else {
        throw new ND4JIllegalStateException("Threshold value should be Number");
    }
}
 
Example #22
Source File: DifferentialFunction.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * The left argument for this function
 * @return
 */
public SDVariable larg() {
    val args = args();
    if(args == null || args.length == 0)
        throw new ND4JIllegalStateException("No arguments found.");
    return args()[0];
}
 
Example #23
Source File: StridedSlice.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public void assertValidForExecution() {
    if(numInputArguments() != 1 && numInputArguments() != 3 && numInputArguments() != 4) {
        throw new ND4JIllegalStateException("Num input arguments must be 1 3 or 4.");
    }

    if(numIArguments() < 5) {
        throw new ND4JIllegalStateException("Number of integer arguments must >= 5");
    }
}
 
Example #24
Source File: Indices.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Create an n dimensional index
 * based on the given interval indices.
 * Start and end represent the begin and
 * end of each interval
 * @param start the start indexes
 * @param end the end indexes
 * @return the interval index relative to the given
 * start and end indices
 */
public static INDArrayIndex[] createFromStartAndEnd(INDArray start, INDArray end) {
    if (start.length() != end.length())
        throw new IllegalArgumentException("Start length must be equal to end length");
    else {
        if (start.length() > Integer.MAX_VALUE)
            throw new ND4JIllegalStateException("Can't proceed with INDArray with length > Integer.MAX_VALUE");

        INDArrayIndex[] indexes = new INDArrayIndex[(int) start.length()];
        for (int i = 0; i < indexes.length; i++) {
            indexes[i] = NDArrayIndex.interval(start.getInt(i), end.getInt(i));
        }
        return indexes;
    }
}
 
Example #25
Source File: NativeOpExecutioner.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public long bitmapEncode(INDArray indArray, INDArray target, double threshold) {
    long length = indArray.lengthLong();
    long tLen = target.data().length();

    if (tLen != (length / 16 + 5))
        throw new ND4JIllegalStateException("Length of target array should be " + (length / 16 + 5));

    if (target.data().dataType() != DataBuffer.Type.INT)
        throw new ND4JIllegalStateException("Target array should have INT dataType");

    DataBuffer buffer = target.data();

    buffer.put(0, (int) length);
    buffer.put(1, (int) length);
    buffer.put(2, Float.floatToIntBits((float) threshold));

    // format id
    buffer.put(3, ThresholdCompression.BITMAP_ENCODING);

    long affected = 0;

    if (indArray.data().dataType() == DataBuffer.Type.FLOAT) {
        affected = loop.encodeBitmapFloat(null, (FloatPointer) indArray.data().addressPointer(), length, (IntPointer) buffer.addressPointer(), (float) threshold);
    } else if (indArray.data().dataType() == DataBuffer.Type.DOUBLE) {
        affected = loop.encodeBitmapDouble(null, (DoublePointer) indArray.data().addressPointer(), length, (IntPointer) buffer.addressPointer(), (float) threshold);
    } else
        throw new UnsupportedOperationException("HALF precision isn't supported on CPU yet");

    return affected;
}
 
Example #26
Source File: DynamicCustomOp.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public void addInputArgument(INDArray... arg) {
    for (int i = 0; i < arg.length; i++) {
        if (arg[i] == null)
            throw new ND4JIllegalStateException("Input " + i + " was null!");
    }


    inputArguments.addAll(Arrays.asList(arg));

    val args = sameDiff != null ? args() : null;
    val arrsSoFar = inputArguments();
    //validate arrays passed in, keep in mind that
    //this is a cumulative algorithm so we should always
    //refresh the current list
    if (args != null) {
        for (int i = 0; i < args.length; i++) {

            // it's possible to get into situation where number of args > number of arrays AT THIS MOMENT
            if (i >= arrsSoFar.length)
                continue;

            if (!Arrays.equals(args[i].getShape(), arrsSoFar[i].shape()))
                throw new ND4JIllegalStateException("Illegal array passed in. Expected shape " + Arrays.toString(args[i].getShape()) + " and received array with shape " + Arrays.toString(arg[i].shape()));
        }
    }
}
 
Example #27
Source File: OnnxGraphMapper.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public String getAttrValueFromNode(OnnxProto3.NodeProto nodeProto, String key) {
    for(OnnxProto3.AttributeProto attributeProto : nodeProto.getAttributeList()) {
        if(attributeProto.getName().equals(key)) {
            return attributeProto.getS().toString();
        }
    }

    throw new ND4JIllegalStateException("No key found for " + key);
}
 
Example #28
Source File: OnnxGraphMapper.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph) {
    DataBuffer.Type type = dataTypeForTensor(tensorProto);
    if(!tensorProto.isInitialized()) {
        throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
    }

    OnnxProto3.TensorProto tensor = null;
    for(int i = 0; i < graph.getInitializerCount(); i++) {
        val initializer = graph.getInitializer(i);
        if(initializer.getName().equals(tensorName)) {
            tensor = initializer;
            break;
        }
    }

    if(tensor == null)
        return null;

    ByteString bytes = tensor.getRawData();
    ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
    ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
    directAlloc.put(byteBuffer);
    directAlloc.rewind();
    long[] shape = getShapeFromTensor(tensorProto);
    DataBuffer buffer = Nd4j.createBuffer(directAlloc,type, ArrayUtil.prod(shape));
    INDArray arr = Nd4j.create(buffer).reshape(shape);
    return arr;
}
 
Example #29
Source File: Reshape.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
    val origShape = arg().getShape();
    if (origShape == null) {
        //TODO need a more robust way to do this
        throw new ND4JIllegalStateException("Cannot reshape: original array input shape is null");
    }
    SDVariable ret = f().reshape(i_v.get(0), origShape);
    return Arrays.asList(ret);
}
 
Example #30
Source File: Shape.java    From nd4j with Apache License 2.0 5 votes vote down vote up
public static long[] getMatrixMultiplyShape(long[] left, long[] right) {
    if(Shape.shapeIsScalar(left)) {
        return right;
    }

    if(Shape.shapeIsScalar(right)) {
        return left;
    }

    if (left.length != 2 && right.length != 2) {
        throw new IllegalArgumentException("Illegal shapes for matrix multiply. Must be of length 2");
    }

    for(int i = 0; i < left.length; i++) {
        if(left[i] < 1)
            throw new ND4JIllegalStateException("Left shape contained value < 0 at index " + i);
    }



    for(int i = 0; i < right.length; i++) {
        if(right[i] < 1)
            throw new ND4JIllegalStateException("Right shape contained value < 0 at index " + i);
    }


    if (left.length > 1 && left[1] != right[0])
        throw new IllegalArgumentException("Columns of left not equal to rows of right");

    if(left.length < right.length) {
        if(left[0] == right[0]) {
            return new long[] {1, right[1]};
        }
    }

    long[] shape = {left[0], right[1]};
    return shape;
}