Java Code Examples for org.deeplearning4j.nn.conf.inputs.InputType#InputTypeConvolutional

The following examples show how to use org.deeplearning4j.nn.conf.inputs.InputType#InputTypeConvolutional . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: Upsampling2D.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public LayerMemoryReport getMemoryReport(InputType inputType) {
    InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
    InputType.InputTypeConvolutional outputType = (InputType.InputTypeConvolutional) getOutputType(-1, inputType);

    // During forward pass: im2col array + reduce. Reduce is counted as activations, so only im2col is working mem
    val im2colSizePerEx =
                    c.getChannels() * outputType.getHeight() * outputType.getWidth() * size[0] * size[1];

    // Current implementation does NOT cache im2col etc... which means: it's recalculated on each backward pass
    long trainingWorkingSizePerEx = im2colSizePerEx;
    if (getIDropout() != null) {
        //Dup on the input before dropout, but only for training
        trainingWorkingSizePerEx += inputType.arrayElementsPerExample();
    }

    return new LayerMemoryReport.Builder(layerName, Upsampling2D.class, inputType, outputType).standardMemory(0, 0) //No params
                    .workingMemory(0, im2colSizePerEx, 0, trainingWorkingSizePerEx)
                    .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching
                    .build();
}
 
Example 2
Source File: GanCnnInputPreProcessor.java    From dl4j-tutorials with MIT License 5 votes vote down vote up
@Override
public InputType getOutputType(InputType inputType) {
	switch (inputType.getType()) {
	case CNN:
		InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType;

		if (c2.getChannels() != numChannels || c2.getHeight() != inputHeight || c2.getWidth() != inputWidth) {
			throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c2.getChannels() + "," + c2.getWidth() + ","
					+ c2.getHeight() + ") but expected (" + numChannels + "," + inputHeight + "," + inputWidth + ")");
		}
		return c2;
	default:
		throw new IllegalStateException("Invalid input type: got " + inputType);
	}
}
 
Example 3
Source File: ConvolutionLayer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void setNIn(InputType inputType, boolean override) {
    if (inputType == null || inputType.getType() != InputType.Type.CNN) {
        throw new IllegalStateException("Invalid input for Convolution layer (layer name=\"" + getLayerName()
                        + "\"): Expected CNN input, got " + inputType);
    }

    if (nIn <= 0 || override) {
        InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
        this.nIn = c.getChannels();
    }
    this.cnn2dDataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
}
 
Example 4
Source File: LocallyConnected2D.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
    if (inputType == null || inputType.getType() != InputType.Type.CNN) {
        throw new IllegalArgumentException("Provided input type for locally connected 2D layers has to be "
                        + "of CNN type, got: " + inputType);
    }
    // dynamically compute input size from input type
    InputType.InputTypeConvolutional cnnType = (InputType.InputTypeConvolutional) inputType;
    this.inputSize = new int[] {(int) cnnType.getHeight(), (int) cnnType.getWidth()};
    computeOutputSize();

    return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernel, stride, padding, new int[] {1, 1}, cm, nOut,
                    layerIndex, getLayerName(), format, LocallyConnected2D.class);
}
 
Example 5
Source File: KerasPermute.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Gets appropriate DL4J InputPreProcessor for given InputTypes.
 *
 * @param inputType Array of InputTypes
 * @return DL4J InputPreProcessor
 * @throws InvalidKerasConfigurationException Invalid Keras config
 * @see InputPreProcessor
 */
@Override
public InputPreProcessor getInputPreprocessor(InputType... inputType) throws
        InvalidKerasConfigurationException {
    if (inputType.length > 1)
        throw new InvalidKerasConfigurationException(
                "Keras Permute layer accepts only one input (received " + inputType.length + ")");
    InputPreProcessor preprocessor = null;
    if (inputType[0] instanceof InputType.InputTypeConvolutional) {
        switch (this.getDimOrder()) {
            case THEANO:
                preprocessor = new PermutePreprocessor(permutationIndices);
                break;
            case NONE: // TF by default
            case TENSORFLOW:
                // account for channels last
                permutationIndices = new int[] {permutationIndices[2], permutationIndices[0], permutationIndices[1]};
                preprocessor = new PermutePreprocessor(new int[]{1, 3, 2});
        }
    } else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
        if (Arrays.equals(permutationIndices, new int[] {2, 1}))
            preprocessor = new PermutePreprocessor(permutationIndices);
        else
            throw new InvalidKerasConfigurationException("For RNN type input data, permutation dims have to be" +
                    "(2, 1) in Permute layer, got " + Arrays.toString(permutationIndices));
    } else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
        preprocessor = null;
    } else {
        throw new InvalidKerasConfigurationException("Input type not supported: " + inputType[0]);
    }
    return preprocessor;
}
 
