"""Utilities."""

import collections
import contextlib
import inspect
import os
import platform
import shutil
import subprocess
from dataclasses import dataclass
from functools import lru_cache, partial, wraps
from typing import Any, Dict, Tuple, Optional

import numpy as np
import tensorflow as tf
from mpi4py import MPI
from tensorflow.contrib import summary

try:
    import horovod.tensorflow as hvd
    hvd.init()
except:
    hvd = None


nest = tf.contrib.framework.nest


def nvidia_gpu_count():
    """
    Count the GPUs on this machine.
    """
    if shutil.which('nvidia-smi') is None:
        return 0
    try:
        output = subprocess.check_output(['nvidia-smi', '--query-gpu=gpu_name', '--format=csv'])
    except subprocess.CalledProcessError:
        # Probably no GPUs / no driver running.
        return 0
    return max(0, len(output.split(b'\n')) - 2)


def get_local_rank_size(comm):
    """
    Returns the rank of each process on its machine
    The processes on a given machine will be assigned ranks
        0, 1, 2, ..., N-1,
    where N is the number of processes on this machine.
    Useful if you want to assign one gpu per machine
    """
    this_node = platform.node()
    ranks_nodes = comm.allgather((comm.Get_rank(), this_node))
    node2rankssofar = collections.defaultdict(int)
    local_rank = None
    for (rank, node) in ranks_nodes:
        if rank == comm.Get_rank():
            local_rank = node2rankssofar[node]
        node2rankssofar[node] += 1
    assert local_rank is not None
    return local_rank, node2rankssofar[this_node]


@lru_cache()
def gpu_devices():
    if 'CUDA_VISIBLE_DEVICES' in os.environ:
        raise ValueError('CUDA_VISIBLE_DEVICES should not be set (it will cause nccl slowdowns).  Use VISIBLE_DEVICES instead!')
    devices_str = os.environ.get('VISIBLE_DEVICES')
    if devices_str is not None:
        return list(map(int, filter(len, devices_str.split(','))))
    else:
        return list(range(nvidia_gpu_count()))

@lru_cache()
def gpu_count():
    return len(gpu_devices()) or None


@lru_cache()
def _our_gpu():
    """Figure out which GPU we should be using in an MPI context."""
    gpus = gpu_devices()
    if not gpus:
        return None
    rank = MPI.COMM_WORLD.Get_rank()
    local_rank, local_size = get_local_rank_size(MPI.COMM_WORLD)
    if gpu_count() not in (0, local_size):
        raise ValueError('Expected one GPU per rank, got gpus %s, local size %d' % (gpus, local_size))
    gpu = gpus[local_rank]
    print('rank %d: gpus = %s, our gpu = %d' % (rank, gpus, gpu))
    return gpu


def mpi_session_config():
    """Make a tf.ConfigProto to use only the GPU assigned to this MPI session."""
    config = tf.ConfigProto()
    gpu = _our_gpu()
    if gpu is not None:
        config.gpu_options.visible_device_list = str(gpu)
    config.gpu_options.allow_growth = True
    return config


def mpi_session():
    """Create a session using only the GPU assigned to this MPI process."""
    return tf.Session(config=mpi_session_config())


def set_mpi_seed(seed: Optional[int]):
    if seed is not None:
        rank = MPI.COMM_WORLD.Get_rank()
        seed = seed + rank * 100003  # Prime (kept for backwards compatibility even though it does nothing)
    np.random.seed(seed)
    tf.set_random_seed(seed)


def exact_div(a, b):
    q = a // b
    if tf.contrib.framework.is_tensor(q):
        with tf.control_dependencies([tf.debugging.Assert(tf.equal(a, q * b), [a, b])]):
            return tf.identity(q)
    else:
        if a != q * b:
            raise ValueError('Inexact division: %s / %s = %s' % (a, b, a / b))
        return q


def ceil_div(a, b):
    return (a - 1) // b + 1


def expand_tile(value, size, *, axis, name=None):
    """Add a new axis of given size."""
    with tf.name_scope(name, 'expand_tile', [value, size, axis]) as scope:
        value = tf.convert_to_tensor(value, name='value')
        size = tf.convert_to_tensor(size, name='size')
        ndims = value.shape.rank
        if axis < 0:
            axis += ndims + 1
        return tf.tile(tf.expand_dims(value, axis=axis), [1]*axis + [size] + [1]*(ndims - axis), name=scope)


