/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.regularization.Regularization; import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.Map; /** * Batch normalization layer<br> See: Ioffe and Szegedy, 2014, <i>Batch Normalization: Accelerating Deep Network * Training by Reducing Internal Covariate Shift</i> * <a href="https://arxiv.org/abs/1502.03167">https://arxiv.org/abs/1502.03167</a> */ @Data @ToString(callSuper = true) @EqualsAndHashCode(callSuper = true) @Builder public class BatchNormalization extends FeedForwardLayer { //Note: need to set defaults here in addition to builder, in case user uses no-op constructor... protected double decay = 0.9; protected double eps = 1e-5; protected boolean isMinibatch = true; protected double gamma = 1.0; protected double beta = 0.0; protected boolean lockGammaBeta = false; protected boolean cudnnAllowFallback = true; protected boolean useLogStd = false; //Default for deserialized models (1.0.0-beta3) and earlier: store variance as variance. Post 1.0.0-beta3: use log stdev instead protected CNN2DFormat cnn2DFormat = CNN2DFormat.NCHW; //Default for deserialized models, 1.0.0-beta6 and earlier private BatchNormalization(Builder builder) { super(builder); this.decay = builder.decay; this.eps = builder.eps; this.isMinibatch = builder.isMinibatch; this.gamma = builder.gamma; this.beta = builder.beta; this.lockGammaBeta = builder.lockGammaBeta; this.cudnnAllowFallback = builder.cudnnAllowFallback; this.useLogStd = builder.useLogStd; this.cnn2DFormat = builder.cnn2DFormat; initializeConstraints(builder); } public BatchNormalization() { this(new Builder()); //Defaults from builder } @Override public BatchNormalization clone() { BatchNormalization clone = (BatchNormalization) super.clone(); return clone; } @Override public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNOutSet("BatchNormalization", getLayerName(), layerIndex, getNOut()); org.deeplearning4j.nn.layers.normalization.BatchNormalization ret = new org.deeplearning4j.nn.layers.normalization.BatchNormalization(conf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map<String, INDArray> paramTable = initializer().init(conf, layerParamsView, initializeParams); ret.setParamTable(paramTable); ret.setConf(conf); return ret; } @Override public ParamInitializer initializer() { return BatchNormalizationParamInitializer.getInstance(); } @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null) { throw new IllegalStateException( "Invalid input type: Batch norm layer expected input of type CNN, got null for layer \"" + getLayerName() + "\""); } //Can handle CNN, flat CNN, CNN3D or FF input formats only switch (inputType.getType()) { case FF: case CNN: case CNNFlat: case CNN3D: return inputType; //OK default: throw new IllegalStateException( "Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got " + inputType + " for layer index " + layerIndex + ", layer name = " + getLayerName()); } } @Override public void setNIn(InputType inputType, boolean override) { if (nIn <= 0 || override) { switch (inputType.getType()) { case FF: nIn = ((InputType.InputTypeFeedForward) inputType).getSize(); break; case CNN: nIn = ((InputType.InputTypeConvolutional) inputType).getChannels(); cnn2DFormat = ((InputType.InputTypeConvolutional) inputType).getFormat(); break; case CNN3D: nIn = ((InputType.InputTypeConvolutional3D) inputType).getChannels(); break; case CNNFlat: nIn = ((InputType.InputTypeConvolutionalFlat) inputType).getDepth(); default: throw new IllegalStateException( "Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got " + inputType + " for layer " + getLayerName() + "\""); } nOut = nIn; } } @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { if (inputType.getType() == InputType.Type.CNNFlat) { InputType.InputTypeConvolutionalFlat i = (InputType.InputTypeConvolutionalFlat) inputType; return new FeedForwardToCnnPreProcessor(i.getHeight(), i.getWidth(), i.getDepth()); } else if (inputType.getType() == InputType.Type.RNN) { return new RnnToFeedForwardPreProcessor(); } return null; } @Override public List<Regularization> getRegularizationByParam(String paramName){ //Don't regularize batch norm params: similar to biases in the sense that there are not many of them... return null; } @Override public IUpdater getUpdaterByParam(String paramName) { switch (paramName) { case BatchNormalizationParamInitializer.BETA: case BatchNormalizationParamInitializer.GAMMA: return iUpdater; case BatchNormalizationParamInitializer.GLOBAL_MEAN: case BatchNormalizationParamInitializer.GLOBAL_VAR: case BatchNormalizationParamInitializer.GLOBAL_LOG_STD: return new NoOp(); default: throw new IllegalArgumentException("Unknown parameter: \"" + paramName + "\""); } } @Override public LayerMemoryReport getMemoryReport(InputType inputType) { InputType outputType = getOutputType(-1, inputType); //TODO CuDNN helper etc val numParams = initializer().numParams(this); int updaterStateSize = 0; for (String s : BatchNormalizationParamInitializer.getInstance().paramKeys(this)) { updaterStateSize += getUpdaterByParam(s).stateSize(nOut); } //During forward pass: working memory size approx. equal to 2x input size (copy ops, etc) val inferenceWorkingSize = 2 * inputType.arrayElementsPerExample(); //During training: we calculate mean and variance... result is equal to nOut, and INDEPENDENT of minibatch size val trainWorkFixed = 2 * nOut; //During backprop: multiple working arrays... output size, 2 * output size (indep. of example size), val trainWorkingSizePerExample = inferenceWorkingSize //Inference during backprop + (outputType.arrayElementsPerExample() + 2 * nOut); //Backprop gradient calculation return new LayerMemoryReport.Builder(layerName, BatchNormalization.class, inputType, outputType) .standardMemory(numParams, updaterStateSize) .workingMemory(0, 0, trainWorkFixed, trainWorkingSizePerExample) //No additional memory (beyond activations) for inference .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching .build(); } @Override public boolean isPretrainParam(String paramName) { return false; //No pretrain params in BN } @AllArgsConstructor @Getter @Setter public static class Builder extends FeedForwardLayer.Builder<Builder> { /** * At test time: we can use a global estimate of the mean and variance, calculated using a moving average of the * batch means/variances. This moving average is implemented as:<br> globalMeanEstimate = decay * * globalMeanEstimate + (1-decay) * batchMean<br> globalVarianceEstimate = decay * globalVarianceEstimate + * (1-decay) * batchVariance<br> * */ protected double decay = 0.9; /** * Epsilon value for batch normalization; small floating point value added to variance (algorithm 1 in <a * href="https://arxiv.org/pdf/1502.03167v3.pdf">https://arxiv.org/pdf/1502.03167v3.pdf</a>) to reduce/avoid * underflow issues.<br> Default: 1e-5 */ protected double eps = 1e-5; /** * If doing minibatch training or not. Default: true. Under most circumstances, this should be set to true. If * doing full batch training (i.e., all examples in a single DataSet object - very small data sets) then this * should be set to false. Affects how global mean/variance estimates are calculated. * */ protected boolean isMinibatch = true; // TODO auto set this if layer conf is batch /** * If set to true: lock the gamma and beta parameters to the values for each activation, specified by {@link * #gamma(double)} and {@link #beta(double)}. Default: false -> learn gamma and beta parameter values during * network training. * */ protected boolean lockGammaBeta = false; /** * Used only when 'true' is passed to {@link #lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default: * 1.0 * */ protected double gamma = 1.0; /** * Used only when 'true' is passed to {@link #lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default: * 0.0 * */ protected double beta = 0.0; /** * Set constraints to be applied to the beta parameter of this batch normalisation layer. Default: no * constraints.<br> Constraints can be used to enforce certain conditions (non-negativity of parameters, * max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have * been updated. * */ protected List<LayerConstraint> betaConstraints; /** * Set constraints to be applied to the gamma parameter of this batch normalisation layer. Default: no * constraints.<br> Constraints can be used to enforce certain conditions (non-negativity of parameters, * max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have * been updated. * */ protected List<LayerConstraint> gammaConstraints; /** * When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation be allowed? * If set to false, an exception in CuDNN will be propagated back to the user. If false, the built-in * (non-CuDNN) implementation for BatchNormalization will be used * */ protected boolean cudnnAllowFallback = true; /** * How should the moving average of variance be stored? Two different parameterizations are supported. * useLogStd(false): equivalent to 1.0.0-beta3 and earlier. The variance "parameter" is stored directly as * variable<br> useLogStd(true): (Default) variance is stored as log10(stdev)<br> The motivation here is for * numerical stability (FP16 etc) and also distributed training: storing the variance directly can cause * numerical issues. For example, a standard deviation of 1e-3 (something that could be encountered in practice) * gives a variance of 1e-6, which can be problematic for 16-bit floating point */ protected boolean useLogStd = true; protected CNN2DFormat cnn2DFormat = CNN2DFormat.NCHW; //Default for deserialized models, 1.0.0-beta6 and earlier public Builder(double decay, boolean isMinibatch) { this.setDecay(decay); this.setMinibatch(isMinibatch); } public Builder(double gamma, double beta) { this.setGamma(gamma); this.setBeta(beta); } public Builder(double gamma, double beta, boolean lockGammaBeta) { this.setGamma(gamma); this.setBeta(beta); this.setLockGammaBeta(lockGammaBeta); } public Builder(boolean lockGammaBeta) { this.setLockGammaBeta(lockGammaBeta); } public Builder() {} /** * Set the input and output array data format. Defaults to NCHW format - i.e., channels first. * See {@link CNN2DFormat} for more details * @param format Format to use */ public Builder dataFormat(CNN2DFormat format){ this.cnn2DFormat = format; return this; } /** * If doing minibatch training or not. Default: true. Under most circumstances, this should be set to true. If * doing full batch training (i.e., all examples in a single DataSet object - very small data sets) then this * should be set to false. Affects how global mean/variance estimates are calculated. * * @param minibatch Minibatch parameter */ public Builder minibatch(boolean minibatch) { this.setMinibatch(minibatch); return this; } /** * Used only when 'true' is passed to {@link #lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default: * 1.0 * * @param gamma Gamma parameter for all activations, used only with locked gamma/beta configuration mode */ public Builder gamma(double gamma) { this.setGamma(gamma); return this; } /** * Used only when 'true' is passed to {@link #lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default: * 0.0 * * @param beta Beta parameter for all activations, used only with locked gamma/beta configuration mode */ public Builder beta(double beta) { this.setBeta(beta); return this; } /** * Epsilon value for batch normalization; small floating point value added to variance (algorithm 1 in <a * href="https://arxiv.org/pdf/1502.03167v3.pdf">https://arxiv.org/pdf/1502.03167v3.pdf</a>) to reduce/avoid * underflow issues.<br> Default: 1e-5 * * @param eps Epsilon values to use */ public Builder eps(double eps) { this.setEps(eps); return this; } /** * At test time: we can use a global estimate of the mean and variance, calculated using a moving average of the * batch means/variances. This moving average is implemented as:<br> globalMeanEstimate = decay * * globalMeanEstimate + (1-decay) * batchMean<br> globalVarianceEstimate = decay * globalVarianceEstimate + * (1-decay) * batchVariance<br> * * @param decay Decay value to use for global stats calculation */ public Builder decay(double decay) { this.setDecay(decay); return this; } /** * If set to true: lock the gamma and beta parameters to the values for each activation, specified by {@link * #gamma(double)} and {@link #beta(double)}. Default: false -> learn gamma and beta parameter values during * network training. * * @param lockGammaBeta If true: use fixed beta/gamma values. False: learn during */ public Builder lockGammaBeta(boolean lockGammaBeta) { this.setLockGammaBeta(lockGammaBeta); return this; } /** * Set constraints to be applied to the beta parameter of this batch normalisation layer. Default: no * constraints.<br> Constraints can be used to enforce certain conditions (non-negativity of parameters, * max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have * been updated. * * @param constraints Constraints to apply to the beta parameter of this layer */ public Builder constrainBeta(LayerConstraint... constraints) { this.setBetaConstraints(Arrays.asList(constraints)); return this; } /** * Set constraints to be applied to the gamma parameter of this batch normalisation layer. Default: no * constraints.<br> Constraints can be used to enforce certain conditions (non-negativity of parameters, * max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have * been updated. * * @param constraints Constraints to apply to the gamma parameter of this layer */ public Builder constrainGamma(LayerConstraint... constraints) { this.setGammaConstraints(Arrays.asList(constraints)); return this; } /** * When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation be allowed? * If set to false, an exception in CuDNN will be propagated back to the user. If true, the built-in * (non-CuDNN) implementation for BatchNormalization will be used * * @deprecated Use {@link #helperAllowFallback(boolean)} * * @param allowFallback Whether fallback to non-CuDNN implementation should be used */ @Deprecated public Builder cudnnAllowFallback(boolean allowFallback) { this.setCudnnAllowFallback(allowFallback); return this; } /** * When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper implementation be allowed? * If set to false, an exception in the helper will be propagated back to the user. If true, the built-in * (non-MKL/CuDNN) implementation for BatchNormalizationLayer will be used * * @param allowFallback Whether fallback to non-CuDNN implementation should be used */ public Builder helperAllowFallback(boolean allowFallback) { this.cudnnAllowFallback = allowFallback; return this; } /** * How should the moving average of variance be stored? Two different parameterizations are supported. * useLogStd(false): equivalent to 1.0.0-beta3 and earlier. The variance "parameter" is stored directly as * variable<br> useLogStd(true): (Default) variance is stored as log10(stdev)<br> The motivation here is for * numerical stability (FP16 etc) and also distributed training: storing the variance directly can cause * numerical issues. For example, a standard deviation of 1e-3 (something that could be encountered in practice) * gives a variance of 1e-6, which can be problematic for 16-bit floating point */ public Builder useLogStd(boolean useLogStd) { this.setUseLogStd(useLogStd); return this; } @Override public BatchNormalization build() { return new BatchNormalization(this); } } }