Example 6
Source File: CnnToFeedForwardPreProcessor.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(InputType inputType) {
    if (inputType == null || inputType.getType() != InputType.Type.CNN) {
        throw new IllegalStateException("Invalid input type: Expected input of type CNN, got " + inputType);
    }

    InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
    val outSize = c.getChannels() * c.getHeight() * c.getWidth();
    return InputType.feedForward(outSize);
}
 
Example 7
Source File: FeedForwardToCnnPreProcessor.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(InputType inputType) {

    switch (inputType.getType()) {
        case FF:
            InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType;
            val expSize = inputHeight * inputWidth * numChannels;
            if (c.getSize() != expSize) {
                throw new IllegalStateException("Invalid input: expected FeedForward input of size " + expSize
                                + " = (d=" + numChannels + " * w=" + inputWidth + " * h=" + inputHeight + "), got "
                                + inputType);
            }
            return InputType.convolutional(inputHeight, inputWidth, numChannels);
        case CNN:
            InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType;

            if (c2.getChannels() != numChannels || c2.getHeight() != inputHeight || c2.getWidth() != inputWidth) {
                throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c2.getChannels()
                                + "," + c2.getWidth() + "," + c2.getHeight() + ") but expected (" + numChannels
                                + "," + inputHeight + "," + inputWidth + ")");
            }
            return c2;
        case CNNFlat:
            InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType;
            if (c3.getDepth() != numChannels || c3.getHeight() != inputHeight || c3.getWidth() != inputWidth) {
                throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c3.getDepth()
                                + "," + c3.getWidth() + "," + c3.getHeight() + ") but expected (" + numChannels
                                + "," + inputHeight + "," + inputWidth + ")");
            }
            return c3.getUnflattenedType();
        default:
            throw new IllegalStateException("Invalid input type: got " + inputType);
    }
}
 
Example 8
Source File: FeedForwardToCnn3DPreProcessor.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(InputType inputType) {

    switch (inputType.getType()) {
        case FF:
            InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType;
            int expSize = inputDepth * inputHeight * inputWidth * numChannels;
            if (c.getSize() != expSize) {
                throw new IllegalStateException("Invalid input: expected FeedForward input of size " + expSize
                        + " = (d=" + numChannels + " * w=" + inputWidth + " * h=" + inputHeight + "), got "
                        + inputType);
            }
            return InputType.convolutional3D(inputDepth, inputHeight, inputWidth, numChannels);
        case CNN:
            InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType;

            if (c2.getChannels() != numChannels || c2.getHeight() != inputHeight || c2.getWidth() != inputWidth) {
                throw new IllegalStateException("Invalid input: Got CNN input type with (c,w,h)=(" + c2.getChannels()
                        + "," + c2.getWidth() + "," + c2.getHeight() + ") but expected (" + numChannels
                        + "," + inputHeight + "," + inputWidth + ")");
            }
            return InputType.convolutional3D(1, c2.getHeight(), c2.getWidth(), c2.getChannels());
        case CNN3D:
            InputType.InputTypeConvolutional3D c3 = (InputType.InputTypeConvolutional3D) inputType;

            if (c3.getChannels() != numChannels || c3.getDepth() != inputDepth ||
                    c3.getHeight() != inputHeight || c3.getWidth() != inputWidth) {
                throw new IllegalStateException("Invalid input: Got CNN input type with (c, d,w,h)=("
                        + c3.getChannels() + "," + c3.getDepth() + "," + c3.getWidth() + "," + c3.getHeight()
                        + ") but expected (" + numChannels + "," + inputDepth + ","
                        + inputHeight + "," + inputWidth + ")");
            }
            return c3;
        default:
            throw new IllegalStateException("Invalid input type: got " + inputType);
    }
}
 
