# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Contains definitions for MixNet model. [1] Mingxing Tan, Quoc V. Le MixNet: Rethinking Model Scaling for Convolutional Neural Networks. ICML'19, https://arxiv.org/abs/1905.11946 """ import re class BlockArgs(object): def __init__(self, input_filters=None, output_filters=None, dw_kernel_size=None, expand_kernel_size=None, project_kernel_size=None, strides=None, num_repeat=None, se_ratio=None, expand_ratio=None, identity_skip=True, swish=False, dilated=False): self.input_filters = input_filters self.output_filters = output_filters self.dw_kernel_size = self._normalize_kernel_size(dw_kernel_size) self.expand_kernel_size = self._normalize_kernel_size(expand_kernel_size) self.project_kernel_size = self._normalize_kernel_size(project_kernel_size) self.strides = strides self.num_repeat = num_repeat self.se_ratio = se_ratio self.expand_ratio = expand_ratio self.identity_skip = identity_skip self.swish = swish self.dilated = dilated def decode_block_string(self, block_string): """Gets a block through a string notation of arguments.""" assert isinstance(block_string, str) ops = block_string.split('_') options = {} for op in ops: splits = re.split(r'(\d.*)', op) if len(splits) >= 2: key, value = splits[:2] options[key] = value if 's' not in options or len(options['s']) != 2: raise ValueError('Strides options should be a pair of integers.') self.input_filters = int(options['i']) self.output_filters = int(options['o']) self.dw_kernel_size = self._parse_ksize(options['k']) self.expand_kernel_size = self._parse_ksize(options['a']) self.project_kernel_size = self._parse_ksize(options['p']) self.num_repeat = int(options['r']) self.identity_skip = ('noskip' not in block_string) self.se_ratio = float(options['se']) if 'se' in options else None self.expand_ratio = int(options['e']) self.strides = [int(options['s'][0]), int(options['s'][1])] self.swish = 'sw' in block_string self.dilated = 'dilated' in block_string return self def encode_block_string(self, block): """Encodes a block to a string. Encoding Schema: "rX_kX_sXX_eX_iX_oX{_se0.XX}{_noskip}" - X is replaced by a any number ranging from 0-9 - {} encapsulates optional arguments To deserialize an encoded block string, use the class method : ```python BlockArgs.from_block_string(block_string) ``` """ args = [ 'r%d' % block.num_repeat, 'k%s' % self._encode_ksize(block.kernel_size), 'a%s' % self._encode_ksize(block.expand_kernel_size), 'p%s' % self._encode_ksize(block.project_kernel_size), 's%d%d' % (block.strides[0], block.strides[1]), 'e%s' % block.expand_ratio, 'i%d' % block.input_filters, 'o%d' % block.output_filters ] if block.se_ratio > 0 and block.se_ratio <= 1: args.append('se%s' % block.se_ratio) if block.id_skip is False: args.append('noskip') if block.swish: args.append('sw') if block.dilated: args.append('dilated') return '_'.join(args) def _normalize_kernel_size(self, val): if type(val) == int: return [val] return val def _parse_ksize(self, ss): return [int(k) for k in ss.split('.')] def _encode_ksize(self, arr): return '.'.join([str(k) for k in arr]) @classmethod def from_block_string(cls, block_string): """ Encoding Schema: "rX_kX_sXX_eX_iX_oX{_se0.XX}{_noskip}" - X is replaced by a any number ranging from 0-9 - {} encapsulates optional arguments To deserialize an encoded block string, use the class method : ```python BlockArgs.from_block_string(block_string) ``` Returns: BlockArgs object initialized with the block string args. """ block = cls() return block.decode_block_string(block_string) # Default list of blocks for MixNets def get_mixnet_small(depth_multiplier=None): blocks_args = [ 'r1_k3_a1_p1_s11_e1_i16_o16', 'r1_k3_a1.1_p1.1_s22_e6_i16_o24', 'r1_k3_a1.1_p1.1_s11_e3_i24_o24', 'r1_k3.5.7_a1_p1_s22_e6_i24_o40_se0.5_sw', 'r3_k3.5_a1.1_p1.1_s11_e6_i40_o40_se0.5_sw', 'r1_k3.5.7_a1_p1.1_s22_e6_i40_o80_se0.25_sw', 'r2_k3.5_a1_p1.1_s11_e6_i80_o80_se0.25_sw', 'r1_k3.5.7_a1.1_p1.1_s11_e6_i80_o120_se0.5_sw', 'r2_k3.5.7.9_a1.1_p1.1_s11_e3_i120_o120_se0.5_sw', 'r1_k3.5.7.9.11_a1_p1_s22_e6_i120_o200_se0.5_sw', 'r2_k3.5.7.9_a1_p1.1_s11_e6_i200_o200_se0.5_sw', ] DEFAULT_BLOCK_LIST = [BlockArgs.from_block_string(s) for s in blocks_args] return DEFAULT_BLOCK_LIST def get_mixnet_medium(depth_multiplier=None): blocks_args = [ 'r1_k3_a1_p1_s11_e1_i24_o24', 'r1_k3.5.7_a1.1_p1.1_s22_e6_i24_o32', 'r1_k3_a1.1_p1.1_s11_e3_i32_o32', 'r1_k3.5.7.9_a1_p1_s22_e6_i32_o40_se0.5_sw', 'r3_k3.5_a1.1_p1.1_s11_e6_i40_o40_se0.5_sw', 'r1_k3.5.7_a1_p1_s22_e6_i40_o80_se0.25_sw', 'r3_k3.5.7.9_a1.1_p1.1_s11_e6_i80_o80_se0.25_sw', 'r1_k3_a1_p1_s11_e6_i80_o120_se0.5_sw', 'r3_k3.5.7.9_a1.1_p1.1_s11_e3_i120_o120_se0.5_sw', 'r1_k3.5.7.9_a1_p1_s22_e6_i120_o200_se0.5_sw', 'r3_k3.5.7.9_a1_p1.1_s11_e6_i200_o200_se0.5_sw', ] DEFAULT_BLOCK_LIST = [BlockArgs.from_block_string(s) for s in blocks_args] return DEFAULT_BLOCK_LIST def get_mixnet_large(depth_multiplier=None): return get_mixnet_medium(depth_multiplier)