Java Code Examples for org.deeplearning4j.nn.conf.inputs.InputType#InputTypeFeedForward

The following examples show how to use org.deeplearning4j.nn.conf.inputs.InputType#InputTypeFeedForward . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: KerasLoss.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Get DL4J LossLayer.
 *
 * @return LossLayer
 */
public FeedForwardLayer getLossLayer(InputType type) throws UnsupportedKerasConfigurationException {
    if (type instanceof InputType.InputTypeFeedForward) {
        this.layer = new LossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
    }
    else if (type instanceof  InputType.InputTypeRecurrent) {
        this.layer = new RnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
    }
    else if (type instanceof InputType.InputTypeConvolutional) {
        this.layer = new CnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
    } else {
        throw new UnsupportedKerasConfigurationException("Unsupported output layer type"
                + "got : " + type.toString());
    }
    return (FeedForwardLayer) this.layer;
}
 
Example 2
Source File: Bidirectional.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
    InputType outOrig = fwd.getOutputType(layerIndex, inputType);

    if (fwd instanceof LastTimeStep) {
        InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward) outOrig;
        if (mode == Mode.CONCAT) {
            return InputType.feedForward(2 * ff.getSize());
        } else {
            return ff;
        }
    } else {
        InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) outOrig;
        if (mode == Mode.CONCAT) {
            return InputType.recurrent(2 * r.getSize(), getRNNDataFormat());
        } else {
            return r;
        }
    }
}
 
Example 3
Source File: KerasPermute.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Gets appropriate DL4J InputPreProcessor for given InputTypes.
 *
 * @param inputType Array of InputTypes
 * @return DL4J InputPreProcessor
 * @throws InvalidKerasConfigurationException Invalid Keras config
 * @see InputPreProcessor
 */
@Override
public InputPreProcessor getInputPreprocessor(InputType... inputType) throws
        InvalidKerasConfigurationException {
    if (inputType.length > 1)
        throw new InvalidKerasConfigurationException(
                "Keras Permute layer accepts only one input (received " + inputType.length + ")");
    InputPreProcessor preprocessor = null;
    if (inputType[0] instanceof InputType.InputTypeConvolutional) {
        switch (this.getDimOrder()) {
            case THEANO:
                preprocessor = new PermutePreprocessor(permutationIndices);
                break;
            case NONE: // TF by default
            case TENSORFLOW:
                // account for channels last
                permutationIndices = new int[] {permutationIndices[2], permutationIndices[0], permutationIndices[1]};
                preprocessor = new PermutePreprocessor(new int[]{1, 3, 2});
        }
    } else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
        if (Arrays.equals(permutationIndices, new int[] {2, 1}))
            preprocessor = new PermutePreprocessor(permutationIndices);
        else
            throw new InvalidKerasConfigurationException("For RNN type input data, permutation dims have to be" +
                    "(2, 1) in Permute layer, got " + Arrays.toString(permutationIndices));
    } else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
        preprocessor = null;
    } else {
        throw new InvalidKerasConfigurationException("Input type not supported: " + inputType[0]);
    }
    return preprocessor;
}
 
Example 4
Source File: FeedForwardToRnnPreProcessor.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(InputType inputType) {
    if (inputType == null || (inputType.getType() != InputType.Type.FF
                    && inputType.getType() != InputType.Type.CNNFlat)) {
        throw new IllegalStateException("Invalid input: expected input of type FeedForward, got " + inputType);
    }

    if (inputType.getType() == InputType.Type.FF) {
        InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward) inputType;
        return InputType.recurrent(ff.getSize(), rnnDataFormat);
    } else {
        InputType.InputTypeConvolutionalFlat cf = (InputType.InputTypeConvolutionalFlat) inputType;
        return InputType.recurrent(cf.getFlattenedSize(), rnnDataFormat);
    }
}
 
