"""
Copyright 2019, ETH Zurich

This file is part of L3C-PyTorch.

L3C-PyTorch is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
any later version.

L3C-PyTorch is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with L3C-PyTorch.  If not, see <https://www.gnu.org/licenses/>.
"""
import argparse
import multiprocessing
import os
import random
import shutil
import time
import warnings
from os.path import join

import PIL
import numpy as np
import skimage.color
from PIL import Image

from helpers.paths import IMG_EXTENSIONS

# TO SPEED THINGS UP: run on CPU cluster! We use task_array for this.
# task_array is not released. It's used by us to batch process on our servers. Feel free to replace with whatever you
# use. Make sure to set NUM_TASKS (number of concurrent processes) and set job_enumerate to a function that takes an
# iterable and only yield elements to be processed by the current process.
try:
    from task_array import NUM_TASKS, job_enumerate
except ImportError:
    NUM_TASKS = 1
    job_enumerate = enumerate

warnings.filterwarnings("ignore")


random.seed(123)


_NUM_PROCESSES = int(os.environ.get('NUM_PROCESS', 16))
_DEFAULT_MAX_SCALE = 0.8


get_fn = lambda p_: os.path.splitext(os.path.basename(p_))[0]


def iter_images(root_dir, num_folder_levels=0):
    fns = sorted(os.listdir(root_dir))
    for fn in fns:
        if num_folder_levels > 0:
            dir_p = os.path.join(root_dir, fn)
            if os.path.isdir(dir_p):
                print('Recursing into', fn)
                yield from iter_images(dir_p, num_folder_levels - 1)
            continue
        _, ext = os.path.splitext(fn)
        if ext.lower() in IMG_EXTENSIONS:
            yield os.path.join(root_dir, fn)


class Helper(object):
    def __init__(self, out_dir_clean, out_dir_discard, min_res: int):
        print(f'Creating {out_dir_clean}, {out_dir_discard}...')
        os.makedirs(out_dir_clean, exist_ok=True)
        os.makedirs(out_dir_discard, exist_ok=True)
        self.out_dir_clean = out_dir_clean
        self.out_dir_discard = out_dir_discard

        print('Getting images already processed...', end=" ", flush=True)
        self.images_cleaned = set(map(get_fn, os.listdir(out_dir_clean)))
        self.images_discarded = set(map(get_fn, os.listdir(out_dir_discard)))
        print(f'-> Found {len(self.images_cleaned) + len(self.images_discarded)} images.')

        self.min_res = min_res

    def process_all_in(self, input_dir):
        images_dl = iter_images(input_dir)  # generator of paths

        # files this job should compress
        files_of_job = [p for _, p in job_enumerate(images_dl)]
        # files that were compressed already by somebody (i.e. this job earlier)
        processed_already = self.images_cleaned | self.images_discarded
        # resulting files to be compressed
        files_of_job = [p for p in files_of_job if get_fn(p) not in processed_already]

        N = len(files_of_job)
        if N == 0:
            print('Everything processed / nothing to process.')
            return

        num_process = 2 if NUM_TASKS > 1 else _NUM_PROCESSES
        print(f'Processing {N} images using {num_process} processes in {NUM_TASKS} tasks...')

        start = time.time()
        predicted_time = None
        with multiprocessing.Pool(processes=num_process) as pool:
            for i, clean in enumerate(pool.imap_unordered(self.process, files_of_job)):
                if i > 0 and i % 100 == 0:
                    time_per_img = (time.time() - start) / (i + 1)
                    time_remaining = time_per_img * (N - i)
                    if not predicted_time:
                        predicted_time = time_remaining
                    print(f'\r{time_per_img:.2e} s/img | '
                          f'{i / N * 100:.1f}% | '
                          f'{time_remaining / 60:.1f} min remaining', end='', flush=True)

    def process(self, p_in):
        fn, ext = os.path.splitext(os.path.basename(p_in))
        if fn in self.images_cleaned:
            return 1
        if fn in self.images_discarded:
            return 0
        try:
            im = Image.open(p_in)
        except OSError as e:
            print(f'\n*** Error while opening {p_in}: {e}')
            return 0
        im_out = random_resize_or_discard(im, self.min_res)
        if im_out is not None:
            p_out = join(self.out_dir_clean, fn + '.png')  # Make sure to use .png!
            im_out.save(p_out)
            return 1
        else:
            p_out = join(self.out_dir_discard, os.path.basename(p_in))
            shutil.copy(p_in, p_out)
            return 0


def random_resize_or_discard(im, min_res: int):
    """Randomly resize image with `random_resize` and check if it should be discarded."""
    im_resized = random_resize(im, min_res)
    if im_resized is None:
        return None
    if should_discard(im_resized):
        return None
    return im_resized


def random_resize(im, min_res: int, max_scale=_DEFAULT_MAX_SCALE):
    """Scale longer side to `min_res`, but only if that scales by <= max_scale."""
    W, H = im.size
    D = min(W, H)
    scale_min = min_res / D
    # Image is too small to downscale by a factor smaller MAX_SCALE.
    if scale_min > max_scale:
        return None

    # Get a random scale for new size.
    scale = random.uniform(scale_min, max_scale)
    new_size = round(W * scale), round(H * scale)
    try:
        # Using LANCZOS!
        return im.resize(new_size, resample=PIL.Image.LANCZOS)
    except OSError as e:  # Happens for corrupted images
        print('*** Caught im.resize error', e)
        return None


def should_discard(im):
    """Return true iff the image is high in saturation or value, or not RGB."""
    # Modes found in train_0:
    # Counter({'RGB': 152326, 'L': 4149, 'CMYK': 66})
    if im.mode != 'RGB':
        return True
    im_rgb = np.array(im)
    im_hsv = skimage.color.rgb2hsv(im_rgb)
    mean_hsv = np.mean(im_hsv, axis=(0, 1))
    _, s, v = mean_hsv
    if s > 0.9:
        return True
    if v > 0.8:
        return True
    return False


def main():
    p = argparse.ArgumentParser()
    p.add_argument('base_dir',
                   help='Directory of images, or directory of DIRS.')
    p.add_argument('dirs', nargs='*',
                   help='If given, must be subdirectories in BASE_DIR. Will be processed. '
                        'If not given, assume BASE_DIR is already a directory of images.')
    p.add_argument('--out_dir_clean', required=True)
    p.add_argument('--out_dir_discard', required=True)
    p.add_argument('--resolution', type=int, default=512,
                   help='Randomly rescale each image to be at least '
                        'RANDOM_SCALE long on the longer side.')

    flags = p.parse_args()

    # If --dirs not given, just assume `base_dir` is already the directory of images.
    if not flags.dirs:
        flags.dirs = [os.path.basename(flags.base_dir)]
        flags.base_dir = os.path.dirname(flags.base_dir)

    h = Helper(flags.out_dir_clean, flags.out_dir_discard, flags.resolution)
    for i, d in enumerate(flags.dirs):
        print(f'*** {d}: {i}/{len(flags.dirs)}')
        h.process_all_in(join(flags.base_dir, d))

    print('\n\nDONE')  # For cluster logs.


if __name__ == '__main__':
    main()