from os import mkdir
from os.path import exists, join
from collections import defaultdict

import pylab
from sklearn.datasets import fetch_lfw_people
from sklearn.impute import IterativeImputer
import numpy as np

from fancyimpute import (
    SimpleFill,
    IterativeSVD,
    SoftImpute,
    BiScaler,
    KNN
)

from fancyimpute.common import masked_mae, masked_mse


def remove_pixels(
        full_images,
        missing_square_size=32,
        random_seed=0):
    np.random.seed(random_seed)
    incomplete_faces = []
    n_faces = len(full_images)
    height, width = full_images[0].shape[:2]
    for i in range(n_faces):
        image = full_images[i].copy()
        start_x = np.random.randint(
            low=0,
            high=height - missing_square_size + 1)
        start_y = np.random.randint(
            low=0,
            high=width - missing_square_size + 1)
        image[
            start_x: start_x + missing_square_size,
            start_y: start_y + missing_square_size] = np.nan
        incomplete_faces.append(image)
    return np.array(incomplete_faces, dtype=np.float32)


def rescale_pixel_values(images, order="C"):
    """
    Rescale the range of values in images to be between [0, 1]
    """
    images = np.asarray(images, order=order).astype("float32")
    images -= images.min()
    images /= images.max()
    return images


def color_balance(images):
    images = images.astype("float32")
    red = images[:, :, :, 0]
    green = images[:, :, :, 1]
    blue = images[:, :, :, 2]
    combined = (red + green + blue)
    total_color = combined.sum()
    overall_fraction_red = red.sum() / total_color
    overall_fraction_green = green.sum() / total_color
    overall_fraction_blue = blue.sum() / total_color

    for i in range(images.shape[0]):
        image = images[i]
        image_total = combined[i].sum()
        red_scale = overall_fraction_red / (red[i].sum() / image_total)
        green_scale = overall_fraction_green / (green[i].sum() / image_total)
        blue_scale = overall_fraction_blue / (blue[i].sum() / image_total)
        image[:, :, 0] *= red_scale
        image[:, :, 1] *= green_scale
        image[:, :, 2] *= blue_scale
    image[image < 0] = 0
    image[image > 255] = 255
    return images