def index_each(a, ix):
    """Do a batched indexing operation: index row i of a by ix[i]

    In the simple case (a is >=2D and ix is 1D), returns [row[i] for row, i in zip(a, ix)].

    If ix has more dimensions, multiple lookups will be done at each batch index.
    For instance, if ix is 2D, returns [[row[i] for i in ix_row] for row, ix_row in zip(a, ix)].

    Always indexes into dimension 1 of a.
    """
    a = tf.convert_to_tensor(a, name='a')
    ix = tf.convert_to_tensor(ix, name='ix', dtype=tf.int32)
    with tf.name_scope('index_each', values=[a, ix]) as scope:
        a.shape[:1].assert_is_compatible_with(ix.shape[:1])
        i0 = tf.range(tf.shape(a)[0], dtype=ix.dtype)
        if ix.shape.rank > 1:
            i0 = tf.tile(tf.reshape(i0, (-1,) + (1,)*(ix.shape.rank - 1)), tf.concat([[1], tf.shape(ix)[1:]], axis=0))
        return tf.gather_nd(a, tf.stack([i0, ix], axis=-1), name=scope)

def cumulative_max(x):
    """Takes the (inclusive) cumulative maximum along the last axis of x. (Not efficient.)"""
    x = tf.convert_to_tensor(x)
    with tf.name_scope('cumulative_max', values=[x]) as scope:
        repeated = tf.tile(
            tf.expand_dims(x, axis=-1),
            tf.concat([tf.ones(x.shape.rank, dtype=tf.int32), tf.shape(x)[-1:]], axis=0))
        trues = tf.ones_like(repeated, dtype=tf.bool)
        upper_triangle = tf.matrix_band_part(trues, 0, -1)
        neg_inf = tf.ones_like(repeated) * tf.dtypes.saturate_cast(-np.inf, dtype=x.dtype)
        prefixes = tf.where(upper_triangle, repeated, neg_inf)
        return tf.math.reduce_max(prefixes, axis=-2, name=scope)


def flatten_dict(nested, sep='.'):
    def rec(nest, prefix, into):
        for k, v in nest.items():
            if sep in k:
                raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'")
            if isinstance(v, collections.Mapping):
                rec(v, prefix + k + sep, into)
            else:
                into[prefix + k] = v
    flat = {}
    rec(nested, '', flat)
    return flat

@dataclass
class Schema:
    dtype: Any
    shape: Tuple[Optional[int],...]


def add_batch_dim(schemas, batch_size=None):
    def add_dim(schema):
        return Schema(dtype=schema.dtype, shape=(batch_size,)+schema.shape)
    return nest.map_structure(add_dim, schemas)


class SampleBuffer:
    """A circular buffer for storing and sampling data.

    Data can be added to the buffer with `add`, and old data will be dropped.  If you need to
    control where the buffer is stored, wrap the constructor call in a `with tf.device` block:

        with tf.device('cpu:0'):
            buffer = SampleBuffer(...)
    """

    def __init__(self, *, capacity: int, schemas: Dict[str,Schema], name=None) -> None:
        with tf.variable_scope(name, 'buffer', use_resource=True, initializer=tf.zeros_initializer):
            self._capacity = tf.constant(capacity, dtype=tf.int32, name='capacity')
            self._total = tf.get_variable(
                'total', dtype=tf.int32, shape=(), trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES],
            )
            self._vars = {
                n: tf.get_variable(
                    n, dtype=s.dtype, shape=(capacity,) + s.shape, trainable=False,
                    collections=[tf.GraphKeys.LOCAL_VARIABLES],
                )
                for n,s in schemas.items()
            }

    def add(self, **data):
        """Add new data to the end of the buffer, dropping old data if we exceed capacity."""
        # Check input shapes
        if data.keys() != self._vars.keys():
            raise ValueError('data.keys() = %s != %s' % (sorted(data.keys()), sorted(self._vars.keys())))
        first = next(iter(data.values()))
        pre = first.shape[:1]
        for k, d in data.items():
            try:
                d.shape.assert_is_compatible_with(pre.concatenate(self._vars[k].shape[1:]))
            except ValueError as e:
                raise ValueError('%s, key %s' % (e, k))
        # Enqueue
        n = tf.shape(first)[0]
        capacity = self._capacity
        i0 = (self._total.assign_add(n) - n) % capacity
        i0n = i0 + n
        i1 = tf.minimum(i0n, capacity)
        i2 = i1 % capacity
        i3 = i0n % capacity
        slices = slice(i0, i1), slice(i2, i3)
        sizes = tf.stack([i1 - i0, i3 - i2])
        assigns = [self._vars[k][s].assign(part)
                   for k,d in data.items()
                   for s, part in zip(slices, tf.split(d, sizes))]
        return tf.group(assigns)

    def total(self):
        """Total number of entries ever added, including those already discarded."""
        return self._total.read_value()

    def size(self):
        """Current number of entries."""
        return tf.minimum(self.total(), self._capacity)

    def read(self, indices):
        """indices: A 1-D Tensor of indices to read from. Each index must be less than
        capacity."""
        return {k: v.sparse_read(indices) for k,v in self._vars.items()}

    def data(self):
        return {k: v[:self.size()] for k,v in self._vars.items()}

    def sample(self, n, seed=None):
        """Sample n entries with replacement."""
        size = self.size()
        indices = tf.random_uniform([n], maxval=size, dtype=tf.int32, seed=seed)
        return self.read(indices)

    def write(self, indices, updates):
        """
        indices: A 1-D Tensor of indices to write to. Each index must be less than `capacity`.
        update: A dictionary of new values, where each entry is a tensor with the same length as `indices`.
        """
        ops = []
        for k, v in updates.items():
            ops.append(self._vars[k].scatter_update(tf.IndexedSlices(v, tf.cast(indices, dtype=tf.int32))))
        return tf.group(*ops)

    def write_add(self, indices, deltas):
        ops = []
        for k, d in deltas.items():
            ops.append(self._vars[k].scatter_add(tf.IndexedSlices(d, tf.cast(indices, dtype=tf.int32))))
        return tf.group(*ops)


