import tensorflow as tf from base.base_model import BaseModel from nets import nets_factory from preprocessing import preprocessing_factory slim = tf.contrib.slim class StyleSwapModel(BaseModel): def __init__(self, config, data_loader): super(StyleSwapModel, self).__init__(config, data_loader) self.style_layer = "/conv3/conv3_3" self.records_loader = data_loader[0] self.style_loader = data_loader[1] self.PREPROCESS_SIZE = 256 self.cell_size = 3 def _build_train_model(self): preprocess_fn = preprocessing_factory.get_preprocessing(self.config.net_name, is_training=False) [image] = self.records_loader.get_data() preprocessed_image = preprocess_fn(image, self.PREPROCESS_SIZE, self.PREPROCESS_SIZE) images = self.records_loader.batch_data(preprocessed_image) style_image = self.style_loader.get_data() preprocessed_style_image = preprocess_fn(style_image, self.PREPROCESS_SIZE, self.PREPROCESS_SIZE) style_images = self.style_loader.batch_data(preprocessed_style_image) self.swaped_tensor = self._swap_net(images, style_images) self.generated = self._inverse_net(self.swaped_tensor) slim.summary.image("generated", self.generated) slim.summary.image("origin", images) slim.summary.image("style", style_images) self._train_inverse(self.generated, self.swaped_tensor) self.init_op = self._get_network_init_fn() def _build_evaluate_model(self): self.input_image = tf.placeholder(tf.float32, shape=[None, None, 3]) self.style_image = tf.placeholder(tf.float32, shape=[None, None, 3]) preprocess_fn = preprocessing_factory.get_preprocessing(self.config.net_name, is_training=False) height = self.evaluate_height if self.evaluate_height else self.PREPROCESS_SIZE width = self.evaluate_width if self.evaluate_width else self.PREPROCESS_SIZE preprocessed_image = preprocess_fn(self.input_image, height, width, resize_side_min=min(height, width)) images = tf.expand_dims(preprocessed_image, axis=0) style_images = tf.expand_dims(preprocess_fn(self.style_image, self.PREPROCESS_SIZE, self.PREPROCESS_SIZE), axis=0) self.swaped_tensor = self._swap_net(images, style_images) # # network_fn = nets_factory.get_network_fn(self.config.net_name, num_classes=1, is_training=False) # _, endpoints_dict = network_fn(images, spatial_squeeze=False) # self.swaped_tensor = endpoints_dict[self.config.net_name + self.style_layer] self.generated = self._inverse_net(self.swaped_tensor) self.evaluate_op = tf.squeeze(self.generated, axis=0) self.init_op = self._get_network_init_fn() self.save_variables = [var for var in tf.trainable_variables() if var.name.startswith("inverse_net")] def _swap_net(self, content, style): network_fn = nets_factory.get_network_fn(self.config.net_name, num_classes=1, is_training=False) # content_amount = content.get_shape()[0].value style_amount = style.get_shape()[0].value # # images = tf.concat([content, style], axis=0) _, endpoints_dict = network_fn(content, spatial_squeeze=False) content_feature = endpoints_dict[self.config.net_name + self.style_layer] with tf.variable_scope("", reuse=True): _, endpoints_dict = network_fn(style, spatial_squeeze=False) layer_names = list(endpoints_dict.keys()) [layer_name] = [l_name for l_name in layer_names if self.style_layer in l_name] style_feature = endpoints_dict[layer_name] # content_feature, style_feature = tf.split(style_layer, num_or_size_splits=[content_amount, style_amount], # axis=0) # # print(content_feature.get_shape()) rows = tf.split(style_feature, num_or_size_splits=list( [self.cell_size] * (style_feature.get_shape()[1].value // self.cell_size) + [style_feature.get_shape()[1].value % self.cell_size]), axis=1)[:-1] cells = [tf.split(row, num_or_size_splits=list( [self.cell_size] * (style_feature.get_shape()[2].value // self.cell_size) + [style_feature.get_shape()[2].value % self.cell_size]), axis=2)[:-1] for row in rows] stacked_cells = [tf.stack(row_cell, axis=4) for row_cell in cells] filters = tf.concat(stacked_cells, axis=-1) swaped_list = [] for style_filter in tf.unstack(filters, axis=0, num=style_amount): swaped_list.append(self._swap_op(content_feature, style_filter)) return tf.concat(swaped_list, axis=0) def _train_inverse(self, generated, swaped_tensor): preprocess_fn = preprocessing_factory.get_preprocessing(self.config.net_name, is_training=False) network_fn = nets_factory.get_network_fn(self.config.net_name, num_classes=1, is_training=False) with tf.variable_scope("", reuse=True): preprocessed_image = tf.stack([preprocess_fn(img, self.PREPROCESS_SIZE, self.PREPROCESS_SIZE) for img in tf.unstack(generated, axis=0)]) _, inversed_endpoints_dict = network_fn(preprocessed_image, spatial_squeeze=False) layer_names = list(inversed_endpoints_dict.keys()) [layer_name] = [l_name for l_name in layer_names if self.style_layer in l_name] inversed_style_layer = inversed_endpoints_dict[layer_name] # print(inversed_style_layer.get_shape()) tf.losses.add_loss(tf.nn.l2_loss(swaped_tensor - inversed_style_layer)) self.loss_op = tf.losses.get_total_loss() train_vars = [var for var in tf.trainable_variables() if var.name.startswith("inverse_net")] slim.summarize_tensor(self.loss_op, "loss") slim.summarize_tensors(train_vars) # print(train_vars) self.save_variables = train_vars learning_rate = tf.train.exponential_decay(self.config.learning_rate, self.global_step, 1000, 0.66, name="learning_rate") self.train_op = tf.train.AdamOptimizer(learning_rate).minimize(self.loss_op, self.global_step, train_vars) def _swap_op(self, content_feature, style_feature): height = tf.shape(content_feature)[1] width = tf.shape(content_feature)[2] print(style_feature) normalized_filters = tf.nn.l2_normalize(style_feature, dim=(0, 1, 2)) """ change the strides to see difference""" similarity = tf.nn.conv2d(content_feature, normalized_filters, strides=[1, 1, 1, 1], padding="VALID") arg_max_filter = tf.argmax(similarity, axis=-1) one_hot_filter = tf.one_hot(arg_max_filter, depth=similarity.get_shape()[-1].value) swap = tf.nn.conv2d_transpose(one_hot_filter, style_feature, output_shape=tf.shape(content_feature), strides=[1, 1, 1, 1], padding="VALID") return swap / 9.0 def _inverse_net(self, x): with tf.variable_scope("inverse_net"): with tf.variable_scope("conv1"): x = slim.conv2d(x, num_outputs=256, kernel_size=3, stride=1, padding="SAME", weights_regularizer=slim.l2_regularizer(self.config.weight_regulation_scale)) with tf.variable_scope("residual1"): res = slim.repeat(x, 2, slim.conv2d, num_outputs=256, kernel_size=3, stride=1, weights_regularizer=slim.l2_regularizer(self.config.weight_regulation_scale)) x = res + x with tf.variable_scope("residual2"): res = slim.repeat(x, 2, slim.conv2d, num_outputs=256, kernel_size=5, stride=1, weights_regularizer=slim.l2_regularizer(self.config.weight_regulation_scale)) x = res + x ## model 2 only use 2 resi module with tf.variable_scope("residual3"): res = slim.repeat(x, 2, slim.conv2d, num_outputs=256, kernel_size=7, stride=1, weights_regularizer=slim.l2_regularizer(self.config.weight_regulation_scale)) x = res + x with tf.variable_scope("deconv1"): x = self._deconv(x, num_outputs=128, kernel_size=5, stride=2, activation_fn=tf.nn.relu, weights_regularizer=slim.l2_regularizer(self.config.weight_regulation_scale)) with tf.variable_scope("deconv2"): x = self._deconv(x, num_outputs=64, kernel_size=5, stride=2, activation_fn=tf.nn.relu, weights_regularizer=slim.l2_regularizer(self.config.weight_regulation_scale)) with tf.variable_scope("conv2"): x = slim.conv2d(x, num_outputs=3, kernel_size=3, stride=1, padding="SAME", activation_fn=tf.nn.tanh, weights_regularizer=slim.l2_regularizer(self.config.weight_regulation_scale)) x = (x + 1) * (255.0 / 2) return x def _instance_norm(self, x): epsilon = 1e-9 mean, var = tf.nn.moments(x, [1, 2], keep_dims=True) return tf.div(tf.subtract(x, mean), tf.sqrt(tf.add(var, epsilon))) def _deconv(self, x, num_outputs, kernel_size, stride, activation_fn, weights_regularizer): height = tf.shape(x)[1] width = tf.shape(x)[2] new_height = height * 2 * stride new_width = width * 2 * stride x = tf.image.resize_images( x, [new_height, new_width], method=tf.image.ResizeMethod.BILINEAR ) x = slim.conv2d(x, num_outputs=num_outputs, kernel_size=kernel_size, stride=stride, padding="SAME", activation_fn=None, weights_regularizer=weights_regularizer) x = self._instance_norm(x) x = activation_fn(x) return x def _get_network_init_fn(self): tf.logging.info("Use pretrained model {}".format(self.config.net_name)) exclusions = [] if self.config.checkpoint_exclude_scopes: exclusions = [scope.strip() for scope in self.config.checkpoint_exclude_scopes.split(",")] variables_to_restore = [] for var in slim.get_model_variables(): excluded = False for exclusion in exclusions: if var.op.name.startswith(exclusion): excluded = True break if not excluded: variables_to_restore.append(var) return slim.assign_from_checkpoint_fn( self.config.loss_model_file, variables_to_restore, ignore_missing_vars=True )