#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>

from ..utils.naming import *
import tensorflow as tf
from copy import copy
import six
from contextlib import contextmanager

__all__ = ['get_default_sess_config',
           'get_global_step',
           'get_global_step_var',
           'get_op_var_name',
           'get_vars_by_names',
           'backup_collection',
           'restore_collection',
           'clear_collection',
           'freeze_collection']

def get_default_sess_config(mem_fraction=0.9):
    """
    Return a better session config to use as default.
    Tensorflow default session config consume too much resources.

    :param mem_fraction: fraction of memory to use.
    :returns: a `tf.ConfigProto` object.
    """
    conf = tf.ConfigProto()
    conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction
    conf.gpu_options.allocator_type = 'BFC'
    conf.gpu_options.allow_growth = True
    conf.allow_soft_placement = True
    #conf.log_device_placement = True
    return conf

def get_global_step_var():
    """ :returns: the global_step variable in the current graph. create if not existed"""
    try:
        return tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
    except KeyError:
        scope = tf.get_variable_scope()
        assert scope.name == '', \
                "Creating global_step_var under a variable scope would cause problems!"
        var = tf.Variable(
            0, trainable=False, name=GLOBAL_STEP_OP_NAME)
        return var

def get_global_step():
    """ :returns: global_step value in current graph and session"""
    return tf.train.global_step(
        tf.get_default_session(),
        get_global_step_var())

def get_op_var_name(name):
    """
    Variable name is assumed to be ``op_name + ':0'``

    :param name: an op or a variable name
    :returns: (op_name, variable_name)
    """
    if name.endswith(':0'):
        return name[:-2], name
    else:
        return name, name + ':0'

def get_vars_by_names(names):
    """
    Get a list of variables in the default graph by a list of names
    """
    ret = []
    G = tf.get_default_graph()
    for n in names:
        opn, varn = get_op_var_name(n)
        ret.append(G.get_tensor_by_name(varn))
    return ret

def backup_collection(keys):
    ret = {}
    for k in keys:
        ret[k] = copy(tf.get_collection(k))
    return ret

def restore_collection(backup):
    for k, v in six.iteritems(backup):
        del tf.get_collection_ref(k)[:]
        tf.get_collection_ref(k).extend(v)

def clear_collection(keys):
    for k in keys:
        del tf.get_collection_ref(k)[:]

@contextmanager
def freeze_collection(keys):
    backup = backup_collection(keys)
    yield
    restore_collection(backup)