def entropy_from_logits(logits):
    pd = tf.nn.softmax(logits, axis=-1)
    return tf.math.reduce_logsumexp(logits, axis=-1) - tf.reduce_sum(pd*logits, axis=-1)


def logprobs_from_logits(*, logits, labels):
    return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)


def sample_from_logits(logits, dtype=tf.int32):
    with tf.name_scope('sample_from_logits', values=[logits]) as scope:
        shape = tf.shape(logits)
        flat_logits = tf.reshape(logits, [-1, shape[-1]])
        flat_samples = tf.random.categorical(flat_logits, num_samples=1, dtype=dtype)
        return tf.reshape(flat_samples, shape[:-1], name=scope)


def take_top_k_logits(logits, k):
    values, _ = tf.nn.top_k(logits, k=k)
    min_values = values[:, :, -1, tf.newaxis]
    return tf.where(
        logits < min_values,
        tf.ones_like(logits) * -1e10,
        logits,
    )


def take_top_p_logits(logits, p):
    """Nucleus sampling"""
    batch, sequence, _ = logits.shape.as_list()
    sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
    cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
    indices = tf.stack([
        tf.range(0, batch)[:, tf.newaxis],
        tf.range(0, sequence)[tf.newaxis, :],
        # number of indices to include
        tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
    ], axis=-1)
    min_values = tf.gather_nd(sorted_logits, indices)
    return tf.where(
        logits < min_values,
        tf.ones_like(logits) * -1e10,
        logits,
    )


def whiten(values, shift_mean=True):
    mean, var = tf.nn.moments(values, axes=list(range(values.shape.rank)))
    whitened = (values - mean) * tf.rsqrt(var + 1e-8)
    if not shift_mean:
        whitened += mean
    return whitened



def where(cond, true, false, name=None):
    """Similar to tf.where, but broadcasts scalar values."""
    with tf.name_scope(name, 'where', [cond, true, false]) as name:
        cond = tf.convert_to_tensor(cond, name='cond', dtype=tf.bool)
        true = tf.convert_to_tensor(true, name='true',
                                    dtype=false.dtype if isinstance(false, tf.Tensor) else None)
        false = tf.convert_to_tensor(false, name='false', dtype=true.dtype)
        if true.shape.rank == false.shape.rank == 0:
            shape = tf.shape(cond)
            true = tf.fill(shape, true)
            false = tf.fill(shape, false)
        elif true.shape.rank == 0:
            true = tf.fill(tf.shape(false), true)
        elif false.shape.rank == 0:
            false = tf.fill(tf.shape(true), false)
        return tf.where(cond, true, false, name=name)


def map_flat(f, values):
    """Apply the function f to flattened, concatenated values, then split and reshape back to original shapes."""
    values = tuple(values)
    for v in values:
        assert not isinstance(v, tf.IndexedSlices)
    values = [tf.convert_to_tensor(v) for v in values]
    flat = tf.concat([tf.reshape(v, [-1]) for v in values], axis=0)
    flat = f(flat)
    parts = tf.split(flat, [tf.size(v) for v in values])
    return [tf.reshape(p, tf.shape(v)) for p, v in zip(parts, values)]


