from __future__ import division

import numpy as np
import matplotlib
import sys
sys.path.append('../')
sys.path.append('../../')
import trainer
from params import params as P
import fr3dnet
matplotlib.use('Agg')
import logging
from parallel import ParallelBatchIterator
from tqdm import tqdm
import theano
import dataset_3D
import theano.tensor as T
from multiprocessing import Pool
#from multiprocessing.pool import ThreadPool as Pool
import itertools
import util
import functools
import normalize
import augment

def load_data(tup): #filename, coordinates, labels tuple
    size = P.INPUT_SIZE
    data = []
    labels = []

    images = dataset_3D.giveSubImage(tup[0],tup[1],size)
    labels += map(int,tup[2])
    data += images[:]

    data = normalize.normalize(np.array(data, dtype=np.float32))

    if P.ZERO_CENTER:
        data -= P.MEAN_PIXEL

    result = zip([tup[0]]*len(labels), np.array(data, dtype=np.float32), np.array(labels, dtype=np.int32))

    if P.AUGMENT and P.AUGMENTATION_PARAMS['flip']:
        augmentation_extra = []

        for filename, image, label in result:
            if label == 1:
                flipped_images = augment.get_all_flips_3d(image)
                np.random.shuffle(flipped_images)
                flipped_images = flipped_images[:1] #SELECT 1 RANDOM IMAGES OF 7 possible flips
                n_new = len(flipped_images)

                augmentation_extra += zip([filename]*n_new, flipped_images, [label]*n_new)
            else: #For false candidates take one flip combination at random :)
                flip_option = augment.OPTS[np.random.randint(8)]
                augment.flip_given_axes(image, flip_option)

        result += augmentation_extra


    return result

def make_epoch(n, train_true, train_false, val_true, val_false):
    n = n[0]
    train_false = list(train_false)
    val_false = list(val_false)
    np.random.shuffle(train_false)
    np.random.shuffle(val_false)

    n_train_true = len(train_true)
    n_val_true = len(val_true)

    train_epoch = train_true + train_false[:n_train_true*2] #*2 to account for 1 flip directions
    val_epoch = val_true + val_false[:n_val_true*2]

    train_epoch = combine_tups(train_epoch)
    val_epoch = combine_tups(val_epoch)

    print "Epoch {0} n files {1}&{2}".format(n, len(train_epoch), len(val_epoch))
    pool = Pool(processes=12)
    train_epoch_data = list(itertools.chain.from_iterable(pool.imap_unordered(load_data, train_epoch)))
    print "Epoch {0} done loading train".format(n)

    val_epoch_data = list(itertools.chain.from_iterable(pool.imap_unordered(load_data, val_epoch)))
    print "Epoch {0} done loading validation".format(n)
    pool.close()

    np.random.shuffle(train_epoch_data)
    return train_epoch_data, val_epoch_data

def combine_tups(tup):
    names,coords,labels = zip(*tup)
    d = {n:[] for n in names}
    for name,coord,label in tup:
        d[name].append((coord,label))
    data = []
    for name,values in d.iteritems():
        c,l = zip(*values)
        data.append((name,c,l))
    return data

class Fr3dNetTrainer(trainer.Trainer):
    def __init__(self):
        metric_names = ['Loss','L2','Accuracy']
        super(Fr3dNetTrainer, self).__init__(metric_names)

        tensor5 = T.TensorType(theano.config.floatX, (False,) * 5)
        input_var = tensor5('inputs')
        target_var = T.ivector('targets')

        logging.info("Defining network")
        net = fr3dnet.define_network(input_var)
        self.network = net
        train_fn, val_fn, l_r = fr3dnet.define_updates(net, input_var, target_var)

        self.train_fn = train_fn
        self.val_fn = val_fn
        self.l_r = l_r

    def do_batches(self, fn, batches, metrics):
        batches = list(batches)
        for i, batch in enumerate(tqdm(batches)):
            filenames, inputs, targets = zip(*batch)
            targets = np.array(targets, dtype=np.int32)
            err, l2_loss, acc, predictions = fn(inputs, targets)

            metrics.append([err, l2_loss, acc])
            metrics.append_prediction(targets, predictions)

    def train(self, X_train, X_val):

        train_true = filter(lambda x: x[2]==1, X_train)
        train_false = filter(lambda x: x[2]==0, X_train)

        val_true = filter(lambda x: x[2]==1, X_val)
        val_false = filter(lambda x: x[2]==0, X_val)

        n_train_true = len(train_true)
        n_val_true = len(val_true)

        make_epoch_helper = functools.partial(make_epoch, train_true=train_true, train_false=train_false, val_true=val_true, val_false=val_false)

        logging.info("Starting training...")
        epoch_iterator = ParallelBatchIterator(make_epoch_helper, range(P.N_EPOCHS), ordered=False, batch_size=1, multiprocess=False, n_producers=1)

        for epoch_values in epoch_iterator:
            self.pre_epoch()
            train_epoch_data, val_epoch_data = epoch_values

            train_epoch_data = util.chunks(train_epoch_data, P.BATCH_SIZE_TRAIN)
            val_epoch_data = util.chunks(val_epoch_data, P.BATCH_SIZE_VALIDATION)

            self.do_batches(self.train_fn, train_epoch_data, self.train_metrics)
            self.do_batches(self.val_fn, val_epoch_data, self.val_metrics)

            self.post_epoch()
            logging.info("Setting learning rate to {}".format(P.LEARNING_RATE  * ((0.985)**self.epoch)))
            self.l_r.set_value(P.LEARNING_RATE  * ((0.985)**self.epoch))