import os
import errno
import numpy as np
import scipy
import scipy.misc

def mkdir_p(path):

    try:
        os.makedirs(path)
    except OSError as exc:  # Python >2.5
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else:
            raise

def get_image(image_path , image_size , is_crop=True, resize_w=64 , is_grayscale = False):
    return transform(imread(image_path , is_grayscale), image_size, is_crop , resize_w)

def transform(image, npx=64 , is_crop=False, resize_w=64):
    # npx : # of pixels width/height of image
    if is_crop:
        cropped_image = center_crop(image , npx , resize_w = resize_w)
    else:
        cropped_image = image
        cropped_image = scipy.misc.imresize(cropped_image ,
                            [resize_w , resize_w])
    return np.array(cropped_image)/127.5 - 1

def center_crop(x, crop_h, crop_w=None, resize_w=64):

    if crop_w is None:
        crop_w = crop_h
    h, w = x.shape[:2]
    j = int(round((h - crop_h)/2.))
    i = int(round((w - crop_w)/2.))

    rate = np.random.uniform(0, 1, size=1)

    if rate < 0.5:
        x = np.fliplr(x)

    #first crop tp 178x178 and resize to 128x128
    return scipy.misc.imresize(x[20:218-20, 0: 178], [resize_w, resize_w])

    #Another cropped method

    # return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w],
    #                            [resize_w, resize_w])

def save_images(images, size, image_path):
    return imsave(inverse_transform(images), size, image_path)

def imread(path, is_grayscale=False):

    if (is_grayscale):
        return scipy.misc.imread(path, flatten=True).astype(np.float)
    else:
        return scipy.misc.imread(path).astype(np.float)

def imsave(images, size, path):
    return scipy.misc.imsave(path, merge(images, size))

def merge(images, size):

    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j * h:j * h + h, i * w: i * w + w, :] = image

    return img

def inverse_transform(image):
    return ((image + 1)* 127.5).astype(np.uint8)

def read_image_list(category):

    filenames = []
    print("list file")
    list = os.listdir(category)
    list.sort()
    for file in list:
        if 'jpg' or 'png' in file:
            filenames.append(category + "/" + file)
    print("list file ending!")

    length = len(filenames)
    perm = np.arange(length)
    np.random.shuffle(perm)
    filenames = np.array(filenames)
    filenames = filenames[perm]

    return filenames

class CelebA(object):

    def __init__(self, images_path, image_size, attri_id):

        self.dataname = "CelebA"
        self.dims = image_size*image_size
        self.shape = [image_size, image_size, 3]
        self.image_size = image_size
        self.channel = 3
        self.images_path = images_path
        self.attri_id = attri_id
        self.dom_1_train_data_list, self.dom_2_train_data_list = self.load_celebA()
        self.train_len = len(self.dom_1_train_data_list)

    def load_celebA(self):

        # get the list of image path
        return read_image_list_file(self.images_path, is_test= False, attri_id= self.attri_id)

    def load_test_celebA(self):

        # get the list of image path
        return read_image_list_file(self.images_path, is_test= True, attri_id= self.attri_id)

    def getShapeForData(self, filenames):

        array = [get_image(batch_file, 128, is_crop=True, resize_w=self.image_size,
                           is_grayscale=False) for batch_file in filenames]
        sample_images = np.array(array)

        return sample_images

    def getTestNextBatch(self, batch_num=0, batch_size=64):

        ro_num = len(self.test_data_list) / batch_size
        if batch_num % ro_num == 0:

            length = len(self.test_data_list)
            perm = np.arange(length)
            np.random.shuffle(perm)
            self.test_data_list = np.array(self.test_data_list)
            self.test_data_list = self.test_data_list[perm]
            self.test_lab_list = np.array(self.test_lab_list)
            self.test_lab_list = self.test_lab_list[perm]

        return self.test_data_list[(batch_num % ro_num) * batch_size: (batch_num % ro_num + 1) * batch_size], \
               self.test_lab_list[(batch_num % ro_num) * batch_size: (batch_num % ro_num + 1) * batch_size]

def read_image_list_file(category, is_test, attri_id):

    end_num = 0
    start_num = 5001

    dom_1_list_image = []
    dom_2_list_image = []

    lines = open(category + "../" + "list_attr_celeba.txt")
    li_num = 0
    for line in lines:

        if li_num < start_num:
            li_num += 1
            continue

        if li_num >= end_num and is_test == True:
            break

        flag = line.split('1 ', 41)[attri_id]  # get the label for gender
        file_name = line.split(' ', 1)[0]

        # print flag
        if flag == ' ':
            dom_1_list_image.append(category + file_name)
            
        else:
            dom_2_list_image.append(category + file_name)

        li_num += 1

    lines.close()

    #keep the balance of the dataset.
    if len(dom_1_list_image) > len(dom_2_list_image):
        dom_1_list_image = dom_1_list_image[0:len(dom_2_list_image)]
    else:
        dom_2_list_image = dom_2_list_image[0:len(dom_1_list_image)]

    return dom_1_list_image, dom_2_list_image