def map_flat_chunked(f, values, *, limit=1<<29):
    """
    Apply the function f to chunked, flattened, concatenated values, then split and reshape back to original shapes.
    """
    values = tuple(values)
    for v in values:
        assert not isinstance(v, tf.IndexedSlices)
    values = [tf.convert_to_tensor(v) for v in values]
    chunks = chunk_tensors(values, limit=limit)
    mapped_values = [v for chunk in chunks for v in map_flat(f, chunk)]
    return mapped_values


def map_flat_bits(f, values):
    """Apply the function f to bit-concatenated values, then convert back to original shapes and dtypes."""
    values = [tf.convert_to_tensor(v) for v in values]
    def maybe_bitcast(v, dtype):
        cast = tf.cast if tf.bool in (v.dtype, dtype) else tf.bitcast
        return cast(v, dtype)
    bits = [maybe_bitcast(v, tf.uint8) for v in values]
    flat = tf.concat([tf.reshape(b, [-1]) for b in bits], axis=0)
    flat = f(flat)
    parts = tf.split(flat, [tf.size(b) for b in bits])
    return [maybe_bitcast(tf.reshape(p, tf.shape(b)), v.dtype)
            for p, v, b in zip(parts, values, bits)]

def mpi_bcast_tensor_dict(d, comm):
    sorted_keys = sorted(d.keys())
    values = map_flat_bits(partial(mpi_bcast, comm), [d[k] for k in sorted_keys])
    return {k: v for k, v in zip(sorted_keys, values)}

def mpi_bcast(comm, value, root=0):
    """Broadcast value from root to other processes via a TensorFlow py_func."""
    value = tf.convert_to_tensor(value)
    if comm.Get_size() == 1:
        return value
    comm = comm.Dup()  # Allow parallelism at graph execution time
    if comm.Get_rank() == root:
        out = tf.py_func(partial(comm.bcast, root=root), [value], value.dtype)
    else:
        out = tf.py_func(partial(comm.bcast, None, root=root), [], value.dtype)
    out.set_shape(value.shape)
    return out


def chunk_tensors(tensors, *, limit=1 << 28):
    """Chunk the list of tensors into groups of size at most `limit` bytes.

    The tensors must have a static shape.
    """
    total = 0
    batches = []
    for v in tensors:
        size = v.dtype.size * v.shape.num_elements()
        if not batches or total + size > limit:
            total = 0
            batches.append([])
        total += size
        batches[-1].append(v)
    return batches


def variable_synchronizer(comm, vars, *, limit=1<<28):
    """Synchronize `vars` from the root to other processs"""
    if comm.Get_size() == 1:
        return tf.no_op()

    # Split vars into chunks so that no chunk is over limit bytes
    batches = chunk_tensors(sorted(vars, key=lambda v: v.name), limit=limit)

    # Synchronize each batch, using a separate communicator to ensure safety
    prev = tf.no_op()
    for batch in batches:
        with tf.control_dependencies([prev]):
            assigns = []
            values = map_flat_bits(partial(mpi_bcast, comm), batch)
            for var, value in zip(batch, values):
                assigns.append(var.assign(value))
            prev = tf.group(*assigns)
    return prev


def mpi_read_file(comm, path):
    """Read a file on rank 0 and broadcast the contents to all machines."""
    if comm.Get_rank() == 0:
        with tf.gfile.Open(path, 'rb') as fh:
            data = fh.read()
        comm.bcast(data)
    else:
        data = comm.bcast(None)
    return data


def mpi_allreduce_sum(values, *, comm):
    if comm.Get_size() == 1:
        return values
    orig_dtype = values.dtype
    if hvd is None:
        orig_shape = values.shape
        def _allreduce(vals):
            buf = np.zeros(vals.shape, np.float32)
            comm.Allreduce(vals, buf, op=MPI.SUM)
            return buf
        values = tf.py_func(_allreduce, [values], tf.float32)
        values.set_shape(orig_shape)
    else:
        values = hvd.mpi_ops._allreduce(values)
    return tf.cast(values, dtype=orig_dtype)


def mpi_allreduce_mean(values, *, comm):
    scale = 1 / comm.Get_size()
    values = mpi_allreduce_sum(values, comm=comm)
    return values if scale == 1 else scale * values


