# Copyright 2019 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import os
import re

from absl import flags
import numpy as np
import tensorflow as tf
from tensorflow.python.client import device_lib

_GPUS = None
flags.DEFINE_bool('log_device_placement', False, 'For debugging purpose.')

def get_config():
    config = tf.ConfigProto()
    if len(get_available_gpus()) > 1:
        config.allow_soft_placement = True
    if FLAGS.log_device_placement:
        config.log_device_placement = True
    config.gpu_options.allow_growth = True
    return config

def setup_tf():
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

def smart_shape(x):
    s = x.shape
    st = tf.shape(x)
    return [s[i] if s[i].value is not None else st[i] for i in range(4)]

def ilog2(x):
    """Integer log2."""
    return int(np.ceil(np.log2(x)))

def find_latest_checkpoint(dir, glob_term='model.ckpt-*.meta'):
    """Replacement for tf.train.latest_checkpoint.

    It does not rely on the "checkpoint" file which sometimes contains
    absolute path and is generally hard to work with when sharing files
    between users / computers.
    r_step = re.compile('.*model\.ckpt-(?P<step>\d+)\.meta')
    matches = glob.glob(os.path.join(dir, glob_term))
    matches = [(int(r_step.match(x).group('step')), x) for x in matches]
    ckpt_file = max(matches)[1][:-5]
    return ckpt_file

def get_latest_global_step(dir):
    """Loads the global step from the latest checkpoint in directory.
      dir: string, path to the checkpoint directory.
      int, the global step of the latest checkpoint or 0 if none was found.
        checkpoint_reader = tf.train.NewCheckpointReader(find_latest_checkpoint(dir))
        return checkpoint_reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
    except:  # pylint: disable=bare-except
        return 0

def get_latest_global_step_in_subdir(dir):
    """Loads the global step from the latest checkpoint in sub-directories.

      dir: string, parent of the checkpoint directories.

      int, the global step of the latest checkpoint or 0 if none was found.
    sub_dirs = (x for x in glob.glob(os.path.join(dir, '*')) if os.path.isdir(x))
    step = 0
    for x in sub_dirs:
        step = max(step, get_latest_global_step(x))
    return step

def getter_ema(ema, getter, name, *args, **kwargs):
    """Exponential moving average getter for variable scopes.

        ema: ExponentialMovingAverage object, where to get variable moving averages.
        getter: default variable scope getter.
        name: variable name.
        *args: extra args passed to default getter.
        **kwargs: extra args passed to default getter.

        If found the moving average variable, otherwise the default variable.
    var = getter(name, *args, **kwargs)
    ema_var = ema.average(var)
    return ema_var if ema_var else var

def model_vars(scope=None):
    return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)

def gpu(x):
    return '/gpu:%d' % (x % max(1, len(get_available_gpus())))

def get_available_gpus():
    global _GPUS
    if _GPUS is None:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        local_device_protos = device_lib.list_local_devices(session_config=config)
        _GPUS = tuple([x.name for x in local_device_protos if x.device_type == 'GPU'])
    return _GPUS

def average_gradients(tower_grads):
    # Adapted from:
    #  https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py
    """Calculate the average gradient for each shared variable across all towers.
    Note that this function provides a synchronization point across all towers.
      tower_grads: List of lists of (gradient, variable) tuples. For each tower, a list of its gradients.
       List of pairs of (gradient, variable) where the gradient has been averaged
       across all towers.
    if len(tower_grads) <= 1:
        return tower_grads[0]

    average_grads = []
    for grads_and_vars in zip(*tower_grads):
        grad = tf.reduce_mean([gv[0] for gv in grads_and_vars], 0)
        average_grads.append((grad, grads_and_vars[0][1]))
    return average_grads

def para_list(fn, *args):
    """Run on multiple GPUs in parallel and return list of results."""
    gpus = len(get_available_gpus())
    if gpus <= 1:
        return zip(*[fn(*args)])
    splitted = [tf.split(x, gpus) for x in args]
    outputs = []
    for gpu, x in enumerate(zip(*splitted)):
        with tf.name_scope('tower%d' % gpu):
            with tf.device(tf.train.replica_device_setter(
                    worker_device='/gpu:%d' % gpu, ps_device='/cpu:0', ps_tasks=1)):
    return zip(*outputs)

def para_mean(fn, *args):
    """Run on multiple GPUs in parallel and return means."""
    gpus = len(get_available_gpus())
    if gpus <= 1:
        return fn(*args)
    splitted = [tf.split(x, gpus) for x in args]
    outputs = []
    for gpu, x in enumerate(zip(*splitted)):
        with tf.name_scope('tower%d' % gpu):
            with tf.device(tf.train.replica_device_setter(
                    worker_device='/gpu:%d' % gpu, ps_device='/cpu:0', ps_tasks=1)):
    if isinstance(outputs[0], (tuple, list)):
        return [tf.reduce_mean(x, 0) for x in zip(*outputs)]
    return tf.reduce_mean(outputs, 0)

def para_cat(fn, *args):
    """Run on multiple GPUs in parallel and return concatenated outputs."""
    gpus = len(get_available_gpus())
    if gpus <= 1:
        return fn(*args)
    splitted = [tf.split(x, gpus) for x in args]
    outputs = []
    for gpu, x in enumerate(zip(*splitted)):
        with tf.name_scope('tower%d' % gpu):
            with tf.device(tf.train.replica_device_setter(
                    worker_device='/gpu:%d' % gpu, ps_device='/cpu:0', ps_tasks=1)):
    if isinstance(outputs[0], (tuple, list)):
        return [tf.concat(x, axis=0) for x in zip(*outputs)]
    return tf.concat(outputs, axis=0)

## Adapted from brain-research/realistic-ssl-evaluation/master/lib/tf_utils.py
## Used for out-of-distribution example robustness experiments.
def hash_float(x, big_num=1000 * 1000):
    """Hash a tensor 'x' into a floating point number in the range [0, 1)."""
    return tf.cast(
        tf.string_to_hash_bucket_fast(x, big_num), tf.float32
    ) / tf.constant(float(big_num))

def make_set_filter_fn(elements):
    """Constructs a TensorFlow "set" data structure.

    Note that sets returned by this function are uninitialized. Initialize them
    by calling `sess.run(tf.tables_initializer())`

        elements: A list of non-Tensor elements.

        A function that when called with a single tensor argument, returns
        a boolean tensor if the argument is in the set.
    table = tf.contrib.lookup.HashTable(
            elements, tf.tile([1], [len(elements)])

    return lambda x: tf.equal(table.lookup(tf.dtypes.cast(x, tf.int32)), 1)

def filter_fn_from_comma_delimited(string):
    """Parses a string of comma delimited numbers, returning a filter function.

    This utility function is useful for parsing flags that represent a set of
    options, where the default is "all options".

        string: e.g. "1,2,3", or empty string for "set of all elements"

        A function that when called with a single tensor argument, returns
        a boolean tensor that evaluates to True if the argument is in the set.
        If 'string' argument is None, the set is understood to contain all
        elements and the function always returns a True tensor.
    if string:
        return make_set_filter_fn(list(map(int, string.split(","))))
        return lambda x: tf.constant(True)