class ResultsTable(object):

    def __init__(
            self,
            images_dict,
            percent_missing=0.25,
            saved_image_stride=25,
            dirname="face_images",
            scale_rows=False,
            center_rows=False):
        self.images_dict = images_dict
        self.labels = list(sorted(images_dict.keys()))
        self.images_array = np.array(
            [images_dict[k] for k in self.labels]).astype("float32")
        self.image_shape = self.images_array[0].shape
        self.width, self.height = self.image_shape[:2]
        self.color = (len(self.image_shape) == 3) and (self.image_shape[2] == 3)
        if self.color:
            self.images_array = color_balance(self.images_array)
        self.n_pixels = self.width * self.height
        self.n_features = self.n_pixels * (3 if self.color else 1)
        self.n_images = len(self.images_array)
        print("[ResultsTable] # images = %d, color=%s # features = %d, shape = %s" % (
            self.n_images, self.color, self.n_features, self.image_shape))

        self.flattened_array_shape = (self.n_images, self.n_features)

        self.flattened_images = self.images_array.reshape(self.flattened_array_shape)

        n_missing_pixels = int(self.n_pixels * percent_missing)

        missing_square_size = int(np.sqrt(n_missing_pixels))
        print("[ResultsTable] n_missing_pixels = %d, missing_square_size = %d" % (
            n_missing_pixels, missing_square_size))
        self.incomplete_images = remove_pixels(
            self.images_array,
            missing_square_size=missing_square_size)
        print("[ResultsTable] Incomplete images shape = %s" % (
            self.incomplete_images.shape,))
        self.flattened_incomplete_images = self.incomplete_images.reshape(
            self.flattened_array_shape)
        self.missing_mask = np.isnan(self.flattened_incomplete_images)
        self.normalizer = BiScaler(
            scale_rows=scale_rows,
            center_rows=center_rows,
            min_value=self.images_array.min(),
            max_value=self.images_array.max())
        self.incomplete_normalized = self.normalizer.fit_transform(
            self.flattened_incomplete_images)

        self.saved_image_indices = list(
            range(0, self.n_images, saved_image_stride))
        self.saved_images = defaultdict(dict)
        self.dirname = dirname
        self.mse_dict = {}
        self.mae_dict = {}

        self.save_images(self.images_array, "original", flattened=False)
        self.save_images(self.incomplete_images, "incomplete", flattened=False)

    def ensure_dir(self, dirname):
        if not exists(dirname):
            print("Creating directory: %s" % dirname)
            mkdir(dirname)

    def save_images(self, images, base_filename, flattened=True):
        self.ensure_dir(self.dirname)
        for i in self.saved_image_indices:
            label = self.labels[i].lower().replace(" ", "_")
            image = images[i, :].copy()
            if flattened:
                image = image.reshape(self.image_shape)
            image[np.isnan(image)] = 0
            figure = pylab.gcf()
            axes = pylab.gca()
            extra_kwargs = {}
            if self.color:
                extra_kwargs["cmap"] = "gray"
            assert image.min() >= 0, "Image can't contain negative numbers"
            if image.max() <= 1:
                image *= 256
            image[image > 255] = 255
            axes.imshow(image.astype("uint8"), **extra_kwargs)
            axes.get_xaxis().set_visible(False)
            axes.get_yaxis().set_visible(False)
            filename = base_filename + ".png"
            subdir = join(self.dirname, label)
            self.ensure_dir(subdir)
            path = join(subdir, filename)
            figure.savefig(
                path,
                bbox_inches='tight')
            self.saved_images[i][base_filename] = path

    def add_entry(self, solver, name):
        print("Running %s" % name)
        completed_normalized = solver.fit_transform(self.incomplete_normalized)
        completed = self.normalizer.inverse_transform(completed_normalized)

        mae = masked_mae(
            X_true=self.flattened_images,
            X_pred=completed,
            mask=self.missing_mask)
        mse = masked_mse(
            X_true=self.flattened_images,
            X_pred=completed,
            mask=self.missing_mask)
        print("==> %s: MSE=%0.4f MAE=%0.4f" % (name, mse, mae))
        self.mse_dict[name] = mse
        self.mae_dict[name] = mae
        self.save_images(completed, base_filename=name)

    def sorted_errors(self):
        """
        Generator for (rank, name, MSE, MAE) sorted by increasing MAE
        """
        for i, (name, mae) in enumerate(
                sorted(self.mae_dict.items(), key=lambda x: x[1])):
            yield(i + 1, name, self.mse_dict[name], self.mae_dict[name],)

    def print_sorted_errors(self):
        for (rank, name, mse, mae) in self.sorted_errors():
            print("%d) %s: MSE=%0.4f MAE=%0.4f" % (
                rank,
                name,
                mse,
                mae))

    def save_html_table(self, filename="results_table.html"):
        html = """
            <table>
            <th>
                <td>Rank</td>
                <td>Name</td>
                <td>Mean Squared Error</td>
                <td>Mean Absolute Error</td>
            </th>
        """
        for (rank, name, mse, mae) in self.sorted_errors():
            html += """
            <tr>
                <td>%d</td>
                <td>%s</td>
                <td>%0.4f</td>
                <td>%0.4f</td>
            </tr>
            """ % (rank, name, mse, mae)
        html += "</table>"
        self.ensure_dir(self.dirname)
        path = join(self.dirname, filename)
        with open(path, "w") as f:
            f.write(html)
        return html


def image_per_label(images, label_indices, label_names, max_size=2000):
    groups = defaultdict(list)
    for i, label_idx in enumerate(label_indices):
        label = label_names[label_idx].lower().strip().replace(" ", "_")
        groups[label].append(images[i])

    # as a pretty arbitrary heuristic, let's try taking the min variance
    # image for each person
    singe_images = {}
    for label, images in sorted(groups.items()):
        singe_images[label] = min(images, key=lambda image: image.std())
        if max_size and len(singe_images) >= max_size:
            break
    return singe_images


def get_lfw(max_size=None):
    dataset = fetch_lfw_people(color=True)
    # keep only one image per person
    return image_per_label(
        dataset.images,
        dataset.target,
        dataset.target_names,
        max_size=max_size)

if __name__ == "__main__":
    images_dict = get_lfw(max_size=2000)
    table = ResultsTable(
        images_dict=images_dict,
        scale_rows=False,
        center_rows=False)

    for negative_log_regularization_weight in [2, 3, 4]:
        regularization_weight = 10.0 ** -negative_log_regularization_weight
        table.add_entry(
            solver=IterativeImputer(
                n_nearest_features=80,
                max_iter=50
            ),
            name="IterativeImputer_%d" % negative_log_regularization_weight)

    for fill_method in ["mean", "median"]:
        table.add_entry(
            solver=SimpleFill(fill_method=fill_method),
            name="SimpleFill_%s" % fill_method)

    for k in [1, 3, 7]:
        table.add_entry(
            solver=KNN(
                k=k,
                orientation="rows"),
            name="KNN_k%d" % (k,))

    for shrinkage_value in [25, 50, 100]:
        # SoftImpute without rank constraints
        table.add_entry(
            solver=SoftImpute(
                shrinkage_value=shrinkage_value),
            name="SoftImpute_lambda%d" % (shrinkage_value,))

    for rank in [10, 20, 40]:
        table.add_entry(
            solver=IterativeSVD(
                rank=rank,
                init_fill_method="zero"),
            name="IterativeSVD_rank%d" % (rank,))

    table.save_html_table()
    table.print_sorted_errors()