Python tensorflow.compat.v1.tensordot() Examples

The following are 18 code examples of tensorflow.compat.v1.tensordot(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.compat.v1 , or try the search function .
Example #1
Source File: ops.py    From language with Apache License 2.0 6 votes vote down vote up
def last_dim_weighted_sum(x, weight_name, weight_init=None, keepdims=False):
  """Computes a weighted sum of the last dimension of `x`.

  Args:
    x: <float32>[..., x_dim]
    weight_name: Name of the weight variable to use
    weight_init: Initializer of the weight variable
    keepdims: Whether the output should hav an ending size one dim

  Returns:
    summed: <float32>[...] or <float32>[..., 1] iff `keepdims`
  """

  dim = x.shape.as_list()[-1]
  w = tf.get_variable(weight_name, dim, initializer=weight_init)
  out = tf.tensordot(x, w, [[len(x.shape) - 1], [0]])
  if keepdims:
    return tf.expand_dims(out, len(out.shape))
  else:
    return out 
Example #2
Source File: common_layers.py    From tensor2tensor with Apache License 2.0 5 votes vote down vote up
def cumsum(x, axis=0, exclusive=False):
  """TPU hack for tf.cumsum.

  This is equivalent to tf.cumsum and is faster on TPU as of 04/2018 unless
  the axis dimension is very large.

  Args:
    x: a Tensor
    axis: an integer
    exclusive: a boolean

  Returns:
    Tensor of the same shape as x.
  """
  if not is_xla_compiled():
    return tf.cumsum(x, axis=axis, exclusive=exclusive)
  x_shape = shape_list(x)
  rank = len(x_shape)
  length = x_shape[axis]
  my_range = tf.range(length)
  comparator = tf.less if exclusive else tf.less_equal
  mask = tf.cast(
      comparator(tf.expand_dims(my_range, 1), tf.expand_dims(my_range, 0)),
      x.dtype)
  ret = tf.tensordot(x, mask, axes=[[axis], [0]])
  if axis != rank - 1:
    ret = tf.transpose(
        ret,
        list(range(axis)) + [rank - 1] + list(range(axis, rank - 1)))
  return ret 
Example #3
Source File: specgrams_helper.py    From magenta with Apache License 2.0 5 votes vote down vote up
def specgrams_to_melspecgrams(self, specgrams):
    """Converts specgrams to melspecgrams.

    Args:
      specgrams: Tensor of log magnitudes and instantaneous frequencies,
        shape [batch, time, freq, 2].

    Returns:
      melspecgrams: Tensor of log magnitudes and instantaneous frequencies,
        shape [batch, time, freq, 2], mel scaling of frequencies.
    """
    if self._mel_downscale is None:
      return specgrams

    logmag = specgrams[:, :, :, 0]
    p = specgrams[:, :, :, 1]

    mag2 = tf.exp(2.0 * logmag)
    phase_angle = tf.cumsum(p * np.pi, axis=-2)

    l2mel = tf.to_float(self._linear_to_mel_matrix())
    logmelmag2 = self._safe_log(tf.tensordot(mag2, l2mel, 1))
    mel_phase_angle = tf.tensordot(phase_angle, l2mel, 1)
    mel_p = spectral_ops.instantaneous_frequency(mel_phase_angle)

    return tf.concat(
        [logmelmag2[:, :, :, tf.newaxis], mel_p[:, :, :, tf.newaxis]], axis=-1) 
Example #4
Source File: specgrams_helper.py    From magenta with Apache License 2.0 5 votes vote down vote up
def melspecgrams_to_specgrams(self, melspecgrams):
    """Converts melspecgrams to specgrams.

    Args:
      melspecgrams: Tensor of log magnitudes and instantaneous frequencies,
        shape [batch, time, freq, 2], mel scaling of frequencies.

    Returns:
      specgrams: Tensor of log magnitudes and instantaneous frequencies,
        shape [batch, time, freq, 2].
    """
    if self._mel_downscale is None:
      return melspecgrams

    logmelmag2 = melspecgrams[:, :, :, 0]
    mel_p = melspecgrams[:, :, :, 1]

    mel2l = tf.to_float(self._mel_to_linear_matrix())
    mag2 = tf.tensordot(tf.exp(logmelmag2), mel2l, 1)
    logmag = 0.5 * self._safe_log(mag2)
    mel_phase_angle = tf.cumsum(mel_p * np.pi, axis=-2)
    phase_angle = tf.tensordot(mel_phase_angle, mel2l, 1)
    p = spectral_ops.instantaneous_frequency(phase_angle)

    return tf.concat(
        [logmag[:, :, :, tf.newaxis], p[:, :, :, tf.newaxis]], axis=-1) 
Example #5
Source File: atari_helpers.py    From batch_rl with Apache License 2.0 5 votes vote down vote up
def combine_q_functions(q_functions, transform_strategy, **kwargs):
  """Utility function for combining multiple Q functions.

  Args:
    q_functions: Multiple Q-functions concatenated.
    transform_strategy: str, Possible options include (1) 'IDENTITY' for no
      transformation (2) 'STOCHASTIC' for random convex combination.
    **kwargs: Arbitrary keyword arguments. Used for passing `transform_matrix`,
      the matrix for transforming the Q-values if the passed
      `transform_strategy` is `STOCHASTIC`.

  Returns:
    q_functions: Modified Q-functions.
    q_values: Q-values based on combining the multiple heads.
  """
  # Create q_values before reordering the heads for training
  q_values = tf.reduce_mean(q_functions, axis=-1)

  if transform_strategy == 'STOCHASTIC':
    left_stochastic_matrix = kwargs.get('transform_matrix')
    if left_stochastic_matrix is None:
      raise ValueError('None value provided for stochastic matrix')
    q_functions = tf.tensordot(
        q_functions, left_stochastic_matrix, axes=[[2], [0]])
  elif transform_strategy == 'IDENTITY':
    tf.logging.info('Identity transformation Q-function heads')
  else:
    raise ValueError(
        '{} is not a valid reordering strategy'.format(transform_strategy))
  return q_functions, q_values 
Example #6
Source File: multi_network_dqn_agent.py    From batch_rl with Apache License 2.0 5 votes vote down vote up
def _build_networks(self):
    super(MultiNetworkDQNAgent, self)._build_networks()
    # q_argmax is only used for picking an action
    self._q_argmax_eval = tf.argmax(self._net_outputs.q_values, axis=1)[0]
    if self.use_deep_exploration:
      if self.transform_strategy.endswith('STOCHASTIC'):
        q_transform = atari_helpers.random_stochastic_matrix(
            self.num_networks, num_cols=1)
        self._q_episode_transform = tf.get_variable(
            trainable=False,
            dtype=tf.float32,
            shape=q_transform.get_shape().as_list(),
            name='q_episode_transform')
        self._update_episode_q_function = self._q_episode_transform.assign(
            q_transform)
        episode_q_function = tf.tensordot(
            self._net_outputs.unordered_q_networks,
            self._q_episode_transform, axes=[[2], [0]])
        self._q_argmax_train = tf.argmax(episode_q_function[:, :, 0], axis=1)[0]
      elif self.transform_strategy == 'IDENTITY':
        self._q_function_index = tf.Variable(
            initial_value=0,
            trainable=False,
            dtype=tf.int32,
            shape=(),
            name='q_head_episode')
        self._update_episode_q_function = self._q_function_index.assign(
            tf.random.uniform(
                shape=(), maxval=self.num_networks, dtype=tf.int32))
        q_function = self._net_outputs.unordered_q_networks[
            :, :, self._q_function_index]
        # This is only used for picking an action
        self._q_argmax_train = tf.argmax(q_function, axis=1)[0]
    else:
      self._q_argmax_train = self._q_argmax_eval 
Example #7
Source File: crown.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def apply_linear(self, wrapper, w, b):
    """Propagate CROWN bounds backward through a linear layer."""
    def _linear_propagate(bound):
      """Propagate one side of the bound."""
      new_bound_w = tf.einsum('nsk,lk->nsl', bound.w, w)
      if b is not None:
        bias = tf.tensordot(bound.w, b, axes=1)
      return fastlin.LinearExpression(w=new_bound_w, b=bias + bound.b,
                                      lower=wrapper.input_bounds.lower,
                                      upper=wrapper.input_bounds.upper)
    ub_expr = _linear_propagate(self.upper) if self.upper else None
    lb_expr = _linear_propagate(self.lower) if self.lower else None
    return BackwardBounds(lb_expr, ub_expr) 
Example #8
Source File: crown.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def apply_conv2d(self, wrapper, w, b, padding, strides):
    """Propagate CROWN bounds backward through a convolution layer."""
    def _conv2d_propagate(bound):
      """Propagate one side of the bound."""
      s = tf.shape(bound.w)
      # Variable bound.w has shape (batch_size, num_specs, H, W, C),
      # resize it to (batch_size * num_specs, H, W, C) for batch processing.
      effective_batch_size = tf.reshape(s[0] * s[1], [1])
      batched_shape = tf.concat([effective_batch_size, s[2:]], 0)
      # The output of a deconvolution is the input shape of the corresponding
      # convolution.
      output_shape = wrapper.input_bounds.lower.shape
      batched_output_shape = tf.concat([effective_batch_size, output_shape[1:]],
                                       0)
      # Batched transpose convolution for efficiency.
      bound_batch = tf.nn.conv2d_transpose(tf.reshape(bound.w, batched_shape),
                                           filter=w,
                                           output_shape=batched_output_shape,
                                           strides=[1] + list(strides) + [1],
                                           padding=padding)
      # Reshape results to (batch_size, num_specs, new_H, new_W, new_C).
      new_shape = tf.concat(
          [tf.reshape(s[0], [1]), tf.reshape(s[1], [1]), output_shape[1:]], 0)
      new_bound_w = tf.reshape(bound_batch, new_shape)
      # If this convolution has bias, multiplies it with current w.
      bias = 0
      if b is not None:
        # Variable bound.w has dimension (batch_size, num_specs, H, W, C),
        # accumulate H and W, and do a dot product for each channel C.
        bias = tf.tensordot(tf.reduce_sum(bound.w, [2, 3]), b, axes=1)
      return fastlin.LinearExpression(w=new_bound_w, b=bias + bound.b,
                                      lower=wrapper.input_bounds.lower,
                                      upper=wrapper.input_bounds.upper)
    ub_expr = _conv2d_propagate(self.upper) if self.upper else None
    lb_expr = _conv2d_propagate(self.lower) if self.lower else None
    return BackwardBounds(lb_expr, ub_expr) 
Example #9
Source File: simplex_bounds.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def apply_linear(self, wrapper, w, b):
    mapped_centres = tf.matmul(self.nominal, w)
    mapped_vertices = tf.tensordot(self.vertices, w, axes=1)

    lb, ub = _simplex_bounds(mapped_vertices, mapped_centres, self.r, -2)

    nominal_out = tf.matmul(self.nominal, w)
    if b is not None:
      nominal_out += b

    return relative_bounds.RelativeIntervalBounds(lb, ub, nominal_out) 
Example #10
Source File: fastlin.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def _scale_expression(expr, w):
    """Scale a linear expression by w."""
    b = tf.matmul(expr.b, w)
    w = tf.tensordot(expr.w, w, axes=1)
    return LinearExpression(w=w, b=b, lower=expr.lower, upper=expr.upper) 
Example #11
Source File: test_case_test.py    From models with Apache License 2.0 5 votes vote down vote up
def test_simple(self):
    def graph_fn(tensora, tensorb):
      return tf.tensordot(tensora, tensorb, axes=1)

    tensora_np = np.ones(20)
    tensorb_np = tensora_np * 2
    output = self.execute(graph_fn, [tensora_np, tensorb_np])
    self.assertAllClose(output, 40.0) 
Example #12
Source File: model_fns.py    From language with Apache License 2.0 5 votes vote down vote up
def _get_bert_embeddings(model, layers_to_use, aggregation_fn, name="bert"):
  """Extract embeddings from BERT model."""
  all_hidden = model.get_all_encoder_layers()
  layers_hidden = [all_hidden[i] for i in layers_to_use]
  hidden_shapes = [
      modeling.get_shape_list(hid, expected_rank=3) for hid in all_hidden
  ]

  if len(layers_hidden) == 1:
    hidden_emb = layers_hidden[0]
    hidden_size = hidden_shapes[0][2]
  elif aggregation_fn == "concat":
    hidden_emb = tf.concat(layers_hidden, 2)
    hidden_size = sum([hidden_shapes[i][2] for i in layers_to_use])
  elif aggregation_fn == "average":
    hidden_size = hidden_shapes[0][2]
    assert all([shape[2] == hidden_size for shape in hidden_shapes
               ]), hidden_shapes
    hidden_emb = tf.add_n(layers_hidden) / len(layers_hidden)
  elif aggregation_fn == "attention":
    hidden_size = hidden_shapes[0][2]
    mixing_weights = tf.get_variable(
        name + "/mixing/weights", [len(layers_hidden)],
        initializer=tf.zeros_initializer())
    mixing_scores = tf.nn.softmax(mixing_weights)
    hidden_emb = tf.tensordot(
        tf.stack(layers_hidden, axis=-1), mixing_scores, [[-1], [0]])
  else:
    raise ValueError("Unrecognized aggregation function %s." % aggregation_fn)

  return hidden_emb, hidden_size 
Example #13
Source File: run_dualencoder_qa.py    From language with Apache License 2.0 5 votes vote down vote up
def _get_bert_embeddings(model, layers_to_use, aggregation_fn, name="bert"):
  """Extract embeddings from BERT model."""
  all_hidden = model.get_all_encoder_layers()
  layers_hidden = [all_hidden[i] for i in layers_to_use]
  hidden_shapes = [
      modeling.get_shape_list(hid, expected_rank=3) for hid in all_hidden
  ]

  if len(layers_hidden) == 1:
    hidden_emb = layers_hidden[0]
    hidden_size = hidden_shapes[0][2]
  elif aggregation_fn == "concat":
    hidden_emb = tf.concat(layers_hidden, 2)
    hidden_size = sum([hidden_shapes[i][2] for i in layers_to_use])
  elif aggregation_fn == "average":
    hidden_size = hidden_shapes[0][2]
    assert all([shape[2] == hidden_size for shape in hidden_shapes
               ]), hidden_shapes
    hidden_emb = tf.add_n(layers_hidden) / len(layers_hidden)
  elif aggregation_fn == "attention":
    hidden_size = hidden_shapes[0][2]
    mixing_weights = tf.get_variable(
        name + "/mixing/weights", [len(layers_hidden)],
        initializer=tf.zeros_initializer())
    mixing_scores = tf.nn.softmax(mixing_weights)
    hidden_emb = tf.tensordot(
        tf.stack(layers_hidden, axis=-1), mixing_scores, [[-1], [0]])
  else:
    raise ValueError("Unrecognized aggregation function %s." % aggregation_fn)

  return hidden_emb, hidden_size 
Example #14
Source File: run_dualencoder_lsf.py    From language with Apache License 2.0 5 votes vote down vote up
def _get_bert_embeddings(model, layers_to_use, aggregation_fn, name="bert"):
  """Extract embeddings from BERT model."""
  all_hidden = model.get_all_encoder_layers()
  layers_hidden = [all_hidden[i] for i in layers_to_use]
  hidden_shapes = [
      modeling.get_shape_list(hid, expected_rank=3) for hid in all_hidden
  ]

  if len(layers_hidden) == 1:
    hidden_emb = layers_hidden[0]
    hidden_size = hidden_shapes[0][2]
  elif aggregation_fn == "concat":
    hidden_emb = tf.concat(layers_hidden, 2)
    hidden_size = sum([hidden_shapes[i][2] for i in layers_to_use])
  elif aggregation_fn == "average":
    hidden_size = hidden_shapes[0][2]
    assert all([shape[2] == hidden_size for shape in hidden_shapes
               ]), hidden_shapes
    hidden_emb = tf.add_n(layers_hidden) / len(layers_hidden)
  elif aggregation_fn == "attention":
    hidden_size = hidden_shapes[0][2]
    mixing_weights = tf.get_variable(
        name + "/mixing/weights", [len(layers_hidden)],
        initializer=tf.zeros_initializer())
    mixing_scores = tf.nn.softmax(mixing_weights)
    hidden_emb = tf.tensordot(
        tf.stack(layers_hidden, axis=-1), mixing_scores, [[-1], [0]])
  else:
    raise ValueError("Unrecognized aggregation function %s." % aggregation_fn)

  return hidden_emb, hidden_size 
Example #15
Source File: ops.py    From language with Apache License 2.0 5 votes vote down vote up
def affine(x, output_size, weight_name, bias_name=None, weight_init=None):
  """Affine transformation of the input `x`.

  Args:
    x: <float32>[..., x_dim]
    output_size: size of the last output dimension
    weight_name: Name of the weight variable to use
    bias_name: Name for the bias variable, if one should be used
    weight_init: Initializer of the weight variable

  Returns:
    transformed <float32>[..., `output_size`]
  """
  dim = x.shape.as_list()[-1]
  w = tf.get_variable(
      weight_name, (dim, output_size), tf.float32, initializer=weight_init)
  out = tf.tensordot(x, w, [[len(x.shape) - 1], [0]])
  if bias_name:
    b = tf.get_variable(
        bias_name, (output_size,),
        tf.float32,
        initializer=tf.zeros_initializer())
    for _ in range(len(out.shape) - 1):
      b = tf.expand_dims(b, 0)
    out += b
  return out 
Example #16
Source File: box_list_ops.py    From models with Apache License 2.0 4 votes vote down vote up
def boolean_mask(boxlist, indicator, fields=None, scope=None,
                 use_static_shapes=False, indicator_sum=None):
  """Select boxes from BoxList according to indicator and return new BoxList.

  `boolean_mask` returns the subset of boxes that are marked as "True" by the
  indicator tensor. By default, `boolean_mask` returns boxes corresponding to
  the input index list, as well as all additional fields stored in the boxlist
  (indexing into the first dimension).  However one can optionally only draw
  from a subset of fields.

  Args:
    boxlist: BoxList holding N boxes
    indicator: a rank-1 boolean tensor
    fields: (optional) list of fields to also gather from.  If None (default),
      all fields are gathered from.  Pass an empty fields list to only gather
      the box coordinates.
    scope: name scope.
    use_static_shapes: Whether to use an implementation with static shape
      gurantees.
    indicator_sum: An integer containing the sum of `indicator` vector. Only
      required if `use_static_shape` is True.

  Returns:
    subboxlist: a BoxList corresponding to the subset of the input BoxList
      specified by indicator
  Raises:
    ValueError: if `indicator` is not a rank-1 boolean tensor.
  """
  with tf.name_scope(scope, 'BooleanMask'):
    if indicator.shape.ndims != 1:
      raise ValueError('indicator should have rank 1')
    if indicator.dtype != tf.bool:
      raise ValueError('indicator should be a boolean tensor')
    if use_static_shapes:
      if not (indicator_sum and isinstance(indicator_sum, int)):
        raise ValueError('`indicator_sum` must be a of type int')
      selected_positions = tf.cast(indicator, dtype=tf.float32)
      indexed_positions = tf.cast(
          tf.multiply(
              tf.cumsum(selected_positions), selected_positions),
          dtype=tf.int32)
      one_hot_selector = tf.one_hot(
          indexed_positions - 1, indicator_sum, dtype=tf.float32)
      sampled_indices = tf.cast(
          tf.tensordot(
              tf.cast(tf.range(tf.shape(indicator)[0]), dtype=tf.float32),
              one_hot_selector,
              axes=[0, 0]),
          dtype=tf.int32)
      return gather(boxlist, sampled_indices, use_static_shapes=True)
    else:
      subboxlist = box_list.BoxList(tf.boolean_mask(boxlist.get(), indicator))
      if fields is None:
        fields = boxlist.get_extra_fields()
      for field in fields:
        if not boxlist.has_field(field):
          raise ValueError('boxlist must contain all specified fields')
        subfieldlist = tf.boolean_mask(boxlist.get_field(field), indicator)
        subboxlist.add_field(field, subfieldlist)
      return subboxlist 
Example #17
Source File: tf_mittens.py    From mittens with Apache License 2.0 4 votes vote down vote up
def _build_graph(self, vocab, initial_embedding_dict):
        """Builds the computatation graph.

        Parameters
        ------------
        vocab : Iterable
        initial_embedding_dict : dict
        """
        # Constants
        self.ones = tf.ones([self.n_words, 1])

        # Parameters:
        if initial_embedding_dict is None:
            # Ordinary GloVe
            self.W = self._weight_init(self.n_words, self.n, 'W')
            self.C = self._weight_init(self.n_words, self.n, 'C')
        else:
            # This is the case where we have values to use as a
            # "warm start":
            self.n = len(next(iter(initial_embedding_dict.values())))
            W = randmatrix(len(vocab), self.n)
            C = randmatrix(len(vocab), self.n)
            self.original_embedding = np.zeros((len(vocab), self.n))
            self.has_embedding = np.zeros(len(vocab))
            for i, w in enumerate(vocab):
                if w in initial_embedding_dict:
                    self.has_embedding[i] = 1.0
                    embedding = np.array(initial_embedding_dict[w])
                    self.original_embedding[i] = embedding
                    # Divide the original embedding into W and C,
                    # plus some noise to break the symmetry that would
                    # otherwise cause both gradient updates to be
                    # identical.
                    W[i] = 0.5 * embedding + noise(self.n)
                    C[i] = 0.5 * embedding + noise(self.n)
            self.W = tf.Variable(W, name='W', dtype=tf.float32)
            self.C = tf.Variable(C, name='C', dtype=tf.float32)
            self.original_embedding = tf.constant(self.original_embedding,
                                                  dtype=tf.float32)
            self.has_embedding = tf.constant(self.has_embedding,
                                             dtype=tf.float32)
            # This is for testing. It differs from
            # `self.original_embedding` only in that it includes the
            # random noise we added above to break the symmetry.
            self.G_start = W + C

        self.bw = self._weight_init(self.n_words, 1, 'bw')
        self.bc = self._weight_init(self.n_words, 1, 'bc')

        self.model = tf.tensordot(self.W, tf.transpose(self.C), axes=1) + \
                     tf.tensordot(self.bw, tf.transpose(self.ones), axes=1) + \
                     tf.tensordot(self.ones, tf.transpose(self.bc), axes=1) 
Example #18
Source File: svg_decoder_loss.py    From magenta with Apache License 2.0 4 votes vote down vote up
def real_svg_loss(top_out, targets, model_hparams, unused_vocab_size,
                  unused_weights_fn):
  """Computes loss for svg decoder model."""
  # targets already come in 10-dim mode, no need to so any mdn stuff
  # obviously.
  targets_commands_rel = targets[..., :4]
  targets_args_rel = targets[..., 4:]

  with tf.variable_scope('full_command_loss'):
    num_mix = model_hparams.num_mixture
    commands = top_out[:, :, :, :4]
    args = top_out[:, :, :, 4:]
    # args are [batch, seq, 1, 6*3*num_mix]. want [batch * seq * 6, 3*num_mix]
    args = tf.reshape(args, [-1, 3 * num_mix])
    out_logmix, out_mean, out_logstd = _get_mdn_coef(args)

    # before we compute mdn_args_loss, we need to create a mask for elements
    # to ignore on it.
    # create mask
    masktemplate = tf.constant([[0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 1., 1.],
                                [0., 0., 0., 0., 1., 1.],
                                [1., 1., 1., 1., 1., 1.]])
    mask = tf.tensordot(targets_commands_rel, masktemplate, [[-1], [-2]])

    # calculate mdn loss, which auto masks it out
    targs_flat = tf.reshape(targets_args_rel, [-1, 1])
    mdn_loss = _get_mdn_loss(out_logmix, out_mean, out_logstd, targs_flat, mask,
                             model_hparams.dont_reduce_loss)

    # we dont have to manually mask out the softmax xent loss because
    # internally, each dimention of the xent loss is multiplied by the
    # given probability in the label for that dim. So for a one-hot label [0,
    # 1, 0] the xent loss between logit[0] and label[0] are multiplied by 0,
    # whereas between logit[1] and label[1] are multiplied by 1. Because our
    # targets_commands_rel is all 0s for the padding, sofmax_xent_loss is 0
    # for those elements as well.
    softmax_xent_loss = tf.nn.softmax_cross_entropy_with_logits(
        labels=targets_commands_rel, logits=commands)

    # Accumulate losses
    if model_hparams.dont_reduce_loss:
      softmax_xent_loss = tf.reduce_mean(softmax_xent_loss, [1, 2])
    else:
      softmax_xent_loss = tf.reduce_mean(softmax_xent_loss)
    loss = (model_hparams.mdn_k  * mdn_loss +
            model_hparams.soft_k * softmax_xent_loss)

  global _summarized_losses
  if not _summarized_losses:
    with tf.name_scope(None), tf.name_scope('losses_command'):
      tf.summary.scalar('mdn_loss', mdn_loss)
      tf.summary.scalar('softmax_xent_loss', softmax_xent_loss)

  # this tells us not to re-create the summary ops
  _summarized_losses = True

  return loss, tf.constant(1.0)