import numpy as np
import random
import os
from glob import glob
import matplotlib
import matplotlib.pyplot as plt
from skimage import io
from skimage.filters.rank import entropy
from skimage.morphology import disk
import progressbar
from sklearn.feature_extraction.image import extract_patches_2d

progress = progressbar.ProgressBar(widgets=[progressbar.Bar('*', '[', ']'), progressbar.Percentage(), ' '])
np.random.seed(5)


class PatchLibrary(object):
    def __init__(self, patch_size, train_data, num_samples):
        '''
        class for creating patches and subpatches from training data to use as input for segmentation models.
        INPUT   (1) tuple 'patch_size': size (in voxels) of patches to extract. Use (33,33) for sequential model
                (2) list 'train_data': list of filepaths to all training data saved as pngs. images should have shape (5*240,240)
                (3) int 'num_samples': the number of patches to collect from training data.
        '''
        self.patch_size = patch_size
        self.num_samples = num_samples
        self.train_data = train_data
        self.h = self.patch_size[0]
        self.w = self.patch_size[1]

    def find_patches(self, class_num, num_patches):
        '''
        Helper function for sampling slices with evenly distributed classes
        INPUT:  (1) list 'training_images': all training images to select from
                (2) int 'class_num': class to sample from choice of {0, 1, 2, 3, 4}.
                (3) tuple 'patch_size': dimensions of patches to be generated defaults to 65 x 65
        OUTPUT: (1) num_samples patches from class 'class_num' randomly selected.
        '''
        h,w = self.patch_size[0], self.patch_size[1]
        patches, labels = [], np.full(num_patches, class_num, 'float')
        print 'Finding patches of class {}...'.format(class_num)

        ct = 0
        while ct < num_patches:
            im_path = random.choice(self.train_data)
            fn = os.path.basename(im_path)
            label = io.imread('Labels/' + fn[:-4] + 'L.png')

            # resample if class_num not in selected slice
            # while len(np.argwhere(label == class_num)) < 10:
            #     im_path = random.choice(self.train_data)
            #     fn = os.path.basename(im_path)
            #     label = io.imread('Labels/' + fn[:-4] + 'L.png')
            if len(np.argwhere(label == class_num)) < 10:
                continue

            # select centerpix (p) and patch (p_ix)
            img = io.imread(im_path).reshape(5, 240, 240)[:-1].astype('float')
            p = random.choice(np.argwhere(label == class_num))
            p_ix = (p[0]-(h/2), p[0]+((h+1)/2), p[1]-(w/2), p[1]+((w+1)/2))
            patch = np.array([i[p_ix[0]:p_ix[1], p_ix[2]:p_ix[3]] for i in img])

            # resample it patch is empty or too close to edge
            # while patch.shape != (4, h, w) or len(np.unique(patch)) == 1:
            #     p = random.choice(np.argwhere(label == class_num))
            #     p_ix = (p[0]-(h/2), p[0]+((h+1)/2), p[1]-(w/2), p[1]+((w+1)/2))
            #     patch = np.array([i[p_ix[0]:p_ix[1], p_ix[2]:p_ix[3]] for i in img])
            if patch.shape != (4, h, w) or len(np.argwhere(patch == 0)) > (h * w):
                continue

            patches.append(patch)
            ct += 1
        return np.array(patches), labels

    def center_n(self, n, patches):
        '''
        Takes list of patches and returns center nxn for each patch. Use as input for cascaded architectures.
        INPUT   (1) int 'n': size of center patch to take (square)
                (2) list 'patches': list of patches to take subpatch of
        OUTPUT: list of center nxn patches.
        '''
        sub_patches = []
        for mode in patches:
            subs = np.array([patch[(self.h/2) - (n/2):(self.h/2) + ((n+1)/2),(self.w/2) - (n/2):(self.w/2) + ((n+1)/2)] for patch in mode])
            sub_patches.append(subs)
        return np.array(sub_patches)

    def slice_to_patches(self, filename):
        '''
        Converts an image to a list of patches with a stride length of 1. Use as input for image prediction.
        INPUT: str 'filename': path to image to be converted to patches
        OUTPUT: list of patched version of imput image.
        '''
        slices = io.imread(filename).astype('float').reshape(5,240,240)[:-1]
        plist=[]
        for slice in slices:
            if np.max(img) != 0:
                img /= np.max(img)
            p = extract_patches_2d(img, (h,w))
            plist.append(p)
        return np.array(zip(np.array(plist[0]), np.array(plist[1]), np.array(plist[2]), np.array(plist[3])))

    def patches_by_entropy(self, num_patches):
        '''
        Finds high-entropy patches based on label, allows net to learn borders more effectively.
        INPUT: int 'num_patches': defaults to num_samples, enter in quantity it using in conjunction with randomly sampled patches.
        OUTPUT: list of patches (num_patches, 4, h, w) selected by highest entropy
        '''
        patches, labels = [], []
        ct = 0
        while ct < num_patches:
            im_path = random.choice(training_images)
            fn = os.path.basename(im_path)
            label = io.imread('Labels/' + fn[:-4] + 'L.png')

            # pick again if slice is only background
            if len(np.unique(label)) == 1:
                continue

            img = io.imread(im_path).reshape(5, 240, 240)[:-1].astype('float')
            l_ent = entropy(label, disk(self.h))
            top_ent = np.percentile(l_ent, 90)

            # restart if 80th entropy percentile = 0
            if top_ent == 0:
                continue

            highest = np.argwhere(l_ent >= top_ent)
            p_s = random.sample(highest, 3)
            for p in p_s:
                p_ix = (p[0]-(h/2), p[0]+((h+1)/2), p[1]-(w/2), p[1]+((w+1)/2))
                patch = np.array([i[p_ix[0]:p_ix[1], p_ix[2]:p_ix[3]] for i in img])
                # exclude any patches that are too small
                if np.shape(patch) != (4,65,65):
                    continue
                patches.append(patch)
                labels.append(label[p[0],p[1]])
            ct += 1
            return np.array(patches[:num_samples]), np.array(labels[:num_samples])

    def make_training_patches(self, entropy=False, balanced_classes=True, classes=[0,1,2,3,4]):
        '''
        Creates X and y for training CNN
        INPUT   (1) bool 'entropy': if True, half of the patches are chosen based on highest entropy area. defaults to False.
                (2) bool 'balanced classes': if True, will produce an equal number of each class from the randomly chosen samples
                (3) list 'classes': list of classes to sample from. Only change default oif entropy is False and balanced_classes is True
        OUTPUT  (1) X: patches (num_samples, 4_chan, h, w)
                (2) y: labels (num_samples,)
        '''
        if balanced_classes:
            per_class = self.num_samples / len(classes)
            patches, labels = [], []
            progress.currval = 0
            for i in progress(xrange(len(classes))):
                p, l = self.find_patches(classes[i], per_class)
                # set 0 <= pix intensity <= 1
                for img_ix in xrange(len(p)):
                    for slice in xrange(len(p[img_ix])):
                        if np.max(p[img_ix][slice]) != 0:
                            p[img_ix][slice] /= np.max(p[img_ix][slice])
                patches.append(p)
                labels.append(l)
            return np.array(patches).reshape(self.num_samples, 4, self.h, self.w), np.array(labels).reshape(self.num_samples)

        else:
            print "Use balanced classes, random won't work."



if __name__ == '__main__':
    pass