Example 9
Source File: SameDiffConv.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void setNIn(InputType inputType, boolean override) {
    if (nIn <= 0 || override) {
        InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
        this.nIn = c.getChannels();
    }
}
 
Example 10
Source File: ZeroPaddingLayer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
    int[] hwd = ConvolutionUtils.getHWDFromInputType(inputType);
    int outH = hwd[0] + padding[0] + padding[1];
    int outW = hwd[1] + padding[2] + padding[3];

    InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional)inputType;

    return InputType.convolutional(outH, outW, hwd[2], c.getFormat());
}
 
Example 11
Source File: ConvolutionLayer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public LayerMemoryReport getMemoryReport(InputType inputType) {
    val paramSize = initializer().numParams(this);
    val updaterStateSize = (int) getIUpdater().stateSize(paramSize);

    InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
    InputType.InputTypeConvolutional outputType = (InputType.InputTypeConvolutional) getOutputType(-1, inputType);

    //TODO convolution helper memory use... (CuDNN etc)

    //During forward pass: im2col array, mmul (result activations), in-place broadcast add
    val im2colSizePerEx = c.getChannels() * outputType.getHeight() * outputType.getWidth() * kernelSize[0]
                    * kernelSize[1];

    //During training: have im2col array, in-place gradient calculation, then epsilons...
    //But: im2col array may be cached...
    Map<CacheMode, Long> trainWorkingMemoryPerEx = new HashMap<>();
    Map<CacheMode, Long> cachedPerEx = new HashMap<>();

    //During backprop: im2col array for forward pass (possibly cached) + the epsilon6d array required to calculate
    // the 4d epsilons (equal size to input)
    //Note that the eps6d array is same size as im2col
    for (CacheMode cm : CacheMode.values()) {
        long trainWorkingSizePerEx;
        long cacheMemSizePerEx = 0;
        if (cm == CacheMode.NONE) {
            trainWorkingSizePerEx = 2 * im2colSizePerEx;
        } else {
            //im2col is cached, but epsNext2d/eps6d is not
            cacheMemSizePerEx = im2colSizePerEx;
            trainWorkingSizePerEx = im2colSizePerEx;
        }

        if (getIDropout() != null) {
            //Dup on the input before dropout, but only for training
            trainWorkingSizePerEx += inputType.arrayElementsPerExample();
        }

        trainWorkingMemoryPerEx.put(cm, trainWorkingSizePerEx);
        cachedPerEx.put(cm, cacheMemSizePerEx);
    }

    return new LayerMemoryReport.Builder(layerName, ConvolutionLayer.class, inputType, outputType)
                    .standardMemory(paramSize, updaterStateSize)
                    //im2col caching -> only variable size caching
                    .workingMemory(0, im2colSizePerEx, MemoryReport.CACHE_MODE_ALL_ZEROS, trainWorkingMemoryPerEx)
                    .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, cachedPerEx).build();

}
 
Example 12
Source File: ConvDataFormatTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public InputType getOutputType(InputType inputType) {
    InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
    return InputType.convolutional(c.getHeight(), c.getWidth(), c.getChannels(), CNN2DFormat.NCHW);
}
 
Example 13
Source File: CnnLossLayer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public void setNIn(InputType inputType, boolean override) {
    if(inputType instanceof InputType.InputTypeConvolutional){
        this.format = ((InputType.InputTypeConvolutional) inputType).getFormat();
    }
}
 
Example 14
Source File: SameDiffConv.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
    InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
    return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernel, stride, padding, new int[]{1, 1},
            cm, nOut, layerIndex, getLayerName(), SameDiffConv.class);
}
 
Example 15
Source File: Yolo2OutputLayer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public void setNIn(InputType inputType, boolean override) {
    InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
    this.format = c.getFormat();
}
 
Example 16
Source File: KerasInput.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
/**
 * Get layer output type.
 *
 * @param inputType Array of InputTypes
 * @return output type as InputType
 * @throws InvalidKerasConfigurationException     Invalid Keras config
 * @throws UnsupportedKerasConfigurationException Unsupported Keras config
 */
