Java Code Examples for org.deeplearning4j.nn.conf.inputs.InputType#feedForward()

The following examples show how to use org.deeplearning4j.nn.conf.inputs.InputType#feedForward() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ActivationLayerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testInputTypes() {
    org.deeplearning4j.nn.conf.layers.ActivationLayer l =
                    new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder().activation(Activation.RELU)
                                    .build();


    InputType in1 = InputType.feedForward(20);
    InputType in2 = InputType.convolutional(28, 28, 1);

    assertEquals(in1, l.getOutputType(0, in1));
    assertEquals(in2, l.getOutputType(0, in2));
    assertNull(l.getPreProcessorForInputType(in1));
    assertNull(l.getPreProcessorForInputType(in2));
}
 
Example 2
Source File: ReshapePreprocessor.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
    long[] shape = getShape(this.targetShape, 0);
    InputType ret;
    switch (shape.length) {
        case 2:
            ret = InputType.feedForward(shape[1]);
            break;
        case 3:
            RNNFormat format = RNNFormat.NCW;
            if(this.format != null && this.format instanceof RNNFormat)
                format = (RNNFormat)this.format;

            ret = InputType.recurrent(shape[2], shape[1], format);
            break;
        case 4:
            if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
                ret = InputType.convolutional(shape[1], shape[2], shape[3]);
            } else {

                CNN2DFormat cnnFormat = CNN2DFormat.NCHW;
                if (this.format != null && this.format instanceof CNN2DFormat)
                    cnnFormat = (CNN2DFormat) this.format;

                if (cnnFormat == CNN2DFormat.NCHW) {
                    ret = InputType.convolutional(shape[2], shape[3], shape[1], cnnFormat);
                } else {
                    ret = InputType.convolutional(shape[1], shape[2], shape[3], cnnFormat);
                }
            }
            break;
        default:
            throw new UnsupportedOperationException(
                    "Cannot infer input type for reshape array " + Arrays.toString(shape));
    }
    return ret;
}
 
Example 3
Source File: Cnn3DToFeedForwardPreProcessor.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(InputType inputType) {
    if (inputType == null || inputType.getType() != InputType.Type.CNN3D) {
        throw new IllegalStateException("Invalid input type: Expected input of type CNN3D, got " + inputType);
    }

    InputType.InputTypeConvolutional3D c = (InputType.InputTypeConvolutional3D) inputType;
    val outSize = c.getChannels() * c.getDepth() * c.getHeight() * c.getWidth();
    return InputType.feedForward(outSize);
}
 
Example 4
Source File: CnnToFeedForwardPreProcessor.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(InputType inputType) {
    if (inputType == null || inputType.getType() != InputType.Type.CNN) {
        throw new IllegalStateException("Invalid input type: Expected input of type CNN, got " + inputType);
    }

    InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
    val outSize = c.getChannels() * c.getHeight() * c.getWidth();
    return InputType.feedForward(outSize);
}
 
Example 5
Source File: RnnToFeedForwardPreProcessor.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(InputType inputType) {
    if (inputType == null || inputType.getType() != InputType.Type.RNN) {
        throw new IllegalStateException("Invalid input: expected input of type RNN, got " + inputType);
    }

    InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType;
    return InputType.feedForward(rnn.getSize(), rnn.getFormat());
}
 
Example 6
Source File: TimeDistributed.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
    if (inputType.getType() != InputType.Type.RNN) {
        throw new IllegalStateException("Only RNN input type is supported as input to TimeDistributed layer (layer #" + layerIndex + ")");
    }

    InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType;
    InputType ff = InputType.feedForward(rnn.getSize());
    InputType ffOut = underlying.getOutputType(layerIndex, ff);
    return InputType.recurrent(ffOut.arrayElementsPerExample(), rnn.getTimeSeriesLength(), rnnDataFormat);
}
 
Example 7
Source File: TimeDistributed.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void setNIn(InputType inputType, boolean override) {
    if (inputType.getType() != InputType.Type.RNN) {
        throw new IllegalStateException("Only RNN input type is supported as input to TimeDistributed layer");
    }

    InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType;
    InputType ff = InputType.feedForward(rnn.getSize());
    this.rnnDataFormat = rnn.getFormat();
    underlying.setNIn(ff, override);
}
 
Example 8
Source File: LastTimeStep.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
    if (inputType.getType() != InputType.Type.RNN) {
        throw new IllegalArgumentException("Require RNN input type - got " + inputType);
    }
    InputType outType = underlying.getOutputType(layerIndex, inputType);
    InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) outType;
    return InputType.feedForward(r.getSize());
}
 
