# coding=utf-8 # Copyright 2018 The THUMT Authors from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf def _get_loss_variable(graph=None): graph = graph or tf.get_default_graph() loss_tensors = tf.get_collection("loss") if len(loss_tensors) == 1: loss_tensor = loss_tensors[0] elif not loss_tensors: try: loss_tensor = graph.get_tensor_by_name("loss_tensor:0") except KeyError: return None else: tf.logging.error("Multiple tensors in loss collection.") return None return loss_tensor def _create_loss_variable(graph=None): graph = graph or tf.get_default_graph() if _get_loss_variable(graph) is not None: raise ValueError("'loss' already exists.") # Create in proper graph and base name_scope. with graph.as_default() as g, g.name_scope(None): tensor = tf.get_variable("loss", shape=[], dtype=tf.float32, initializer=tf.zeros_initializer(), trainable=False, collections=[tf.GraphKeys.GLOBAL_VARIABLES, "loss"]) return tensor def _get_or_create_loss_variable(graph=None): graph = graph or tf.get_default_graph() loss_tensor = _get_loss_variable(graph) if loss_tensor is None: loss_tensor = _create_loss_variable(graph) return loss_tensor def _zero_variables(variables, name=None): ops = [] for var in variables: with tf.device(var.device): op = var.assign(tf.zeros(var.shape.as_list())) ops.append(op) return tf.group(*ops, name=name or "zero_variables") def _replicate_variables(variables, device=None): new_vars = [] for var in variables: device = device or var.device with tf.device(device): name = var.name.split(":")[0].rstrip("/") + "/replica" new_vars.append(tf.Variable(tf.zeros(var.shape.as_list()), name=name, trainable=False)) return new_vars def _collect_gradients(gradients, variables): ops = [] for grad, var in zip(gradients, variables): if isinstance(grad, tf.Tensor): ops.append(tf.assign_add(var, grad)) else: ops.append(tf.scatter_add(var, grad.indices, grad.values)) return tf.group(*ops, name="collect_gradients") def _scale_variables(variables, scale): if not isinstance(variables, (list, tuple)): return tf.assign(variables, scale * variables) ops = [] for var in variables: ops.append(tf.assign(var, scale * var)) return tf.group(*ops, name="scale_variables") def create_train_op(loss, optimizer, global_step, params): with tf.name_scope("create_train_op"): grads_and_vars = optimizer.compute_gradients( loss, colocate_gradients_with_ops=True) gradients = [item[0] for item in grads_and_vars] variables = [item[1] for item in grads_and_vars] if params.update_cycle == 1: zero_variables_op = tf.no_op("zero_variables") collect_op = tf.no_op("collect_op") scale_op = tf.no_op("scale_op") else: # collect loss_tensor = _get_or_create_loss_variable() slot_variables = _replicate_variables(variables) zero_variables_op = _zero_variables(slot_variables + [loss_tensor]) collect_grads_op = _collect_gradients(gradients, slot_variables) collect_loss_op = tf.assign_add(loss_tensor, loss) collect_op = tf.group(collect_loss_op, collect_grads_op, name="collect_op") # scale scale = 1.0 / params.update_cycle scale_grads_op = _scale_variables(slot_variables, scale) scale_loss_op = _scale_variables(loss_tensor, scale) scale_op = tf.group(scale_grads_op, scale_loss_op, name="scale_op") gradients = slot_variables loss = tf.convert_to_tensor(loss_tensor) # Add summaries tf.summary.scalar("loss", loss) tf.summary.scalar("global_norm/gradient_norm", tf.global_norm(gradients)) for gradient, variable in zip(gradients, variables): if isinstance(gradient, tf.IndexedSlices): grad_values = gradient.values else: grad_values = gradient if grad_values is not None: var_name = variable.name.replace(":", "_") tf.summary.histogram("gradients/%s" % var_name, grad_values) tf.summary.scalar("gradient_norm/%s" % var_name, tf.global_norm([grad_values])) # Gradient clipping if isinstance(params.clip_grad_norm or None, float): gradients, _ = tf.clip_by_global_norm(gradients, params.clip_grad_norm) # Update variables grads_and_vars = list(zip(gradients, tf.trainable_variables())) train_op = optimizer.apply_gradients(grads_and_vars, global_step) ops = { "zero_op": zero_variables_op, "collect_op": collect_op, "scale_op": scale_op, "train_op": train_op } return loss, ops