import os
import json
from math import ceil

import re
import sys
import gc
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
from collections import OrderedDict

root = os.path.join(os.path.dirname(__file__), '..')


def get_root():
    return root


def get_global_opts():
    # Default options:
    opts = {
        'result_path': '/example/path',
        'cityscapes_path': '/example/path',
        'vistas_path': '/example/path',
        'wildash_root_path': '/example/path',

        'robotcar_root_path': '/example/path',
        'robotcar_corr_path': '/example/path',
        'robotcar_im_path': '/example/path',

        'cmu_root_path': '/example/path',
        'cmu_corr_path': '/example/path',
        'cmu_im_path': '/example/path'
    }

    global_opts_path = os.path.join(get_root(), 'global_opts.json')
    if os.path.exists(global_opts_path):
        with open(global_opts_path, 'r') as opts_file:
            json_opts = json.load(opts_file)
        opts.update(json_opts)
    return opts


def rename_key_of_ordered_dict(ordered_dict, old_name, new_name):
    return OrderedDict([(new_name, v) if k == old_name else (k, v)
                        for k, v in ordered_dict.items()])


def rename_keys_to_match(state_dict):
    state_dict = rename_key_of_ordered_dict(
        state_dict, 'final.conv6.bias', 'conv6.bias')
    state_dict = rename_key_of_ordered_dict(
        state_dict, 'final.conv6.weight', 'conv6.weight')
    state_dict = rename_key_of_ordered_dict(
        state_dict, 'aux.conv6_1.bias', 'conv6_1.bias')
    state_dict = rename_key_of_ordered_dict(
        state_dict, 'aux.conv6_1.weight', 'conv6_1.weight')
    return state_dict


def replace_root(old_path, old_root, new_root):
    assert old_path[:len(old_root)] == old_root
    # Ensure same format (trailing slash for both or neither)
    assert old_root.endswith(os.sep) == new_root.endswith(os.sep)
    relpath = old_path[len(old_root):]
    new_path = new_root + relpath
    return new_path


def replace_suffix(old_str, old_suffix, new_suffix):
    assert old_str[-len(old_suffix):] == old_suffix
    return old_str[:-len(old_suffix)] + new_suffix


def absorb_bn(module, bn_module):
    w = module.weight.data
    if module.bias is None:
        zeros = torch.Tensor(module.out_channels).zero_().type(w.type())
        module.bias = nn.Parameter(zeros)
    b = module.bias.data
    invstd = bn_module.running_var.clone().add_(bn_module.eps).pow_(-0.5)
    w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w))
    b.add_(-bn_module.running_mean).mul_(invstd)

    if bn_module.affine:
        w.mul_(bn_module.weight.data.view(w.size(0), 1, 1, 1).expand_as(w))
        b.mul_(bn_module.weight.data).add_(bn_module.bias.data)

    bn_module.register_buffer('running_mean', None)
    bn_module.register_buffer('running_var', None)
    bn_module.register_parameter('weight', None)
    bn_module.register_parameter('bias', None)
    bn_module.affine = False


def is_bn(m):
    return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)


def is_absorbing(m):
    return isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear)


def search_absorbe_bn(model):
    prev = None
    for m in model.children():
        if is_bn(m) and is_absorbing(prev):
            absorb_bn(prev, m)
        search_absorbe_bn(m)
        prev = m


def add_bias(net):
    for module in net.modules():
        if (isinstance(module, nn.Conv2d) or isinstance(
                module, nn.Linear)) and module.bias is None:
            w = module.weight.data
            zeros = torch.Tensor(module.out_channels).zero_().type(w.type())
            module.bias = nn.Parameter(zeros)


def freeze_bn(net):
    """ Freezes batchnorm modules during training, useful when training with small batches"""
    for module in net.modules():
        if isinstance(module, torch.nn.modules.BatchNorm1d):
            module.eval()
        if isinstance(module, torch.nn.modules.BatchNorm2d):
            module.eval()
        if isinstance(module, torch.nn.modules.BatchNorm3d):
            module.eval()


def check_mkdir(dir_name):
    os.makedirs(dir_name, exist_ok=True)


def initialize_weights(*models):
    for model in models:
        for module in model.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()