Example 9
Source File: CapsuleStrengthLayer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {

    if(inputType == null || inputType.getType() != Type.RNN) {
        throw new IllegalStateException("Invalid input for Capsule Strength layer (layer name = \""
                + layerName + "\"): expect RNN input.  Got: " + inputType);
    }

    InputTypeRecurrent ri = (InputTypeRecurrent) inputType;
    return InputType.feedForward(ri.getSize());
}
 
Example 10
Source File: FeedForwardLayer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
    if (inputType == null || (inputType.getType() != InputType.Type.FF
                    && inputType.getType() != InputType.Type.CNNFlat)) {
        throw new IllegalStateException("Invalid input type (layer index = " + layerIndex + ", layer name=\""
                        + getLayerName() + "\"): expected FeedForward input type. Got: " + inputType);
    }

    return InputType.feedForward(nOut, timeDistributedFormat);
}
 
Example 11
Source File: DropoutLayerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testInputTypes() {
    DropoutLayer config = new DropoutLayer.Builder(0.5).build();

    InputType in1 = InputType.feedForward(20);
    InputType in2 = InputType.convolutional(28, 28, 1);

    assertEquals(in1, config.getOutputType(0, in1));
    assertEquals(in2, config.getOutputType(0, in2));
    assertNull(config.getPreProcessorForInputType(in1));
    assertNull(config.getPreProcessorForInputType(in2));
}
 
Example 12
Source File: LastTimeStepVertex.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
    if (vertexInputs.length != 1)
        throw new InvalidInputTypeException("Invalid input type: cannot get last time step of more than 1 input");
    if (vertexInputs[0].getType() != InputType.Type.RNN) {
        throw new InvalidInputTypeException(
                        "Invalid input type: cannot get subset of non RNN input (got: " + vertexInputs[0] + ")");
    }

    return InputType.feedForward(((InputType.InputTypeRecurrent) vertexInputs[0]).getSize());
}
 
Example 13
Source File: ReshapeVertex.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
    //Infer output shape from specified shape:
    switch (newShape.length) {
        case 2:
            return InputType.feedForward(newShape[1]);
        case 3:
            return InputType.recurrent(newShape[1]);
        case 4:
            return InputType.convolutional(newShape[2], newShape[3], newShape[1]); //[mb,d,h,w] for activations
        default:
            throw new UnsupportedOperationException(
                            "Cannot infer input type for reshape array " + Arrays.toString(newShape));
    }
}
 
Example 14
Source File: L2Vertex.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
    return InputType.feedForward(1);
}
 
Example 15
Source File: DeepFMProductConfiguration.java    From jstarcraft-rns with Apache License 2.0 4 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
    return InputType.feedForward(1);
}
 
Example 16
Source File: CustomLayer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
    return InputType.feedForward(10);
}
 
Example 17
Source File: CustomLayer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
    return InputType.feedForward(10);
}
 
Example 18
Source File: GlobalPoolingLayer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {

    switch (inputType.getType()) {
        case FF:
            throw new UnsupportedOperationException(
                            "Global max pooling cannot be applied to feed-forward input type. Got input type = "
                                            + inputType);
        case RNN:
            InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) inputType;
            if (collapseDimensions) {
                //Return 2d (feed-forward) activations
                return InputType.feedForward(recurrent.getSize());
            } else {
                //Return 3d activations, with shape [minibatch, timeStepSize, 1]
                return recurrent;
            }
        case CNN:
            InputType.InputTypeConvolutional conv = (InputType.InputTypeConvolutional) inputType;
            if (collapseDimensions) {
                return InputType.feedForward(conv.getChannels());
            } else {
                return InputType.convolutional(1, 1, conv.getChannels(), conv.getFormat());
            }
        case CNN3D:
            InputType.InputTypeConvolutional3D conv3d = (InputType.InputTypeConvolutional3D) inputType;
            if (collapseDimensions) {
                return InputType.feedForward(conv3d.getChannels());
            } else {
                return InputType.convolutional3D(1, 1, 1, conv3d.getChannels());
            }
        case CNNFlat:
            InputType.InputTypeConvolutionalFlat convFlat = (InputType.InputTypeConvolutionalFlat) inputType;
            if (collapseDimensions) {
                return InputType.feedForward(convFlat.getDepth());
            } else {
                return InputType.convolutional(1, 1, convFlat.getDepth());
            }
        default:
            throw new UnsupportedOperationException("Unknown or not supported input type: " + inputType);
    }
}
 
Example 19
Source File: SameDiffMSEOutputLayer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
    return InputType.feedForward(nOut);
}
 
Example 20
Source File: KerasFlattenRnnPreprocessor.java    From deeplearning4j with Apache License 2.0 2 votes vote down vote up
@Override
public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {

    return InputType.feedForward(depth * tsLength);

}