import numpy as np
import scipy.io as sio
import argparse
import shutil
import os
import platform
from torch.autograd import Variable
import torch
import config
import glob
import sys
import time
import torchvision.utils as vutils

def tboard_add_img(handle,img,title,niter):
    img = vutils.make_grid(img, normalize=False, scale_each=True)
    handle.add_image(title, img, niter)

def debug():
    import ipdb
    ipdb.set_trace()

def env():
    import getpass
    if os.path.exists('/scratch'):
        return 'eldar'
    elif getpass.getuser() == 'zhenpei':
        return 'qhgroup-desktopv'
    elif getpass.getuser() == 'yzp12':
        return 'graphicsai01'
    return

def env_display():
    # whether have display enviroment
    return 'DISPLAY' in os.environ

def import_matplotlib():
    if env_display():
        import matplotlib
    else:
        import matplotlib
        matplotlib.use('Agg')
    import matplotlib.pyplot as plt


def adjust_learning_rate(optimizer, epoch, baseLR, dropLR, DECAY_LIMIT):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = baseLR * (0.5 ** (epoch // dropLR))
    if lr < DECAY_LIMIT:
        lr = DECAY_LIMIT
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    message = 'current learning rate: {0}'.format(lr)
    return message


def get_latest_model(path, identifier):
    models = glob.glob('{0}/*{1}*'.format(path, identifier))
    epoch = [int(model.split('_')[-1].split('.')[0]) for model in models]
    ind = np.array(epoch).argsort()
    models = [models[i] for i in ind]
    return models[-1]

def parse_epoch(path):
    epoch = int(path.split('_')[-1][:-4])
    return epoch

def resume(keyNet, EXP_DIR_PARAMS, key_word):
    try:
        net_path = get_latest_model(EXP_DIR_PARAMS, key_word)
        state = torch.load(net_path)
        keyNet.load_state_dict(state['state_dict'])
        epoch = state['epoch']
        return epoch, net_path, True
    except:
        return None, None, False
        pass

global counting
counting = 0
def variable_hook(grad):
    grad_ = grad.data.cpu().numpy()
    print('variable hook')
    print(np.mean(abs(grad_.flatten())))
    return grad*.1

def parameters_count(net, name):
    
    model_parameters = filter(lambda p: p.requires_grad, net.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print('total parameters for %s: %d' % (name, params))

def initialize_parser():    
    parser = argparse.ArgumentParser(description='Optional app description')
    parser.add_argument('--rm', action='store_true',
                    help='remove the experiment folder if exists')
    parser.add_argument('--exp', help='add identifier for this experiment')
    parser.add_argument('--param_id', help='parameter identifier')
    parser.add_argument('--resume', action='store_true', help='if specified, resume certain training')
    parser.add_argument('--d', help='specify gpu device')
    parser.add_argument('--g', action='store_true', help='if specified, produce more detailed training log')
    parser.add_argument('--da', help='specify the dataset')
    return parser
    
def initialize_experiment_directories(args):
    args.DETAILED = args.g
    args.EXPERIMENT_INDEX = args.exp if args.exp else '404'
    args.EXP_BASE_DIR = './experiments/exp_' + args.EXPERIMENT_INDEX
    if args.param_id:
        if args.repeat is not None:
            args.EXP_BASE_REPEAT_DIR = os.path.join('./experiments/exp_' + args.EXPERIMENT_INDEX,'repeat_{0}'.format(args.repeat))
            args.EXP_DIR = os.path.join(args.EXP_BASE_REPEAT_DIR, args.param_id)
        else:
            args.EXP_DIR = os.path.join(args.EXP_BASE_DIR, args.param_id)
    else:
        if args.repeat is not None:
            args.EXP_BASE_REPEAT_DIR = os.path.join('./experiments/exp_' + args.EXPERIMENT_INDEX,'repeat_{0}'.format(args.repeat))
            args.EXP_DIR = args.EXP_BASE_REPEAT_DIR
        else:
            args.EXP_DIR = args.EXP_BASE_DIR
    args.EXP_DIR_SAMPLES = args.EXP_DIR + '/samples'
    args.EXP_DIR_PARAMS = args.EXP_DIR + '/params'
    args.EXP_DIR_LOG = os.path.join(args.EXP_DIR, 'exp_{}.csv'.format(args.EXPERIMENT_INDEX))
    if args.d:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.d
    else:
        if env()=='graphicsai01':
            os.environ["CUDA_VISIBLE_DEVICES"]='0,1,2'
    for i in range(torch.cuda.device_count()):
        print("detected gpu:{}\n".format(torch.cuda.get_device_name(i)))
    validate_and_execute_arguments(args)
    return args

def platform_specific_initialization(args):
    args.env = env()
    if args.env == 'eldar':
        pass
    elif args.env == 'graphicsai01':
        pass
    elif args.env == 'qhgroup-desktopv':
        import ssl
        ssl._create_default_https_context = ssl._create_unverified_context
    else:
        pass

def save_config(config, path):
    skip_keys = ['__builtins__', '__doc__', '__name__', '__package__', '__file__']
    keys = dir(config)
    with open(path, 'w') as f:
        for i,key in enumerate(keys):
            if key in skip_keys:
                continue
            value = eval('config.{0}'.format(key))
            f.write('{0} {1}\n'.format(key, value))

def validate_and_execute_arguments(args):
    # cannot set remove and resume both true
    assert(not (args.rm and args.resume))
    # if not specified rm and resume, then the folder must not exists
    if not args.rm and not args.resume:
        assert(not os.path.exists(args.EXP_DIR))
    try:
        if args.rm:
            shutil.rmtree(args.EXP_DIR)
    except:
        pass
    try:
        if args.repeat or args.param_id:
            if not os.path.exists(args.EXP_BASE_DIR):
                os.makedirs(args.EXP_BASE_DIR)
    except:
        pass
    try:
        if args.repeat:
            if not os.path.exists(args.EXP_BASE_REPEAT_DIR):
                os.makedirs(args.EXP_BASE_REPEAT_DIR)
    except:
        pass

    try:
        if not os.path.exists(args.EXP_DIR):
            os.makedirs(args.EXP_DIR)
    except:
        pass
    try:
        if not os.path.exists(args.EXP_DIR_SAMPLES):
            os.makedirs(args.EXP_DIR_SAMPLES)
    except:
        pass
    try:
        if not os.path.exists(args.EXP_DIR_PARAMS):
            os.makedirs(args.EXP_DIR_PARAMS)
    except:
        pass


def get_size(obj, seen=None):
    """Recursively finds size of objects"""
    size = sys.getsizeof(obj)
    if seen is None:
        seen = set()
    obj_id = id(obj)
    if obj_id in seen:
        return 0
    # Important mark as seen *before* entering recursion to gracefully handle
    # self-referential objects
    seen.add(obj_id)
    if isinstance(obj, dict):
        size += sum([get_size(v, seen) for v in obj.values()])
        size += sum([get_size(k, seen) for k in obj.keys()])
    elif hasattr(obj, '__dict__'):
        size += get_size(obj.__dict__, seen)
    elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
        size += sum([get_size(i, seen) for i in obj])
    return size


def decay_learning_rate(LEARNING_RATE, decay, DECAY_LIMIT):
    LEARNING_RATE = LEARNING_RATE / decay
    if LEARNING_RATE < DECAY_LIMIT:
        LEARNING_RATE = DECAY_LIMIT
    return LEARNING_RATE