def get_upsampling_weight(in_channels, out_channels, kernel_size):
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = np.ogrid[:kernel_size, :kernel_size]
    filt = (1 - abs(og[0] - center) / factor) * \
        (1 - abs(og[1] - center) / factor)
    weight = np.zeros(
        (in_channels,
         out_channels,
         kernel_size,
         kernel_size),
        dtype=np.float64)
    weight[list(range(in_channels)), list(range(out_channels)), :, :] = filt
    return torch.from_numpy(weight).float()


def _fast_hist(label_pred, label_true, num_classes):
    mask = (label_true >= 0) & (label_true < num_classes)
    hist = np.bincount(
        num_classes * label_true[mask].astype(int) +
        label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes)
    return hist

# predictions and gts should have size (N,W,H), where N is the number of
# images in a batch


def evaluate_incremental(hist, predictions, gts, num_classes):
    #hist = np.zeros((num_classes, num_classes))
    for lp, lt in zip(predictions, gts):
        hist += _fast_hist(lp.flatten(), lt.flatten(), num_classes)
    # axis 0: gt, axis 1: prediction
    present_classes = hist.sum(axis=1) != 0

    acc = np.diag(hist).sum() / hist.sum()
    acc_cls = np.diag(hist) / hist.sum(axis=1)

    #acc_cls = np.nanmean(acc_cls)
    acc_cls = np.mean(acc_cls[present_classes])

    iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))

    #mean_iu = np.nanmean(iu)
    mean_iu = np.mean(iu[present_classes])

    freq = hist.sum(axis=1) / hist.sum()
    fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
    return acc, acc_cls, mean_iu, fwavacc, hist


def evaluate(predictions, gts, num_classes):
    hist = np.zeros((num_classes, num_classes))
    for lp, lt in zip(predictions, gts):
        hist += _fast_hist(lp.flatten(), lt.flatten(), num_classes)
    # axis 0: gt, axis 1: prediction
    present_classes = hist.sum(axis=1) != 0

    acc = np.diag(hist).sum() / hist.sum()
    acc_cls = np.diag(hist) / hist.sum(axis=1)
    #acc_cls = np.nanmean(acc_cls)
    acc_cls = np.mean(acc_cls[present_classes])
    iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))

    #mean_iu = np.nanmean(iu)
    mean_iu = np.mean(iu[present_classes])

    freq = hist.sum(axis=1) / hist.sum()
    fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
    return acc, acc_cls, mean_iu, fwavacc, hist


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class PolyLR(object):
    def __init__(self, optimizer, curr_iter, max_iter, lr_decay):
        self.max_iter = float(max_iter)
        self.init_lr_groups = []
        for p in optimizer.param_groups:
            self.init_lr_groups.append(p['lr'])
        self.param_groups = optimizer.param_groups
        self.curr_iter = curr_iter
        self.lr_decay = lr_decay

    def step(self):
        for idx, p in enumerate(self.param_groups):
            p['lr'] = self.init_lr_groups[idx] * \
                (1 - self.curr_iter / self.max_iter) ** self.lr_decay


def get_latest_network_name(folder_path):
    net_to_load = ''
    max_it = 0
    for f in os.listdir(folder_path):
        rem = re.match(r'^iter_(\d+)_acc.*.pth$', f)
        if rem:
            it = int(rem.group(1))
            if it > max_it:
                net_to_load = rem.group(0)
                max_it = it

    return net_to_load


def collect_gt_from_slices(gt_slices, slices_info):
    imsize1 = slices_info[0, :, 1].max().item()
    imsize2 = slices_info[0, :, 3].max().item()

    gts_tmp = np.zeros((1, imsize1, imsize2), dtype=int)
    gt.transpose_(0, 1)
    slices_info.squeeze_(0)
    count = torch.zeros(imsize1, imsize2)
    for gt_slice, info in zip(input, gt, slices_info):
        gts_tmp[0, info[0]: info[1], info[2]: info[3]
                ] += gt_slice[0, :info[4], :info[5]].data.numpy()
        count[info[0]: info[1], info[2]: info[3]] += 1

    output /= count
    gts_tmp //= count.numpy().astype(int)
    return gts_tmp

# removes all log entries after last validation point before continuing


def clean_log_before_continuing(log_path, last_val_iter):
    pat = re.compile(r"\[iter (\d+) / (\d+)\]")
    lines_to_keep = []
    with open(log_path) as f:
        for line in f:
            mm = pat.match(line)
            if mm:
                this_iter = int(mm.group(1))
                if this_iter > last_val_iter:
                    break
            lines_to_keep.append(line)
    with open(log_path, 'w') as f:
        for line in lines_to_keep:
            f.write(line)