import threading

from scipy.misc.pilutil import imread, imsave

from params import args
import numpy as np
import os

from utils import ThreadsafeIter


def average_strategy(images):
    return np.average(images, axis=0)


def hard_voting(images):
    rounded = np.round(images / 255.)
    return np.round(np.sum(rounded, axis=0) / images.shape[0]) * 255.

def ensemble_image(files, dirs, ensembling_dir, strategy):
    for file in files:
        images = []
        for dir in dirs:
            file_path = os.path.join(dir, file)
            if os.path.exists(file_path):
                images.append(imread(file_path, mode='L'))
        images = np.array(images)

        if strategy == 'average':
            ensembled = average_strategy(images)
        elif strategy == 'hard_voting':
            ensembled = hard_voting(images)
        else:
            raise ValueError('Unknown ensembling strategy')
        imsave(os.path.join(ensembling_dir, file), ensembled)


def ensemble(dirs, strategy, ensembling_dir, n_threads):
    files = ThreadsafeIter(os.listdir(dirs[0]))
    threads = [threading.Thread(target=ensemble_image, args=(files, dirs, ensembling_dir, strategy)) for i in range(n_threads)]

    for t in threads:
        t.start()

    for t in threads:
        t.join()


if __name__ == '__main__':
    n_threads = args.ensembling_cpu_threads
    ensembling_dir = args.ensembling_dir
    strategy = args.ensembling_strategy
    dirs = args.dirs_to_ensemble
    folds_dir = args.folds_dir
    dirs = [os.path.join(folds_dir, d) for d in dirs]
    for d in dirs:
        if not os.path.exists(d):
            raise ValueError(d + " doesn't exist")
    ensemble(dirs, strategy, ensembling_dir, n_threads)