@Override
public InputType getOutputType(InputType... inputType)
        throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
    if (inputType.length > 0)
        log.warn("Keras Input layer does not accept inputs (received " + inputType.length + "). Ignoring.");
    InputType myInputType;
    switch (this.inputShape.length) {
        case 1:
            myInputType = new InputType.InputTypeFeedForward(this.inputShape[0], null);
            break;
        case 2:
            if(this.dimOrder != null) {
                switch (this.dimOrder) {
                    case TENSORFLOW:    //NWC == channels_last
                        myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
                        break;
                    case THEANO:        //NCW == channels_first
                        myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1], RNNFormat.NCW);
                        break;
                    case NONE:
                        //Assume RNN in [mb, seqLen, size] format
                        myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
                        break;
                    default:
                        throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder);
                }
            } else {
                //Assume RNN in [mb, seqLen, size] format
                myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
            }

            break;
        case 3:
            switch (this.dimOrder) {
                case TENSORFLOW:
                    /* TensorFlow convolutional input: # rows, # cols, # channels */
                    myInputType = new InputType.InputTypeConvolutional(this.inputShape[0], this.inputShape[1],
                            this.inputShape[2], CNN2DFormat.NHWC);
                    break;
                case THEANO:
                    /* Theano convolutional input:     # channels, # rows, # cols */
                    myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
                            this.inputShape[0], CNN2DFormat.NCHW);
                    break;
                default:
                    this.dimOrder = DimOrder.THEANO;
                    myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
                            this.inputShape[0], CNN2DFormat.NCHW);
                    log.warn("Couldn't determine dim ordering / data format from model file. Older Keras " +
                            "versions may come without specified backend, in which case we assume the model was " +
                            "built with theano." );
            }
            break;
        case 4:
            switch (this.dimOrder) {
                case TENSORFLOW:
                    myInputType = new InputType.InputTypeConvolutional3D(Convolution3D.DataFormat.NDHWC,
                            this.inputShape[0], this.inputShape[1],
                            this.inputShape[2],this.inputShape[3]);
                    break;
                case THEANO:
                    myInputType = new InputType.InputTypeConvolutional3D(Convolution3D.DataFormat.NCDHW,
                            this.inputShape[3], this.inputShape[0],
                            this.inputShape[1],this.inputShape[2]);
                    break;
                default:
                    this.dimOrder = DimOrder.THEANO;
                    myInputType = new InputType.InputTypeConvolutional3D(Convolution3D.DataFormat.NCDHW,
                            this.inputShape[3], this.inputShape[0],
                            this.inputShape[1],this.inputShape[2]);
                    log.warn("Couldn't determine dim ordering / data format from model file. Older Keras " +
                            "versions may come without specified backend, in which case we assume the model was " +
                            "built with theano." );
            }
            break;
        default:
            throw new UnsupportedKerasConfigurationException(
                    "Inputs with " + this.inputShape.length + " dimensions not supported");
    }
    return myInputType;
}
 
Example 17
Source File: ConvDataFormatTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public InputType getOutputType(InputType inputType) {
    InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
    return InputType.convolutional(c.getHeight(), c.getWidth(), c.getChannels(), CNN2DFormat.NCHW);
}
 