class FlatStats:
    """A bunch of statistics stored as a single flat tensor."""

    def __init__(self, keys, flat):
        keys = tuple(keys)
        flat = tf.convert_to_tensor(flat, dtype=tf.float32, name='flat')
        assert [len(keys)] == flat.shape.as_list()
        self.keys = keys
        self.flat = flat

    @staticmethod
    def from_dict(stats):
        for k, v in stats.items():
            if v.dtype != tf.float32:
                raise ValueError('Statistic %s has dtype %r, expected %r' % (k, v.dtype, tf.float32))
        keys = tuple(sorted(stats.keys()))
        flat = tf.stack([stats[k] for k in keys])
        return FlatStats(keys, flat)

    def concat(self, more):
        dups = set(self.keys) & set(more.keys)
        if dups:
            raise ValueError('Duplicate statistics: %s' % ', '.join(dups))
        return FlatStats(self.keys + more.keys, tf.concat([self.flat, more.flat], axis=0))

    def as_dict(self):
        flat = tf.unstack(self.flat, num=len(self.keys))
        return dict(safe_zip(self.keys, flat))

    def with_values(self, flat):
        return FlatStats(self.keys, flat)

    def map_flat(self, f):
        return FlatStats(self.keys, f(self.flat))


def find_trainable_variables(key):
    return [v for v in tf.trainable_variables() if v.op.name.startswith(key + '/')]


def variables_on_gpu():
    """Prevent variables from accidentally being placed on the CPU.

    This dodges an obscure bug in tf.train.init_from_checkpoint.
    """
    if _our_gpu() is None:
        return contextlib.suppress()
    def device(op):
        return '/gpu:0' if op.type == 'VarHandleOp' else ''
    return tf.device(device)



def graph_function(**schemas: Schema):
    def decorate(make_op):
        def make_ph(path, schema):
            return tf.placeholder(name=f'arg_{make_op.__name__}_{path}', shape=schema.shape, dtype=schema.dtype)
        phs = nest.map_structure_with_paths(make_ph, schemas)
        op = make_op(**phs)
        sig = inspect.signature(make_op)
        @wraps(make_op)
        def run(*args, **kwargs):
            bound: inspect.BoundArguments = sig.bind(*args, **kwargs)
            bound.apply_defaults()

            arg_dict = bound.arguments
            for name, param in sig.parameters.items():
                if param.kind == inspect.Parameter.VAR_KEYWORD:
                    kwargs = arg_dict[name]
                    arg_dict.update(kwargs)
                    del arg_dict[name]
            flat_phs = nest.flatten(phs)
            flat_arguments = nest.flatten_up_to(phs, bound.arguments)
            feed = {ph: arg for ph, arg in zip(flat_phs, flat_arguments)}
            run_options = tf.RunOptions(report_tensor_allocations_upon_oom=True)

            return tf.get_default_session().run(op, feed_dict=feed, options=run_options, run_metadata=None)
        return run
    return decorate



def pearson_r(x: tf.Tensor, y: tf.Tensor):
    assert x.shape.rank == 1
    assert y.shape.rank == 1
    x_mean, x_var = tf.nn.moments(x, axes=[0])
    y_mean, y_var = tf.nn.moments(y, axes=[0])
    cov = tf.reduce_mean((x - x_mean)*(y - y_mean), axis=0)
    return cov / tf.sqrt(x_var * y_var)

def shape_list(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
    dynamic = tf.shape(x)
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]

def safe_zip(*args):
    """Zip, but require all sequences to be the same length."""
    args = tuple(map(tuple, args))
    for a in args[1:]:
        if len(args[0]) != len(a):
            raise ValueError(f'Lengths do not match: {[len(a) for a in args]}')
    return zip(*args)


def get_summary_writer(save_dir, subdir='', comm=MPI.COMM_WORLD):
    if comm.Get_rank() != 0:
        return None
    if save_dir is None:
        return None
    with tf.init_scope():
        return summary.create_file_writer(os.path.join(save_dir, 'tb', subdir))


def record_stats(*, stats, summary_writer, step, log_interval, name=None, comm=MPI.COMM_WORLD):
    def log_stats(step, *stat_values):
        if comm.Get_rank() != 0 or step % log_interval != 0:
            return

        for k, v in safe_zip(stats.keys(), stat_values):
            print('k = ', k, ', v = ', v)

    summary_ops = [tf.py_func(log_stats, [step] + list(stats.values()), [])]
    if summary_writer:
        with summary_writer.as_default(), summary.always_record_summaries():
            for key, value in stats.items():
                summary_ops.append(summary.scalar(key, value, step=step))
    return tf.group(*summary_ops, name=name)


def minimize(*, loss, params, lr, name=None, comm=MPI.COMM_WORLD):
    with tf.name_scope(name, 'minimize'):
        with tf.name_scope('grads'):
            grads = tf.gradients(loss, params)
        grads, params = zip(*[(g, v) for g, v in zip(grads, params) if g is not None])
        grads = map_flat_chunked(partial(mpi_allreduce_mean, comm=comm), grads)
        optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon=1e-5, name='adam')
        opt_op = optimizer.apply_gradients(zip(grads, params), name=name)
        return opt_op