""" code borrow from: https://github.com/mnicnc404/CartoonGan-tensorflow https://github.com/xiaohu2015/DeepLearning_tutorials/blob/master/CNNs/shufflenet_v2.py https://github.com/zengarden/light_head_rcnn https://github.com/geonseoks/Light_head_R_CNN_xception https://github.com/Stick-To/light-head-rcnn-tensorflow """ import tensorflow as tf #todo: remove this line import tensorflow.keras.backend as K from tensorflow.keras.models import Model from tensorflow.keras.layers import Layer, InputSpec from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, ReLU from tensorflow.keras.layers import MaxPool2D, GlobalAveragePooling2D, Dense from tensorflow.keras.layers import BatchNormalization, Activation CEM_FILTER=245 @tf.function def channle_shuffle(inputs, group): """Shuffle the channel Args: inputs: 4D Tensor group: int, number of groups Returns: Shuffled 4D Tensor """ #in_shape = inputs.get_shape().as_list() h, w, in_channel = K.int_shape(inputs)[1:] #h, w, in_channel = in_shape[1:] assert(in_channel % group == 0) l = K.reshape(inputs, [-1, h, w, in_channel // group, group]) l = K.permute_dimensions(l, [0, 1, 2, 4, 3]) l = K.reshape(l, [-1, h, w, in_channel]) return l class Conv2D_BN(Model): """Conv2D -> BN """ def __init__(self, channel, kernel_size=1, stride=1): super(Conv2D_BN, self).__init__() self.conv = Conv2D(channel, kernel_size, strides=stride, padding="SAME", use_bias=False) self.bn = BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5) def build(self, input_shape): super(Conv2D_BN, self).build(input_shape) @tf.function def call(self, inputs, training=False): x = self.conv(inputs) x = self.bn(x, training=training) return x class Conv2D_BN_ReLU(Model): """Conv2D -> BN -> ReLU""" def __init__(self, channel, kernel_size=1, stride=1): super(Conv2D_BN_ReLU, self).__init__(name="Conv2D_BN_ReLU") self.conv = Conv2D(channel, kernel_size, strides=stride, padding="SAME", use_bias=False) self.bn_ = BatchNormalization(momentum=0.9, epsilon=1e-5) self.relu = Activation("relu") self.model = tf.keras.models.Sequential() self.model.add(self.conv) self.model.add(self.bn_ ) self.model.add(self.relu) def build(self, input_shape): super(Conv2D_BN_ReLU, self).build(input_shape) @tf.function def call(self, inputs, training=False): #x = self.conv(inputs) #x = self.bn_(x, training=training) #x = self.relu(x) x=self.model(inputs, training=training) return x class DepthwiseConv2D_BN(Model): """DepthwiseConv2D -> BN""" def __init__(self, kernel_size=3, stride=1): super(DepthwiseConv2D_BN, self).__init__() self.dconv = DepthwiseConv2D(kernel_size, strides=stride, depth_multiplier=1, padding="SAME", use_bias=False) self.bn = BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5) @tf.function def call(self, inputs, training=False): x = self.dconv(inputs) x = self.bn(x, training=training) return x class DepthwiseConv2D_BN_POINT(Model): def __init__(self, kernel_size=3, stride=1, out_channel=256): super(DepthwiseConv2D_BN_POINT, self).__init__() self.dconv = DepthwiseConv2D(kernel_size, strides=stride, depth_multiplier=1, padding="SAME", use_bias=False) self.bn = BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5) self.conv = Conv2D(out_channel, 1, strides=1, padding="SAME", use_bias=True) @tf.function def call(self, inputs, training=False): x = self.dconv(inputs) x = self.bn(x, training=training) x = self.conv(x) return x class ShufflenetUnit1(Model): def __init__(self, out_channel): """The unit of shufflenetv2 for stride=1 Args: out_channel: int, number of channels """ super(ShufflenetUnit1, self).__init__() assert out_channel % 2 == 0 self.out_channel = out_channel self.conv1_bn_relu = Conv2D_BN_ReLU(out_channel // 2, 1, 1) self.dconv_bn = DepthwiseConv2D_BN(5, 1) self.conv2_bn_relu = Conv2D_BN_ReLU(out_channel // 2, 1, 1) def build(self, input_shape): super(ShufflenetUnit1, self).build(input_shape) @tf.function def call(self, inputs, training=False): #print(K.int_shape(inputs)) # split the channel shortcut, x = tf.split(inputs, 2, axis=3) x = self.conv1_bn_relu(x, training=training) x = self.dconv_bn(x, training=training) x = self.conv2_bn_relu(x, training=training) x = tf.concat([shortcut, x], axis=3) #print(K.int_shape(x)) x = channle_shuffle(x, 2) return x class ShufflenetUnit2(tf.keras.Model): """The unit of shufflenetv2 for stride=2""" def __init__(self, in_channel, out_channel): super(ShufflenetUnit2, self).__init__() assert out_channel % 2 == 0 self.in_channel = in_channel self.out_channel = out_channel self.conv1_bn_relu = Conv2D_BN_ReLU(out_channel // 2, 1, 1) self.dconv_bn = DepthwiseConv2D_BN(5, 2) self.conv2_bn = Conv2D_BN(out_channel - in_channel, 1, 1) # for shortcut self.shortcut_dconv_bn = DepthwiseConv2D_BN(3, 2) self.shortcut_conv_bn = Conv2D_BN(in_channel, 1, 1) @tf.function def call(self, inputs, training=False): shortcut, x = inputs, inputs x = self.conv1_bn_relu(x, training=training) x = self.dconv_bn(x, training=training) x = self.conv2_bn(x, training=training) shortcut = self.shortcut_dconv_bn(shortcut, training=training) shortcut = self.shortcut_conv_bn(shortcut, training=training) x = tf.concat([shortcut, x], axis=3) x = ReLU()(x) x = channle_shuffle(x, 2) return x class ShufflenetStage(Model): """The stage of shufflenet""" def __init__(self, in_channel, out_channel, num_blocks): super(ShufflenetStage, self).__init__() self.in_channel = in_channel self.out_channel = out_channel self.ops = [] for i in range(num_blocks): if i == 0: op = ShufflenetUnit2(in_channel, out_channel) else: op = ShufflenetUnit1(out_channel) self.ops.append(op) @tf.function def call(self, inputs, training=False): x = inputs for op in self.ops: x = op(x, training=training) return x class CEM(Model): """Context Enhancement Module""" def __init__(self): super(CEM, self).__init__() self.conv4 = Conv2D(CEM_FILTER, 1, strides=1, padding="SAME", use_bias=True) self.conv5 = Conv2D(CEM_FILTER, 1, strides=1, padding="SAME", use_bias=True) #self.b = K.reshape(inputs, [-1, h, w, in_channel // group, group]) @tf.function def call(self, inputs, training=False): C4_lat = self.conv4(inputs[0]) C5_lat = self.conv5(inputs[1]) C5_lat = tf.keras.backend.resize_images(C5_lat, 2, 2, "channels_last", "bilinear") Cglb_lat = K.reshape(inputs[2], [-1, 1, 1, CEM_FILTER]) return C4_lat + C5_lat + Cglb_lat class ShuffleNetv2(Model): """Shufflenetv2""" def __init__(self, num_classes, first_channel=24, channels_per_stage=(132, 264, 528)): super(ShuffleNetv2, self).__init__(name="ShuffleNetv2") self.num_classes = num_classes self.conv1_bn_relu = Conv2D_BN_ReLU(first_channel, 3, 2) self.pool1 = MaxPool2D(3, strides=2, padding="SAME") self.stage2 = ShufflenetStage(first_channel, channels_per_stage[0], 4) self.stage3 = ShufflenetStage(channels_per_stage[0], channels_per_stage[1], 8) self.stage4 = ShufflenetStage(channels_per_stage[1], channels_per_stage[2], 4) #self.conv5_bn_relu = Conv2D_BN_ReLU(1024, 1, 1) self.gap = GlobalAveragePooling2D() self.linear = Dense(num_classes) def build(self, input_shape): super(ShuffleNetv2, self).build(input_shape) @tf.function def call(self, inputs, training=False): x = self.conv1_bn_relu(inputs, training=training) x = self.pool1(x) x = self.stage2(x, training=training) C4 = self.stage3(x, training=training) C5 = self.stage4(C4, training=training) #print("C5: ", K.int_shape(C5)) #x = self.conv5_bn_relu(x, training=training) Cglb = self.gap(C5) print(K.int_shape(Cglb)) x = self.linear(Cglb) #print(K.int_shape(x)) return x, C4, C5, Cglb class RPNROI(Model): """roi""" def __init__(self): super(RPNROI, self).__init() @tf.function def call(self, inputs, training=False): """inputs=[SAM, rpn_conf, rpn_pbbox] """ pass #TODO class RPN(Model): """region proposal network""" def __init__(self, filter=256): super(RPN, self).__init__() self.num_anchors = 5*5 self.rpn = DepthwiseConv2D_BN_POINT(kernel_size=6, stride=1, out_channel=filter) self.rpn_cls_score = Conv2D(2*self.num_anchors, 1, strides=1, padding="VALID", use_bias=True) self.rpn_cls_pred = Conv2D(4*self.num_anchors, 1, strides=1, padding="VALID", use_bias=True) @tf.function def call(self, inputs, training=False): rpn = self.rpn(inputs, training=training) rpn_conf = self.rpn_cls_score(rpn) #cls_pred = tf.reshape(rpn_cls_score, [tf.shape(rpn_cls_score)[0], -1, 2] rpn_pbbox = self.rpn_cls_pred(rpn) return rpn, rpn_conf, rpn_pbbox class SAM(Model): """spatial attention module""" def __init__(self): super(SAM, self).__init__() self.point = Conv2D(CEM_FILTER, 1, strides=1, padding="VALID", use_bias=False) self.bn = BatchNormalization() @tf.function def call(self, inputs, training=False): """[RPN, CEM] """ x = self.point(inputs[0]) x = self.bn(x) x = tf.keras.activations.softmax(x, axis=-1) x = tf.math.multiply(x, inputs[1]) return x import numpy as np if __name__ == "__main__": s = (10, 320, 320, 12) nx = np.random.rand(*s).astype(np.float32) custom_layers = [ ShufflenetUnit1(out_channel = s[-1]), ShufflenetUnit2(in_channel = 24, out_channel = 116), ShufflenetStage(in_channel = 24, out_channel = 116, num_blocks = 5) ] for layer in custom_layers: tf.keras.backend.clear_session() out = layer(nx) layer.summary() print(f"Input Shape: {nx.shape}") print(f"Output Shape: {out.shape}") print("\n" * 2) tf.keras.backend.clear_session() g = ShuffleNetv2(CEM_FILTER) shape = (10, 320, 320, 3) nx = np.random.rand(*shape).astype(np.float32) t = tf.keras.Input(shape=nx.shape[1:], batch_size=nx.shape[0]) x,C4,C5,Cglb = g(nx, training=False) cem = CEM() sam = SAM() rpn = RPN() re = cem([C4, C5, x], training=False) rpn_result, rpn_cls_score, rpn_cls_pred = rpn(re, training =False) sam_result = sam([rpn_result, re], training=False) print('cem rsult: ', K.int_shape(re)) g.summary() sam.summary() cem.summary() rpn.summary()