"""Utility functions for handling TPU graphs."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import contextlib2
import tensorflow.compat.v1 as tf

_run_on_cpu = False


@contextlib2.contextmanager
def run_on_cpu():
  """Provide a context for the code that needs to run on CPU.

  Not thread-safe.

  Yields:
    None.
  """
  global _run_on_cpu
  original_run_on_cpu = _run_on_cpu
  _run_on_cpu = True
  try:
    yield
  finally:
    _run_on_cpu = original_run_on_cpu


def is_on_cpu():
  return _run_on_cpu


def get_variable_name(read_variable_op):
  assert read_variable_op.type == 'ReadVariableOp'
  op = read_variable_op
  while op.type != 'VarHandleOp':
    assert len(op.inputs) == 1
    op = op.inputs[0].op
  return op.name


def maybe_convert_to_variable(tensor):
  """Read value of a tensor from a variable when possible.

  This function is intended to make tensors from inside the TPU while loop
  available on the CPU by reading it from the variable to which the tensor was
  written earlier. Note that the read may not reflect any writes that happened
  in the same session.run(), unless control dependencies are added.

  Args:
    tensor: A tf.Tensor.

  Returns:
    A tf.Tensor. If input tensor is an output of reading a ResourceVariable, we
    return an equivalent tensor produced in the current context. Otherwise, we
    return the original input tensor.
  """
  op = tensor.op
  if is_on_cpu() and tensor in var_store:
    return var_store[tensor]
  while op.type == 'Identity':
    assert len(op.inputs) == 1
    op = op.inputs[0].op
  if op.type != 'ReadVariableOp':
    # No need to convert.
    return tensor
  with tf.variable_scope(
      # Reset the scope because variable_name contains all the scopes we need.
      name_or_scope=tf.VariableScope(''),
      # We are looking for a reference to an existing variable, so we want to
      # raise an exception if variable is not found.
      reuse=True,
  ):
    variable_name = get_variable_name(op)
    tf.logging.info('Converting tensor %s --> variable %s',
                    tensor, variable_name)
    try:
      return tf.get_variable(variable_name)
    except ValueError:
      tf.logging.info(
          'Variable %s was not created with tf.get_variable(). '
          'Attempting to find it in GLOBAL_VARIABLES collection.',
          variable_name)
    global_vars = tensor.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    matched_vars = [v for v in global_vars if v.name == variable_name + ':0']
    if not matched_vars:
      raise ValueError('Variable %s is in GraphDef but not in the live graph.')
    assert len(matched_vars) == 1
    return matched_vars[0]


var_store = {}
top_level_scope = tf.get_variable_scope()


def write_to_variable(tensor, fail_if_exists=True):
  """Saves a tensor for later retrieval on CPU."""
  # Only relevant for debugging.
  debug_name = 'tpu_util__' + tensor.name.split(':')[0]

  reuse = False if fail_if_exists else tf.compat.v1.AUTO_REUSE
  with tf.variable_scope(top_level_scope, reuse=reuse):
    variable = tf.get_variable(
        name=debug_name,
        shape=tensor.shape,
        dtype=tensor.dtype,
        trainable=False,
        use_resource=True)

  var_store[tensor] = variable
  with tf.control_dependencies([variable.assign(tensor)]):
    tensor_copy = tf.identity(tensor)
  var_store[tensor_copy] = variable
  return tensor_copy


def read_from_variable(tensor):
  """Retrieves (a possibly stale copy of) the previously stored tensor."""
  if is_on_cpu():
    # Stale read, but on CPU that's all we can do without adding to loop vars.
    return var_store[tensor]
  else:
    # Current read, but only works on TPU.
    return tensor


def is_intermediate_var(v):
  """Returns True if `v` was created by `write_to_variable` above."""
  return v in var_store.values()