Java Code Examples for org.deeplearning4j.nn.conf.inputs.InputType#InputTypeConvolutionalFlat

The following examples show how to use org.deeplearning4j.nn.conf.inputs.InputType#InputTypeConvolutionalFlat . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: GlobalPoolingLayer.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {

    switch (inputType.getType()) {
        case FF:
            throw new UnsupportedOperationException(
                            "Global max pooling cannot be applied to feed-forward input type. Got input type = "
                                            + inputType);
        case RNN:
        case CNN:
        case CNN3D:
            //No preprocessor required
            return null;
        case CNNFlat:
            InputType.InputTypeConvolutionalFlat cFlat = (InputType.InputTypeConvolutionalFlat) inputType;
            return new FeedForwardToCnnPreProcessor(cFlat.getHeight(), cFlat.getWidth(), cFlat.getDepth());
    }

    return null;
}
 
Example 2
Source File: FeedForwardToRnnPreProcessor.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(InputType inputType) {
    if (inputType == null || (inputType.getType() != InputType.Type.FF
                    && inputType.getType() != InputType.Type.CNNFlat)) {
        throw new IllegalStateException("Invalid input: expected input of type FeedForward, got " + inputType);
    }

    if (inputType.getType() == InputType.Type.FF) {
        InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward) inputType;
        return InputType.recurrent(ff.getSize(), rnnDataFormat);
    } else {
        InputType.InputTypeConvolutionalFlat cf = (InputType.InputTypeConvolutionalFlat) inputType;
        return InputType.recurrent(cf.getFlattenedSize(), rnnDataFormat);
    }
}
 
Example 3
Source File: FeedForwardToCnnPreProcessor.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(InputType inputType) {

    switch (inputType.getType()) {
        case FF:
            InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType;
            val expSize = inputHeight * inputWidth * numChannels;
            if (c.getSize() != expSize) {
                throw new IllegalStateException("Invalid input: expected FeedForward input of size " + expSize
                                + " = (d=" + numChannels + " * w=" + inputWidth + " * h=" + inputHeight + "), got "
                                + inputType);
            }
            return InputType.convolutional(inputHeight, inputWidth, numChannels);
        case CNN:
            InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType;

            if (c2.getChannels() != numChannels || c2.getHeight() != inputHeight || c2.getWidth() != inputWidth) {
                throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c2.getChannels()
                                + "," + c2.getWidth() + "," + c2.getHeight() + ") but expected (" + numChannels
                                + "," + inputHeight + "," + inputWidth + ")");
            }
            return c2;
        case CNNFlat:
            InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType;
            if (c3.getDepth() != numChannels || c3.getHeight() != inputHeight || c3.getWidth() != inputWidth) {
                throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c3.getDepth()
                                + "," + c3.getWidth() + "," + c3.getHeight() + ") but expected (" + numChannels
                                + "," + inputHeight + "," + inputWidth + ")");
            }
            return c3.getUnflattenedType();
        default:
            throw new IllegalStateException("Invalid input type: got " + inputType);
    }
}
 
Example 4
Source File: BatchNormalization.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
    if (inputType.getType() == InputType.Type.CNNFlat) {
        InputType.InputTypeConvolutionalFlat i = (InputType.InputTypeConvolutionalFlat) inputType;
        return new FeedForwardToCnnPreProcessor(i.getHeight(), i.getWidth(), i.getDepth());
    } else if (inputType.getType() == InputType.Type.RNN) {
        return new RnnToFeedForwardPreProcessor();
    }

    return null;
}
 
Example 5
Source File: Yolo2OutputLayer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
    switch (inputType.getType()) {
        case FF:
        case RNN:
            throw new UnsupportedOperationException("Cannot use FF or RNN input types");
        case CNN:
            return null;
        case CNNFlat:
            InputType.InputTypeConvolutionalFlat cf = (InputType.InputTypeConvolutionalFlat) inputType;
            return new FeedForwardToCnnPreProcessor(cf.getHeight(), cf.getWidth(), cf.getDepth());
        default:
            return null;
    }
}
 
Example 6
Source File: InputTypeUtil.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Utility method for determining the appropriate preprocessor for CNN layers, such as {@link ConvolutionLayer} and
 * {@link SubsamplingLayer}
 *
 * @param inputType     Input type to get the preprocessor for
 * @return              Null if no preprocessor is required; otherwise the appropriate preprocessor for the given input type
 */
public static InputPreProcessor getPreProcessorForInputTypeCnnLayers(InputType inputType, String layerName) {

    //To add x-to-CNN preprocessor: need to know image channels/width/height after reshaping
    //But this can't be inferred from the FF/RNN activations directly (could be anything)

    switch (inputType.getType()) {
        case FF:
            //FF -> CNN
            //                return new FeedForwardToCnnPreProcessor(inputSize[0], inputSize[1], inputDepth);
            log.info("Automatic addition of FF -> CNN preprocessors: not yet implemented (layer name: \""
                            + layerName + "\")");
            return null;
        case RNN:
            //RNN -> CNN
            //                return new RnnToCnnPreProcessor(inputSize[0], inputSize[1], inputDepth);
            log.warn("Automatic addition of RNN -> CNN preprocessors: not yet implemented (layer name: \""
                            + layerName + "\")");
            return null;
        case CNN:
            //CNN -> CNN: no preprocessor required
            return null;
        case CNNFlat:
            //CNN (flat) -> CNN
            InputType.InputTypeConvolutionalFlat f = (InputType.InputTypeConvolutionalFlat) inputType;
            return new FeedForwardToCnnPreProcessor(f.getHeight(), f.getWidth(), f.getDepth());
        default:
            throw new RuntimeException("Unknown input type: " + inputType);
    }
}
 
Example 7
Source File: GlobalPoolingLayer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {

    switch (inputType.getType()) {
        case FF:
            throw new UnsupportedOperationException(
                            "Global max pooling cannot be applied to feed-forward input type. Got input type = "
                                            + inputType);
        case RNN:
            InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) inputType;
            if (collapseDimensions) {
                //Return 2d (feed-forward) activations
                return InputType.feedForward(recurrent.getSize());
            } else {
                //Return 3d activations, with shape [minibatch, timeStepSize, 1]
                return recurrent;
            }
        case CNN:
            InputType.InputTypeConvolutional conv = (InputType.InputTypeConvolutional) inputType;
            if (collapseDimensions) {
                return InputType.feedForward(conv.getChannels());
            } else {
                return InputType.convolutional(1, 1, conv.getChannels(), conv.getFormat());
            }
        case CNN3D:
            InputType.InputTypeConvolutional3D conv3d = (InputType.InputTypeConvolutional3D) inputType;
            if (collapseDimensions) {
                return InputType.feedForward(conv3d.getChannels());
            } else {
                return InputType.convolutional3D(1, 1, 1, conv3d.getChannels());
            }
        case CNNFlat:
            InputType.InputTypeConvolutionalFlat convFlat = (InputType.InputTypeConvolutionalFlat) inputType;
            if (collapseDimensions) {
                return InputType.feedForward(convFlat.getDepth());
            } else {
                return InputType.convolutional(1, 1, convFlat.getDepth());
            }
        default:
            throw new UnsupportedOperationException("Unknown or not supported input type: " + inputType);
    }
}