/* * Copyright 2016 The BigDL Authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.intel.analytics.bigdl.nn.abstractnn import com.intel.analytics.bigdl.nn.keras.{Input => KInput, Sequential => KSequential} import com.intel.analytics.bigdl.nn.{Input => TInput} import com.intel.analytics.bigdl.utils.Shape import scala.language.existentials import scala.reflect.ClassTag class InvalidLayer(msg: String) extends RuntimeException(msg) trait InferShape { private[bigdl] var _inputShapeValue: Shape = null private[bigdl] var _outputShapeValue: Shape = null private[bigdl] def inputShapeValue: Shape = _inputShapeValue private[bigdl] def outputShapeValue: Shape = _outputShapeValue // scalastyle:off private[bigdl] def inputShapeValue_=(value: Shape): Unit = { _inputShapeValue = value } private[bigdl] def outputShapeValue_=(value: Shape): Unit = { _outputShapeValue = value } // scalastyle:on /** * Return the inputShape for the current Layer and the first dim is batch. */ final def getInputShape(): Shape = { require(this.isKerasStyle(), "Torch style definition doesn't support getInputShape for now.") _inputShapeValue } /** * Return the outputShape for the current Layer and the first dim is batch. */ final def getOutputShape(): Shape = { require(this.isKerasStyle(), "Torch style definition doesn't support getOutputShape for now.") require(this.isBuilt(), "This module hasn't been built.") outputShapeValue } /** * Execute building logic and return the outputShape for the given inputShape. * NB: the first dim of inputShape is batch */ private[bigdl] def build(inputShape: Shape): Shape = { val outputShape = computeOutputShape(inputShape) this.outputShapeValue = outputShape this.inputShapeValue = inputShape outputShape } private[bigdl] def isBuilt(): Boolean = outputShapeValue != null private[bigdl] def isKerasStyle(): Boolean = false private[bigdl] def allowRebuilt(): Boolean = false /** * We suppose the first dim is batch */ private[bigdl] def computeOutputShape(inputShape: Shape): Shape = { throw new RuntimeException("Haven't been implemented yet. Do not use it with Keras Layer") } private[bigdl] def excludeInvalidLayers[T: ClassTag] (modules : Seq[AbstractModule[_, _, T]]): Unit = { val invalidNodes = if (this.isKerasStyle()) { modules.filter{!_.isKerasStyle()} } else { modules.filter{_.isKerasStyle()} } if (invalidNodes.length > 0) { throw new InvalidLayer(s"""Do not mix ${this}(isKerasStyle=${isKerasStyle()}) with Layer (isKerasStyle=${invalidNodes(0).isKerasStyle()}): ${invalidNodes.mkString(",")}""") } } private[bigdl] def validateInput[T: ClassTag](modules : Seq[AbstractModule[_, _, T]]): Unit = { if (this.isKerasStyle()) { require(modules != null && !modules.isEmpty, "Empty input is not allowed") } excludeInvalidLayers(modules) } }