import os
import sys
import argparse
import csv
import numpy as np
import h5py

from braniac.utils import *
from braniac.format.human36m.body import BodyFileReader, get_labels

class DataSplitter(object):
    '''
    A helper class to split training data and generate a CSV files with the list of training data.
    '''
    def __init__(self, input_folder):
        '''
        Initialize DataSplitter object.

        Args:
            input_folder(str): path of the input folder that contains all the clips.
        '''
        self._input_folder = input_folder
        self._items = []
        self._stats_context = DataStatisticsContext()

    def load_data_paths(self):
        '''
        Load the list of files or sub-folders into a python list with their
        corresponding label.
        '''
        labels = get_labels()
        sub_folders = os.listdir(self._input_folder)
        index = 0
        self._items.clear()
        for folder in sub_folders:
            folder_path = os.path.join(self._input_folder, folder)
            if not os.path.isdir(folder_path):
                continue
            
            subject_id = int(folder[1:])
            files = os.listdir(folder_path)
            for item in files:
                item_path = os.path.join(folder_path, item)
                if not os.path.isfile(item_path):
                    continue
                
                # Get the filename without extension and remove anything after
                # the space, so `Directive 1` will become `Directive`.
                filename = os.path.splitext(item)[0].split()[0]
                if self._filter_data(item_path):
                    self._items.append([os.path.abspath(item_path), labels[filename.lower()], subject_id])

                if (index % 100) == 0:
                    print("Process {} items.".format(index+1))
                index += 1
        
        return self._items

    def _filter_data(self, item_path):
        '''
        Return True to add this item and false otherwise.

        Args:
            item_path(str): path of the item.

        Todo: Refactor filter.
        '''
        frames = BodyFileReader(item_path)
        if len(frames) >= 60:
            for frame in frames:
                if len(frame) == 0:
                    return False
            return True
        return False

    def split_data(self, items):
        '''
        Split the data at random for train, eval and test set.

        Args:
            items: list of clips and their correspodning label if available.
        '''
        item_count = len(items)
        indices = np.arange(item_count)
        np.random.shuffle(indices)

        train_count = int(0.8 * item_count)
        test_count  = item_count - train_count

        train = []
        test  = []

        for i in range(train_count):
            train.append(items[indices[i]])

        for i in range(train_count, train_count + test_count):
            test.append(items[indices[i]])

        return train, test

    def write_to_csv(self, items, file_path):
        '''
        Write file path and its target pair in a CSV file format.

        Args:
            items: list of paths and their corresponding label if provided.
            file_path(str): target file path.
        '''
        if sys.version_info[0] < 3:
            with open(file_path, 'wb') as csv_file:
                writer = csv.writer(csv_file, delimiter=',')
                for item in items:
                    writer.writerow(item)
        else:
            with open(file_path, 'w', newline='') as csv_file:
                writer = csv.writer(csv_file, delimiter=',')
                for item in items:
                    writer.writerow(item)

    def compute_statistics(self):
        '''
        Compute some statistics across all the datatset.
        '''
        with BodyDataStatisticsPass1(self._stats_context) as stats:
            for item in self._items:
                frames = BodyFileReader(item[0])
                for frame in frames:
                    stats.add(frame[0].as_numpy())

        with BodyDataStatisticsPass2(self._stats_context) as stats:
            for item in self._items:
                frames = BodyFileReader(item[0])
                for frame in frames:
                    stats.add(frame[0].as_numpy())

        return self._stats_context

def main(input_folder, output_folder):
    '''
    Main entry point, it iterates through all the clip files in a folder or through all
    sub-folders into a list with their corresponding target label. It then split the data
    into training set, validation set and test set.

    Args:
        input_folder: input folder contains all the data files.
        output_folder: where to store the result.
    '''
    data_splitter = DataSplitter(input_folder)
    items = data_splitter.load_data_paths()
    print("{} items loaded, start splitting.".format(len(items)))

    train, test = data_splitter.split_data(items)
    print("Train: {} and test: {}.".format(len(train), len(test)))

    context = data_splitter.compute_statistics()
    print("Complete computing statistics.")

    save_statistics_context(context, os.path.join(output_folder, 'data_statistics.h5'))

    if len(train) > 0:
        data_splitter.write_to_csv(train, os.path.join(output_folder, 'train_map.csv'))
    if len(test) > 0:
        data_splitter.write_to_csv(test, os.path.join(output_folder, 'test_map.csv'))

    print("Done.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-i",
                        "--input_folder",
                        type = str,
                        help = "Input folder containing the raw data.",
                        required = True)

    parser.add_argument("-o",
                        "--output_folder",
                        type = str,
                        help = "Output folder for the generated training, validation and test text files.",
                        required = True)

    args = parser.parse_args()
    main(args.input_folder, args.output_folder)