from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import tensorflow as tf import warnings from contextlib import contextmanager from .version_utils import tf_later_than try: import cv2 except ImportError: cv2 = None __middles__ = 'middles' __outputs__ = 'outputs' if tf_later_than('1.14'): tf = tf.compat.v1 if tf_later_than('2'): from .contrib_framework import arg_scope from .contrib_layers.utils import collect_named_outputs else: from tensorflow.contrib.framework import arg_scope from tensorflow.contrib.layers.python.layers.utils import collect_named_outputs if tf_later_than('2.1'): from tensorflow.python.keras.applications.imagenet_utils \ import decode_predictions from tensorflow.python.keras.utils.data_utils import get_file elif tf_later_than('1.9'): from tensorflow.python.keras.applications.imagenet_utils \ import decode_predictions from tensorflow.python.keras.utils import get_file elif tf_later_than('1.4'): from tensorflow.python.keras._impl.keras.applications.imagenet_utils \ import decode_predictions from tensorflow.python.keras.utils import get_file else: from tensorflow.contrib.keras.python.keras.applications.imagenet_utils \ import decode_predictions from tensorflow.contrib.keras.python.keras.utils.data_utils \ import get_file def print_collection(collection, scope): if scope is not None: print("Scope: %s" % scope) for x in tf.get_collection(collection, scope=scope + '/'): name = x.name if scope is not None: name = name[len(scope)+1:] print("%s %s" % (name, x.shape)) def parse_scopes(inputs): if not isinstance(inputs, list): inputs = [inputs] outputs = [] for scope_or_tensor in inputs: if isinstance(scope_or_tensor, tf.Tensor): outputs.append(scope_or_tensor.aliases[0]) elif isinstance(scope_or_tensor, str): outputs.append(scope_or_tensor) else: outputs.append(None) return outputs def print_middles(scopes=None): scopes = parse_scopes(scopes) for scope in scopes: print_collection(__middles__, scope) def print_outputs(scopes=None): scopes = parse_scopes(scopes) for scope in scopes: print_collection(__outputs__, scope) def print_weights(scopes=None): scopes = parse_scopes(scopes) for scope in scopes: print_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) def print_summary(scopes=None): scopes = parse_scopes(scopes) for scope in scopes: if scope is not None: print("Scope: %s" % scope) weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope + '/') names = [w.name for w in weights] starts = [n.rfind('/') + 1 for n in names] ends = [n.rfind(':') for n in names] layers = sum([n[s:e] == 'weights' for (n, s, e) in zip(names, starts, ends)]) parameters = sum([w.shape.num_elements() for w in weights]) print("Total layers: %d" % layers) print("Total weights: %d" % len(weights)) print("Total parameters: {:,}".format(parameters)) def get_collection(collection_name, scope=None, names=None): scope = parse_scopes(scope)[0] collection = tf.get_collection(collection_name, scope=scope + '/') if names is None: return collection else: if not isinstance(names, list): names = [names] _collection = [] for x in collection: if any([name in x.name for name in names]): _collection.append(x) return _collection def get_bottleneck(scope=None): return get_collection(__middles__, scope, names=None)[-1] def get_middles(scope=None, names=None): return get_collection(__middles__, scope, names) def get_outputs(scope=None, names=None): return get_collection(__outputs__, scope, names) def get_weights(scope=None, names=None): return get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope, names) def load(model, weights_path, sess): if sess is None: sess = tf.get_default_session() assert sess is not None, 'The default session should be given.' values = parse_weights(weights_path) sess.run(pretrained_initializer(model, values)) def save(model, weights_path, sess): if sess is None: sess = tf.get_default_session() assert sess is not None, 'The default session should be given.' weights = get_weights(model) names = [w.name for w in weights] values = sess.run(weights) np.savez(weights_path, names=names, values=values) def pad_info(s, symmetry=True): pads = [[0, 0], [s // 2, s // 2], [s // 2, s // 2], [0, 0]] if not symmetry: pads[1][0] -= 1 pads[2][0] -= 1 return pads def crop_idx(total_size, crop_size, crop_loc, crop_grid): if isinstance(total_size, int): total_size = (total_size, total_size) if isinstance(crop_size, int): crop_size = (crop_size, crop_size) if crop_loc > -1: row_loc = crop_loc // crop_grid[0] col_loc = crop_loc % crop_grid[1] row_start = row_loc * (total_size[0] - crop_size[0]) // 2 col_start = col_loc * (total_size[1] - crop_size[1]) // 2 else: row_start = np.random.randint(0, total_size[0] - crop_size[0], 1)[0] col_start = np.random.randint(0, total_size[1] - crop_size[1], 1)[0] return row_start, col_start def crop(img, crop_size, crop_loc=4, crop_grid=(3, 3)): if isinstance(crop_loc, list): imgs = np.zeros((img.shape[0], len(crop_loc), crop_size, crop_size, 3), np.float32) for (i, loc) in enumerate(crop_loc): r, c = crop_idx(img.shape[1:3], crop_size, loc, crop_grid) imgs[:, i] = img[:, r:r+crop_size, c:c+crop_size, :] return imgs elif crop_loc == np.prod(crop_grid) + 1: imgs = np.zeros((img.shape[0], crop_loc, crop_size, crop_size, 3), np.float32) r, c = crop_idx(img.shape[1:3], crop_size, 4, crop_grid) imgs[:, 0] = img[:, r:r+crop_size, c:c+crop_size, :] imgs[:, 1] = img[:, 0:crop_size, 0:crop_size, :] imgs[:, 2] = img[:, 0:crop_size, -crop_size:, :] imgs[:, 3] = img[:, -crop_size:, 0:crop_size, :] imgs[:, 4] = img[:, -crop_size:, -crop_size:, :] imgs[:, 5:] = np.flip(imgs[:, :5], axis=3) return imgs else: r, c = crop_idx(img.shape[1:3], crop_size, crop_loc, crop_grid) return img[:, r:r+crop_size, c:c+crop_size, :] def load_img(paths, grayscale=False, target_size=None, crop_size=None, interp=None): assert cv2 is not None, '`load_img` requires `cv2`.' if interp is None: interp = cv2.INTER_CUBIC if not isinstance(paths, list): paths = [paths] if len(paths) > 1 and (target_size is None or isinstance(target_size, int)): raise ValueError('A tuple `target_size` should be provided ' 'when loading multiple images.') def _load_img(path): img = cv2.imread(path) if target_size: if isinstance(target_size, int): hw_tuple = tuple([x * target_size // min(img.shape[:2]) for x in img.shape[1::-1]]) else: hw_tuple = (target_size[1], target_size[0]) if img.shape[1::-1] != hw_tuple: img = cv2.resize(img, hw_tuple, interpolation=interp) img = img[:, :, ::-1] if len(img.shape) == 2: img = np.expand_dims(img, -1) return img if len(paths) > 1: imgs = np.zeros((len(paths),) + target_size + (3,), dtype=np.float32) for (i, path) in enumerate(paths): imgs[i] = _load_img(path) else: imgs = np.array([_load_img(paths[0])], dtype=np.float32) if crop_size is not None: imgs = crop(imgs, crop_size) return imgs def init(scopes, sess): if sess is None: sess = tf.get_default_session() assert sess is not None, 'The default session should be given.' if not isinstance(scopes, list): scopes = [scopes] for scope in scopes: sess.run(tf.variables_initializer(get_weights(scope))) def var_scope(name): def decorator(func): def wrapper(*args, **kwargs): stem = kwargs.get('stem', False) scope = kwargs.get('scope', name) reuse = kwargs.get('reuse', None) with tf.variable_scope(scope, reuse=reuse): x = func(*args, **kwargs) if func.__name__ == 'wrapper': from .middles import direct as p0 from .preprocess import direct as p1 from .pretrained import direct as p2 _scope = tf.get_variable_scope().name if tf_later_than('1.2'): _name = tf.get_default_graph().get_name_scope() else: # Note that `get_middles` and `get_outputs` # may NOT work well for TensorFlow == 1.1.0. _name = _scope if tf_later_than('2'): _input_shape = tuple(args[0].shape[1:3]) else: _input_shape = tuple([i.value for i in args[0].shape[1:3]]) _outs = get_outputs(_name) for i in p0(name)[0]: collect_named_outputs(__middles__, _scope, _outs[i]) if stem: x.aliases.insert(0, _scope) x.p = get_middles(_name)[p0(name)[2]] else: x.logits = get_outputs(_name)[-2] setattr(x, 'preprocess', p1(name, _input_shape)) setattr(x, 'pretrained', p2(name, x)) setattr(x, 'get_bottleneck', lambda: get_bottleneck(_scope)) setattr(x, 'get_middles', lambda names=None: get_middles(_name, names)) setattr(x, 'get_outputs', lambda names=None: get_outputs(_name, names)) setattr(x, 'get_weights', lambda names=None: get_weights(_scope, names)) setattr(x, 'middles', lambda names=None: get_middles(_name, names)) setattr(x, 'outputs', lambda names=None: get_outputs(_name, names)) setattr(x, 'weights', lambda names=None: get_weights(_scope, names)) setattr(x, 'summary', lambda: print_summary(_scope)) setattr(x, 'print_middles', lambda: print_middles(_name)) setattr(x, 'print_outputs', lambda: print_outputs(_name)) setattr(x, 'print_weights', lambda: print_weights(_scope)) setattr(x, 'print_summary', lambda: print_summary(_scope)) setattr(x, 'init', lambda sess=None: init(_scope, sess)) setattr(x, 'load', lambda weights_path, sess=None: load(x, weights_path, sess)) setattr(x, 'save', lambda weights_path, sess=None: save(x, weights_path, sess)) return x return wrapper return decorator def ops_to_outputs(func): def wrapper(*args, **kwargs): x = func(*args, **kwargs) x = collect_named_outputs(__outputs__, tf.get_variable_scope().name, x) return x return wrapper @contextmanager def arg_scopes(l): for x in l: x.__enter__() yield def set_args(largs, conv_bias=True, weights_regularizer=None): from .layers import conv2d from .layers import fc from .layers import sconv2d def real_set_args(func): def wrapper(*args, **kwargs): is_training = kwargs.get('is_training', False) layers = sum([x for (x, y) in largs(is_training)], []) layers_args = [arg_scope(x, **y) for (x, y) in largs(is_training)] if not conv_bias: layers_args += [arg_scope([conv2d], biases_initializer=None)] if weights_regularizer is not None: layers_args += [arg_scope( [conv2d, fc, sconv2d], weights_regularizer=weights_regularizer)] with arg_scope(layers, outputs_collections=__outputs__): with arg_scopes(layers_args): x = func(*args, **kwargs) x.model_name = func.__name__ return x return wrapper return real_set_args def pretrained_initializer(scope, values): weights = get_weights(scope) if values is None: return tf.variables_initializer(weights) if len(weights) > len(values): # excluding weights in Optimizer weights = weights[:len(values)] if len(weights) != len(values): values = values[:len(weights)] warnings.warn('The sizes of symbolic and actual weights do not match. ' 'Never mind if you are trying to load stem layers only.') if scope.dtype == tf.float16: ops = [weights[0].assign(np.asarray(values[0], dtype=np.float16))] for (w, v) in zip(weights[1:-2], values[1:-2]): w.load(np.asarray(v, dtype=np.float16)) if weights[-1].shape != values[-1].shape: ops += [w.initializer for w in weights[-2:]] else: for (w, v) in zip(weights[-2:], values[-2:]): w.load(np.asarray(v, dtype=np.float16)) return ops ops = [w.assign(v) for (w, v) in zip(weights[:-2], values[:-2])] if weights[-1].shape != values[-1].shape: # for transfer learning ops += [w.initializer for w in weights[-2:]] else: # The logits layer can be either 1x1 conv or fc. In other words, # the weight shape is (1, 1, features, classes) for the former, # or (features, classes) the latter. if weights[-2].shape != values[-2].shape: values[-2] = values[-2].reshape(weights[-2].shape) warnings.warn('The weight has been reshaped because 1x1 conv and ' 'fc layers are interchangeable for a logits layer. ' 'But, the conversion may affect the precision.') ops += [w.assign(v) for (w, v) in zip(weights[-2:], values[-2:])] return ops def parse_weights(weights_path, move_rules=None): data = np.load(weights_path, encoding='bytes', allow_pickle=True) if isinstance(data, np.lib.npyio.NpzFile) or isinstance(data, dict): values = data['values'] if tf_later_than('1.4'): for (i, name) in enumerate(data['names']): if '/beta' in str(data['names'][i-1]) and '/gamma' in str(name): values[i], values[i-1] = values[i-1], values[i] else: values = data return values def parse_keras_weights(weights_path, move_rules=None): try: import h5py except ImportError: h5py = None assert h5py is not None, '`get_values_from_keras_file` requires `h5py`.' values = [] with h5py.File(weights_path, mode='r') as f: names = [n.decode('utf8') for n in f.attrs['layer_names'] if len(f[n.decode('utf8')].attrs['weight_names']) > 0] if move_rules is not None: if isinstance(move_rules, list): for (name, loc) in move_rules: idx = names.index(name) names.insert(idx + loc, names.pop(idx)) elif move_rules == 'ordered': bn_names, conv_names, other_names = [], [], [] for n in names: if 'batch' in n: bn_names.append(n) elif 'conv' in n: conv_names.append(n) else: other_names.append(n) names = [] for n in range(1, len(conv_names) + 1): names.append("conv2d_%d" % n) names.append("batch_normalization_%d" % n) names += other_names for name in names: g = f[name] w = [n.decode('utf8') for n in g.attrs['weight_names']] v = [np.asarray(g[n]) for n in w] if not tf_later_than('1.4'): if len(v) == 4: w[0], w[1] = w[1], w[0] v[0], v[1] = v[1], v[0] values += v return values def parse_torch_weights(weights_path, move_rules=None): try: import torch import torch.nn as nn import torch.nn.functional as F except ImportError: torch = None assert torch is not None, '`get_values_from_torch_file` requires `torch`.' model = torch.load(weights_path) names = list(model.keys()) if move_rules is not None: if isinstance(move_rules, list): for (name, loc) in move_rules: idx = names.index(name) names.insert(idx + loc, names.pop(idx)) if not tf_later_than('1.4'): for (i, name) in enumerate(names): if 'running_mean' in str(name): names[i-1], names[i-2] = names[i-2], names[i-1] values = [] for name in names: val = model[name].numpy() if val.ndim == 4: val = np.transpose(val, [2, 3, 1, 0]) if val.ndim == 2: val = np.transpose(val, [1, 0]) if val.ndim == 4: groups = val.shape[3] // val.shape[2] if (groups == 32) or (groups == 64): values += np.split(val, groups, axis=3) else: values.append(val) else: values.append(val) return values def remove_head(original_stem, name): _scope = "%s/stem" % tf.get_variable_scope().name g = tf.get_default_graph() for x in g.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=_scope + '/')[::-1]: if name in x.name: break g.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES).pop() for x in g.get_collection(__outputs__, scope=_scope + '/')[::-1]: if name in x.name: break g.get_collection_ref(__outputs__).pop() x.model_name = original_stem.model_name return x def remove_utils(module_name, exceptions): import sys from . import utils module = sys.modules[module_name] for util in dir(utils): if not ((util.startswith('_')) or (util in exceptions)): try: delattr(module, util) except: None def remove_commons(module_name, exceptions=[]): import sys _commons = [ 'absolute_import', 'division' 'print_function', 'remove_commons', ] module = sys.modules[module_name] for _common in _commons: if _common not in exceptions: try: delattr(module, _common) except: None remove_commons(__name__, ['remove_commons'])