# -*- coding: utf-8 -*-
from __future__ import division
from __future__ import print_function
from builtins import input
from builtins import map
from builtins import str
from builtins import zip
from builtins import range
from past.builtins import basestring
from past.utils import old_div
from builtins import object

import os, sys, pdb, random, collections, pickle, stat, codecs, itertools, shutil, datetime, importlib, requests
import numpy as np
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from math import *
from functools import reduce
from itertools import cycle
from scipy import interp
from sklearn.metrics import *
from os.path import join as pathJoin
from os.path import exists as pathExists

###############################################################################
# Description:
#    This is a collection of general utility / helper functions.
#
# Typical meaning of variable names:
#    lines,strings = list of strings
#    line,string   = single string
#    table         = 2D row/column matrix implemented using a list of lists
#    row,list1D    = single row in a table, i.e. single 1D-list
#    rowItem       = single item in a row
#    list1D        = list of items, not necessarily strings
#    item          = single item of a list1D
###############################################################################


#################################################
# File access
#################################################
def readFile(inputFile):
    # Comment from Python 2: reading as binary, to avoid problems with end-of-text
    #    characters. Note that readlines() does not remove the line ending characters
    with open(inputFile,'rb') as f:
        lines = f.readlines()
        #lines = [unicode(l.decode('latin-1')) for l in lines]  convert to uni-code
    return [removeLineEndCharacters(s.decode('utf8')) for s in lines];

def readBinaryFile(inputFile):
    with open(inputFile,'rb') as f:
        bytes = f.read()
    return bytes

def readPickle(inputFile):
    with open(inputFile, 'rb') as filePointer:
         data = pickle.load(filePointer)
    return data

def readTable(inputFile, delimiter='\t', columnsToKeep=None):
    # Note: if getting memory errors then use 'readTableFileAccessor' instead
    lines = readFile(inputFile);
    if columnsToKeep != None:
        header = lines[0].split(delimiter)
        columnsToKeepIndices = listFindItems(header, columnsToKeep)
    else:
        columnsToKeepIndices = None;
    return splitStrings(lines, delimiter, columnsToKeepIndices)

def writeFile(outputFile, lines, header=None, encoding=None):
    if encoding == None:
        with open(outputFile,'w') as f:
            if header != None:
                f.write("%s\n" % header)
            for line in lines:
                f.write("%s\n" % line)
    else:
        with codecs.open(outputFile, 'w', encoding) as f:  # e.g. encoding=utf-8
            if header != None:
                f.write("%s\n" % header)
            for line in lines:
                f.write("%s\n" % line)

def writeTable(outputFile, table, header=None):
    lines = tableToList1D(table)
    writeFile(outputFile, lines, header)

def writeBinaryFile(outputFile, data):
    with open(outputFile,'wb') as f:
        bytes = f.write(data)
    return bytes

def writePickle(outputFile, data):
    p = pickle.Pickler(open(outputFile,"wb"))
    p.fast = True
    p.dump(data)

def getFilesInDirectory(directory, postfix = ""):
    if not os.path.exists(directory):
        return []
    fileNames = [s for s in os.listdir(directory) if not os.path.isdir(directory+"/"+s)]
    if not postfix or postfix == "":
        return fileNames
    else:
        return [s for s in fileNames if s.lower().endswith(postfix)]

def getFilesInSubdirectories(directory, postfix = ""):
    paths = []
    for subdir in getDirectoriesInDirectory(directory):
        for filename in getFilesInDirectory(os.path.join(directory, imgSubdir), postfix):
            paths.append(os.path.join(directory, subdir, filename))
    return paths

def getDirectoriesInDirectory(directory):
    return [s for s in os.listdir(directory) if os.path.isdir(directory+"/"+s)]