Example 18
Source File: PoolHelperVertex.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
    if (vertexInputs.length == 1)
        return vertexInputs[0];
    InputType first = vertexInputs[0];
    if (first.getType() == InputType.Type.CNNFlat) {
        //TODO
        //Merging flattened CNN format data could be messy?
        throw new InvalidInputTypeException(
                        "Invalid input: MergeVertex cannot currently merge CNN data in flattened format. Got: "
                                        + vertexInputs);
    } else if (first.getType() != InputType.Type.CNN) {
        //FF or RNN data inputs
        int size = 0;
        InputType.Type type = null;
        for (int i = 0; i < vertexInputs.length; i++) {
            if (vertexInputs[i].getType() != first.getType()) {
                throw new InvalidInputTypeException(
                                "Invalid input: MergeVertex cannot merge activations of different types:"
                                                + " first type = " + first.getType() + ", input type " + (i + 1)
                                                + " = " + vertexInputs[i].getType());
            }

            long thisSize;
            switch (vertexInputs[i].getType()) {
                case FF:
                    thisSize = ((InputType.InputTypeFeedForward) vertexInputs[i]).getSize();
                    type = InputType.Type.FF;
                    break;
                case RNN:
                    thisSize = ((InputType.InputTypeRecurrent) vertexInputs[i]).getSize();
                    type = InputType.Type.RNN;
                    break;
                default:
                    throw new IllegalStateException("Unknown input type: " + vertexInputs[i]); //Should never happen
            }
            if (thisSize <= 0) {//Size is not defined
                size = -1;
            } else {
                size += thisSize;
            }
        }

        if (size > 0) {
            //Size is specified
            if (type == InputType.Type.FF)
                return InputType.feedForward(size);
            else
                return InputType.recurrent(size);
        } else {
            //size is unknown
            if (type == InputType.Type.FF)
                return InputType.feedForward(-1);
            else
                return InputType.recurrent(-1);
        }
    } else {
        //CNN inputs... also check that the channels, width and heights match:
        InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;

        val fd = firstConv.getChannels();
        val fw = firstConv.getWidth();
        val fh = firstConv.getHeight();

        long depthSum = fd;

        for (int i = 1; i < vertexInputs.length; i++) {
            if (vertexInputs[i].getType() != InputType.Type.CNN) {
                throw new InvalidInputTypeException(
                                "Invalid input: MergeVertex cannot process activations of different types:"
                                                + " first type = " + InputType.Type.CNN + ", input type " + (i + 1)
                                                + " = " + vertexInputs[i].getType());
            }

            InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i];

            long od = otherConv.getChannels();
            long ow = otherConv.getWidth();
            long oh = otherConv.getHeight();

            if (fw != ow || fh != oh) {
                throw new InvalidInputTypeException(
                                "Invalid input: MergeVertex cannot merge CNN activations of different width/heights:"
                                                + "first [channels,width,height] = [" + fd + "," + fw + "," + fh
                                                + "], input " + i + " = [" + od + "," + ow + "," + oh + "]");
            }

            depthSum += od;
        }

        return InputType.convolutional(fh, fw, depthSum);
    }
}
 
Example 19
Source File: CustomBroadcast.java    From wekaDeeplearning4j with GNU General Public License v3.0 4 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
    InputType.InputTypeConvolutional convolutional = (InputType.InputTypeConvolutional) inputType;
    long channels = convolutional.getChannels();
    return InputType.convolutional(width, width, channels, CNN2DFormat.NHWC);
}
 
Example 20
Source File: GlobalPoolingLayer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {

    switch (inputType.getType()) {
        case FF:
            throw new UnsupportedOperationException(
                            "Global max pooling cannot be applied to feed-forward input type. Got input type = "
                                            + inputType);
        case RNN:
            InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) inputType;
            if (collapseDimensions) {
                //Return 2d (feed-forward) activations
                return InputType.feedForward(recurrent.getSize());
            } else {
                //Return 3d activations, with shape [minibatch, timeStepSize, 1]
                return recurrent;
            }
        case CNN:
            InputType.InputTypeConvolutional conv = (InputType.InputTypeConvolutional) inputType;
            if (collapseDimensions) {
                return InputType.feedForward(conv.getChannels());
            } else {
                return InputType.convolutional(1, 1, conv.getChannels(), conv.getFormat());
            }
        case CNN3D:
            InputType.InputTypeConvolutional3D conv3d = (InputType.InputTypeConvolutional3D) inputType;
            if (collapseDimensions) {
                return InputType.feedForward(conv3d.getChannels());
            } else {
                return InputType.convolutional3D(1, 1, 1, conv3d.getChannels());
            }
        case CNNFlat:
            InputType.InputTypeConvolutionalFlat convFlat = (InputType.InputTypeConvolutionalFlat) inputType;
            if (collapseDimensions) {
                return InputType.feedForward(convFlat.getDepth());
            } else {
                return InputType.convolutional(1, 1, convFlat.getDepth());
            }
        default:
            throw new UnsupportedOperationException("Unknown or not supported input type: " + inputType);
    }
}