Example 5
Source File: FeedForwardToCnnPreProcessor.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(InputType inputType) {

    switch (inputType.getType()) {
        case FF:
            InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType;
            val expSize = inputHeight * inputWidth * numChannels;
            if (c.getSize() != expSize) {
                throw new IllegalStateException("Invalid input: expected FeedForward input of size " + expSize
                                + " = (d=" + numChannels + " * w=" + inputWidth + " * h=" + inputHeight + "), got "
                                + inputType);
            }
            return InputType.convolutional(inputHeight, inputWidth, numChannels);
        case CNN:
            InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType;

            if (c2.getChannels() != numChannels || c2.getHeight() != inputHeight || c2.getWidth() != inputWidth) {
                throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c2.getChannels()
                                + "," + c2.getWidth() + "," + c2.getHeight() + ") but expected (" + numChannels
                                + "," + inputHeight + "," + inputWidth + ")");
            }
            return c2;
        case CNNFlat:
            InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType;
            if (c3.getDepth() != numChannels || c3.getHeight() != inputHeight || c3.getWidth() != inputWidth) {
                throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c3.getDepth()
                                + "," + c3.getWidth() + "," + c3.getHeight() + ") but expected (" + numChannels
                                + "," + inputHeight + "," + inputWidth + ")");
            }
            return c3.getUnflattenedType();
        default:
            throw new IllegalStateException("Invalid input type: got " + inputType);
    }
}
 
Example 6
Source File: FeedForwardToCnn3DPreProcessor.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(InputType inputType) {

    switch (inputType.getType()) {
        case FF:
            InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType;
            int expSize = inputDepth * inputHeight * inputWidth * numChannels;
            if (c.getSize() != expSize) {
                throw new IllegalStateException("Invalid input: expected FeedForward input of size " + expSize
                        + " = (d=" + numChannels + " * w=" + inputWidth + " * h=" + inputHeight + "), got "
                        + inputType);
            }
            return InputType.convolutional3D(inputDepth, inputHeight, inputWidth, numChannels);
        case CNN:
            InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType;

            if (c2.getChannels() != numChannels || c2.getHeight() != inputHeight || c2.getWidth() != inputWidth) {
                throw new IllegalStateException("Invalid input: Got CNN input type with (c,w,h)=(" + c2.getChannels()
                        + "," + c2.getWidth() + "," + c2.getHeight() + ") but expected (" + numChannels
                        + "," + inputHeight + "," + inputWidth + ")");
            }
            return InputType.convolutional3D(1, c2.getHeight(), c2.getWidth(), c2.getChannels());
        case CNN3D:
            InputType.InputTypeConvolutional3D c3 = (InputType.InputTypeConvolutional3D) inputType;

            if (c3.getChannels() != numChannels || c3.getDepth() != inputDepth ||
                    c3.getHeight() != inputHeight || c3.getWidth() != inputWidth) {
                throw new IllegalStateException("Invalid input: Got CNN input type with (c, d,w,h)=("
                        + c3.getChannels() + "," + c3.getDepth() + "," + c3.getWidth() + "," + c3.getHeight()
                        + ") but expected (" + numChannels + "," + inputDepth + ","
                        + inputHeight + "," + inputWidth + ")");
            }
            return c3;
        default:
            throw new IllegalStateException("Invalid input type: got " + inputType);
    }
}
 
Example 7
Source File: RepeatVector.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
    if (inputType == null || inputType.getType() != InputType.Type.FF) {
        throw new IllegalStateException("Invalid input for RepeatVector layer (layer name=\"" + getLayerName()
                        + "\"): Expected FF input, got " + inputType);
    }
    InputType.InputTypeFeedForward ffInput = (InputType.InputTypeFeedForward) inputType;
    return InputType.recurrent(ffInput.getSize(), n, this.dataFormat);
}
 
