from math import log10, floor
import torch
from sklearn.utils.linear_assignment_ import linear_assignment

def str2bool(v):
      return v.lower() in ('true', '1')

def cudify(x, use_cuda):
    """
    Args 
        x: input Tensor
        use_cuda: boolean   
    """
    if use_cuda:
        return x.cuda()
    else:
        return x

def loss_dt_check(losses):
    """
    Compute an estimate of the rate 
    of change of the loss given a 
    deque of loss values 
    """
    dt = []
    for i in range(1, len(losses)):
        dt.append(losses[i] - losses[i-1])
    return sum(dt) / (len(losses) - 1)

def copy_model_params(source, target):
    target_p = list(target.parameters())
    source_p = list(source.parameters())
    n = len(source_p)
    for i in range(n):
        target_p[i].data[:] = source_p[i].data[:]

def round_to_2(x):
    return round(x, -int(floor(log10(abs(x))))+1)

def byte_tensor_to_index(x):
    """
    Convert a torch.ByteTensor of size [batch_size, 1]
    to a torch.LongTensor containing only nonzero elements
    of x converted to the corresponding index values.
    """
    idx_tensor = []
    for i in range(x.size()[0]):
        if x.data[i][0] == 1:
            idx_tensor.append(i)
    return torch.LongTensor(idx_tensor)

def logsumexp(inputs, dim=None, keepdim=False):
    """Numerically stable logsumexp.

    Args:
        inputs: A Variable with any shape.
        dim: An integer.
        keepdim: A boolean.

    Returns:
        Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)).
    """
    # For a 1-D array x (any array along a single dimension),
    # log sum exp(x) = s + log sum exp(x - s)
    # with s = max(x) being a common choice.
    if dim is None:
        inputs = inputs.view(-1)
        dim = 0
    s, _ = torch.max(inputs, dim=dim, keepdim=True)
    outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
    if not keepdim:
        outputs = outputs.squeeze(dim)
    return outputs

def parallel_matching(batch):
    perms = []
    (m, n, n) = batch.shape
    for i in range(m):
        perm = torch.zeros(n, n)
        matching = linear_assignment(-batch[i])
        perm[matching[:,0], matching[:,1]] = 1
        perms.append(perm)
    return perms


def memory_usage():
    return ((int(open('/proc/self/statm').read().split()[1]) * 4096.) / 1000000.)

if __name__ == '__main__': 
    torch.random.manual_seed(1)
    ones = torch.ones(3,1)
    x = torch.zeros(3,3).uniform_()
    x = torch.exp(x / 0.1)
    #x = torch.ones(3,3)
    print(x)

    log_scale_res_2 = torch.log(x) - logsumexp(x, dim=0, keepdim=True)
    print("Log scale res 2 {}".format(log_scale_res_2))
    print("exp(log_scale_res_2): {}".format(torch.exp(log_scale_res_2)))

    rn = torch.div(x , torch.matmul(torch.matmul(x,ones), torch.t(ones)))
    print("Unstable: {}".format(rn))