def makeDirectory(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

def makeOrClearDirectory(directory):
    # Note: removes just the files in the directory, not recursive
    makeDirectory(directory)
    files = os.listdir(directory)
    for file in files:
        filePath = directory +"/"+ file
        os.chmod(filePath, stat.S_IWRITE )
        if not os.path.isdir(filePath):
            os.remove(filePath)

def removeWriteProtectionInDirectory(directory):
    files = os.listdir(directory)
    for file in files:
        filePath = directory +"/"+ file
        if not os.path.isdir(filePath):
            os.chmod(filePath, stat.S_IWRITE )

def deleteFile(filePath):
    if os.path.exists(filePath):
        os.remove(filePath)

def deleteAllFilesInDirectory(directory, fileEndswithString, boPromptUser = False):
    if boPromptUser:
        userInput = eval(input('--> INPUT: Press "y" to delete files in directory ' + directory + ": "))
        if not (userInput.lower() == 'y' or userInput.lower() == 'yes'):
            print("User input is %s: exiting now." % userInput)
            exit()
    for filename in getFilesInDirectory(directory):
        if fileEndswithString == None or filename.lower().endswith(fileEndswithString):
            deleteFile(directory + "/" + filename)



#################################################
# 1D list
#################################################
def isList(var):
    return isinstance(var, list)

def toIntegers(list1D):
    return [int(float(x)) for x in list1D]

def toRounded(list1D):
    return [round(x) for x in list1D]

def toFloats(list1D):
    return [float(x) for x in list1D]

def toStrings(list1D):
    return [str(x) for x in list1D]

def max2(list1D):
    maxVal = max(list1D)
    indices = [i for i in range(len(list1D)) if list1D[i] == maxVal]
    return maxVal,indices

def pbMax(list1D): # depricated
    return max2(list1D)

def find(list1D, func):
    return [index for (index,item) in enumerate(list1D) if func(item)]

def listSort(list1D, reverseSort=False, comparisonFct=lambda x: x):
    indices = list(range(len(list1D)))
    tmp = sorted(zip(list1D,indices), key=comparisonFct, reverse=reverseSort)
    list1DSorted, sortOrder = list(map(list, list(zip(*tmp))))
    return (list1DSorted, sortOrder) 



#################################################
# 2D list (e.g. tables)
#################################################
def getColumn(table, columnIndex):
    return [row[columnIndex] for row in table]

def getRows(table, rowIndices):    
    return [table[rowIndex] for rowIndex in rowIndices]

def getColumns(table, columnIndices):
    return [[row[i] for i in columnIndices] for row in table]

def sortTable(table, sortColumnIndex, reverseSort=False, comparisonFct=lambda x: float(x[0])):
    if len(table) == 0:
        return []
    list1D = getColumn(table, sortColumnIndex)
    _, sortOrder = listSort(list1D, reverseSort, comparisonFct)
    return [table[i] for i in sortOrder]

def tableToList1D(table, delimiter='\t'):
    return [delimiter.join([str(s) for s in row]) for row in table]



#################################################
# String and chars
#################################################
def isString(var):
    return type(var) == type("")

def numToString(num, length, paddingChar = '0'):
    if len(str(num)) >= length:
        return str(num)[:length]
    else:
        return str(num).ljust(length, paddingChar)

def splitString(string, delimiter='\t', columnsToKeepIndices=None):
    if string == None:
        return None
    items = string.split(delimiter)
    if columnsToKeepIndices != None:
        items = getColumn(items, columnsToKeepIndices)
    return items;

def splitStrings(strings, delimiter, columnsToKeepIndices=None):
    table = [splitString(string, delimiter, columnsToKeepIndices) for string in strings]
    return table;

def removeLineEndCharacters(line):
    if line.endswith('\r\n'):
        return line[:-2]
    elif line.endswith('\n'):
        return line[:-1]
    else:
        return line



#################################################
# Randomize
#################################################
def getRandomNumber(low, high):
    return random.randint(low,high)

def getRandomNumbers(low, high):
    randomNumbers = list(range(low,high+1))
    random.shuffle(randomNumbers)
    return randomNumbers

def getRandomListElement(listND, containsHeader=False):
    if containsHeader:
        index = getRandomNumber(1, len(listND)-1)
    else:
        index = getRandomNumber(0, len(listND)-1)
    return listND[index]

def randomizeList(listND, containsHeader=False):
    if containsHeader:
        header = listND[0]
        listND = listND[1:]
    random.shuffle(listND)
    if containsHeader:
        listND.insert(0, header)
    return listND



#################################################
# Dictionaries
#################################################
def getDictionary(keys, values, boConvertValueToInt = True):
    dictionary = {}
    for key, value in zip(keys, values):
        if boConvertValueToInt:
            value = int(value)
        dictionary[key] = value
    return dictionary

def sortDictionary(dictionary, sortIndex=0, reverseSort=False):
    return sorted(list(dictionary.items()), key=lambda x: x[sortIndex], reverse=reverseSort)

def invertDictionary(dictionary):
    return {v: k for k, v in list(dictionary.items())}

def dictionaryToTable(dictionary):
    return (list(dictionary.items()))

def mergeDictionaries(dict1, dict2):
    tmp = dict1.copy()
    tmp.update(dict2)
    return tmp



#################################################
# Url
#################################################
def downloadFromUrl(url, boVerbose = True):
    data = []
    try:
        r = requests.get(url, timeout=5)
        data = r.content
    except:
        if boVerbose:
            print('Error downloading url {0}'.format(url))
    #if boVerbose and data == []: # and r.status_code != 200:
    #    print('Error {} downloading url {}'.format(r.status_code, url))
    return data



#################################################
# Confusion matrix and p/r curves
# Note: Let C be the confusion matrix. Then C_{i, j} is the number of observations known to be in group i but predicted to be in group j.
#################################################
def cmSanityCheck(confMatrix, gtLabels):
    for i in range(max(gtLabels)+1):
        assert(sum(confMatrix[i,:]) == sum([l == i for l in gtLabels])) 

def cmGetAccuracies(confMatrix, gtLabels = []):
    if gtLabels != []:
        cmSanityCheck(confMatrix, gtLabels)
    return [float(confMatrix[i, i]) / sum(confMatrix[i,:]) for i in range(confMatrix.shape[1])]

def cmPrintAccuracies(confMatrix, classes, gtLabels = []):
    columnWidth = max([len(s) for s in classes])
    accs = cmGetAccuracies(confMatrix, gtLabels)
    for cls, acc in zip(classes, accs):
        print(("Class {:<" + str(columnWidth) + "} accuracy: {:2.2f}%.").format(cls, 100 * acc))
    globalAcc = 100.0 * sum(np.diag(confMatrix)) / sum(sum(confMatrix))
    print("OVERALL accuracy: {:2.2f}%.".format(globalAcc))
    print("OVERALL class-averaged accuracy: {:2.2f}%.".format(100 * np.mean(accs)))
    return globalAcc, accs

def cmPlot(confMatrix, classes, normalize=False, title='Confusion matrix', cmap=[]):
    if normalize:
        confMatrix = confMatrix.astype('float') / confMatrix.sum(axis=1)[:, np.newaxis]
        confMatrix = np.round(confMatrix * 100,1)
    if cmap == []:
        cmap = plt.cm.Blues

    #Actual plotting of the values
    thresh = confMatrix.max() / 2.
    for i, j in itertools.product(range(confMatrix.shape[0]), range(confMatrix.shape[1])):
        plt.text(j, i, confMatrix[i, j], horizontalalignment="center",
                 color="white" if confMatrix[i, j] > thresh else "black")

    avgAcc = np.mean([float(confMatrix[i, i]) / sum(confMatrix[:, i]) for i in range(confMatrix.shape[1])])
    plt.imshow(confMatrix, interpolation='nearest', cmap=cmap)
    plt.title(title + " (avgAcc={:2.2f}%)".format(100*avgAcc))
    plt.colorbar()
    plt.xticks(np.arange(len(classes)), classes, rotation=45)
    plt.yticks(np.arange(len(classes)), classes)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

def rocComputePlotCurves(gtLabels, scoresMatrix, labels):
    #Code taken from Microsoft AML Workbench iris tutorial
    n_classes = len(labels)
    Y_score = scoresMatrix
    Y_onehot = []
    for i in range(len(gtLabels)):
        Y_onehot.append([])
        for j in range(len(labels)):
            Y_onehot[i].append(0)
        Y_onehot[i][gtLabels[i]] = 1
    Y_onehot = np.asarray(Y_onehot)

    fpr = dict()
    tpr = dict()
    thres = dict()
    roc_auc = dict()

    for i in range(n_classes):
        fpr[i], tpr[i], thres[i] = roc_curve(Y_onehot[:, i], Y_score[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    fpr["micro"], tpr["micro"], _ = roc_curve(Y_onehot.ravel(), Y_score.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    # Compute macro-average ROC curve and ROC area

    # First aggregate all false positive rates
    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

    # Then interpolate all ROC curves at this points
    mean_tpr = np.zeros_like(all_fpr)
    for i in range(n_classes):
        mean_tpr += interp(all_fpr, fpr[i], tpr[i])

    # Finally average it and compute AUC
    mean_tpr /= n_classes

    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

    # Plot all ROC curves
    # fig = plt.figure(figsize=(6, 5), dpi=75)
    # set lineweight
    lw = 2

    # plot micro average
    plt.plot(fpr["micro"], tpr["micro"],
             label='micro-average ROC curve (area = {0:0.2f})'
                   ''.format(roc_auc["micro"]),
             color='deeppink', linestyle=':', linewidth=4)

    # plot macro average
    plt.plot(fpr["macro"], tpr["macro"],
             label='macro-average ROC curve (area = {0:0.2f})'
                   ''.format(roc_auc["macro"]),
             color='navy', linestyle=':', linewidth=4)

    # plot ROC for each class
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=lw,
                 label='ROC curve of class {0} (area = {1:0.2f})'
                       ''.format(labels[i], roc_auc[i]))

    # plot diagnal line
    plt.plot([0, 1], [0, 1], 'k--', lw=lw)

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC curve')
    plt.legend(loc="lower right")
    return (fpr, tpr, thres)



#################################################
# Math
#################################################
def intRound(item):
    return int(round(float(item)))

def softmax(vec):
    expVec = np.exp(vec)
    if max(expVec) != np.inf:
        outVec = expVec / np.sum(expVec)
    else:
        # Note: this is a hack to make softmax stable
        outVec = np.zeros(len(expVec))
        outVec[expVec == np.inf] = vec[expVec == np.inf]
        outVec = outVec / np.sum(outVec)
    return outVec

def softmax2D(w):
    # Note: could replace with np.exp(w – max(w)) to make numerically stable
    e = np.exp(w)
    dist = old_div(e, np.sum(e, axis=1)[:, np.newaxis])
    return dist



#################################################
# other
#################################################
def isTuple(var):
    return isinstance(var, tuple)