Example 8
Source File: InputTypeUtil.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public static InputPreProcessor getPreprocessorForInputTypeRnnLayers(InputType inputType, RNNFormat rnnDataFormat, String layerName) {
    if (inputType == null) {
        throw new IllegalStateException(
                        "Invalid input for RNN layer (layer name = \"" + layerName + "\"): input type is null");
    }

    switch (inputType.getType()) {
        case CNNFlat:
            //FF -> RNN or CNNFlat -> RNN
            //In either case, input data format is a row vector per example
            return new FeedForwardToRnnPreProcessor(rnnDataFormat);
        case FF:
            //If time distributed format is defined, use that. Otherwise use the layer-defined rnnDataFormat, which may be default
            InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward)inputType;
            if(ff.getTimeDistributedFormat() != null && ff.getTimeDistributedFormat() instanceof RNNFormat){
                return new FeedForwardToRnnPreProcessor((RNNFormat) ff.getTimeDistributedFormat());
            }
            return new FeedForwardToRnnPreProcessor(rnnDataFormat);
        case RNN:
            //RNN -> RNN: No preprocessor necessary
            return null;
        case CNN:
            //CNN -> RNN
            InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
            return new CnnToRnnPreProcessor(c.getHeight(), c.getWidth(), c.getChannels(), rnnDataFormat);
        default:
            throw new RuntimeException("Unknown input type: " + inputType);
    }
}
 
Example 9
Source File: KerasInput.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
/**
 * Get layer output type.
 *
 * @param inputType Array of InputTypes
 * @return output type as InputType
 * @throws InvalidKerasConfigurationException     Invalid Keras config
 * @throws UnsupportedKerasConfigurationException Unsupported Keras config
 */
@Override
public InputType getOutputType(InputType... inputType)
        throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
    if (inputType.length > 0)
        log.warn("Keras Input layer does not accept inputs (received " + inputType.length + "). Ignoring.");
    InputType myInputType;
    switch (this.inputShape.length) {
        case 1:
            myInputType = new InputType.InputTypeFeedForward(this.inputShape[0], null);
            break;
        case 2:
            if(this.dimOrder != null) {
                switch (this.dimOrder) {
                    case TENSORFLOW:    //NWC == channels_last
                        myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
                        break;
                    case THEANO:        //NCW == channels_first
                        myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1], RNNFormat.NCW);
                        break;
                    case NONE:
                        //Assume RNN in [mb, seqLen, size] format
                        myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
                        break;
                    default:
                        throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder);
                }
            } else {
                //Assume RNN in [mb, seqLen, size] format
                myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
            }

            break;
        case 3:
            switch (this.dimOrder) {
                case TENSORFLOW:
                    /* TensorFlow convolutional input: # rows, # cols, # channels */
                    myInputType = new InputType.InputTypeConvolutional(this.inputShape[0], this.inputShape[1],
                            this.inputShape[2], CNN2DFormat.NHWC);
                    break;
                case THEANO:
                    /* Theano convolutional input:     # channels, # rows, # cols */
                    myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
                            this.inputShape[0], CNN2DFormat.NCHW);
                    break;
                default:
                    this.dimOrder = DimOrder.THEANO;
                    myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
                            this.inputShape[0], CNN2DFormat.NCHW);
                    log.warn("Couldn't determine dim ordering / data format from model file. Older Keras " +
                            "versions may come without specified backend, in which case we assume the model was " +
                            "built with theano." );
            }
            break;
        case 4:
            switch (this.dimOrder) {
                case TENSORFLOW:
                    myInputType = new InputType.InputTypeConvolutional3D(Convolution3D.DataFormat.NDHWC,
                            this.inputShape[0], this.inputShape[1],
                            this.inputShape[2],this.inputShape[3]);
                    break;
                case THEANO:
                    myInputType = new InputType.InputTypeConvolutional3D(Convolution3D.DataFormat.NCDHW,
                            this.inputShape[3], this.inputShape[0],
                            this.inputShape[1],this.inputShape[2]);
                    break;
                default:
                    this.dimOrder = DimOrder.THEANO;
                    myInputType = new InputType.InputTypeConvolutional3D(Convolution3D.DataFormat.NCDHW,
                            this.inputShape[3], this.inputShape[0],
                            this.inputShape[1],this.inputShape[2]);
                    log.warn("Couldn't determine dim ordering / data format from model file. Older Keras " +
                            "versions may come without specified backend, in which case we assume the model was " +
                            "built with theano." );
            }
            break;
        default:
            throw new UnsupportedKerasConfigurationException(
                    "Inputs with " + this.inputShape.length + " dimensions not supported");
    }
    return myInputType;
}