#!/usr/bin/env python
# coding=utf-8

from __future__ import (print_function, division, absolute_import, unicode_literals)
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # no INFO/WARN logs from Tensorflow

import time
import utils
import threading
import numpy as np
import tensorflow as tf
from tensorflow.contrib import distributions as dist

from sacred import Experiment
from sacred.utils import get_by_dotted_path
from datasets import ds
from datasets import InputPipeLine
from nem_model import nem, static_nem_iterations, dynamic_nem_iteration, get_loss_step_weights
from network import net

ex = Experiment("R-NEM", ingredients=[ds, nem, net])


# noinspection PyUnusedLocal
@ex.config
def cfg():
    noise = {
        'noise_type': 'bitflip',                        # noise type
        'prob': 0.2,                                    # probability of annihilating the pixel
    }
    training = {
        'optimizer': 'adam',                            # {adam, sgd, momentum, adadelta, adagrad, rmsprop}
        'params': {
            'learning_rate': 0.001,                     # float
        },
        'max_patience': 10,                             # number of epochs to wait before early stopping
        'batch_size': 64,
        'max_epoch': 500,
        'clip_gradients': None,                         # maximum norm of gradients
        'debug_samples': [3, 37, 54],                   # sample ids to generate plots for (None, int, list)
        'save_epochs': [1, 5, 10, 20, 50, 100]          # at what epochs to save the model independent of valid loss
    }
    validation = {
        'batch_size': training['batch_size'],
        'debug_samples': [0, 1, 2]                      # sample ids to generate plots for (None, int, list)
    }

    feed_actions = False                                # whether to feed the actions (RL) via the recurrent state
    record_grouping_score = True                        # whether to use grouping to compute ARI/AMI scores
    record_relational_loss = 'collisions'               # use {events, collisions} to compute rel. losses or None

    dt = 10                                             # how many steps to include in the last loss
    log_dir = 'debug_out'                               # directory to dump logs and debug plots
    net_path = None                                     # path of to network file to initialize weights with

    # config to control run_from_file
    run_config = {
        'usage': 'test',                                # what dataset to use {training, validation, test}
        'batch_size': 100,
        'rollout_steps': 10,
        'debug_samples': [0, 1, 2],                      # sample ids to generate plots for (None, int, list)
    }


ex.add_named_config('no_score', {'record_grouping_score': False})
ex.add_named_config('no_collisions', {'record_relational_loss': None})


@ex.capture
def add_noise(data, noise):
    noise_type = noise['noise_type']
    if noise_type in ['None', 'none', None]:
        return data

    with tf.name_scope('input_noise'):
        shape = tf.stack([s.value if s.value is not None else tf.shape(data)[i]
                         for i, s in enumerate(data.get_shape())])

        if noise_type == 'bitflip':
            noise_dist = dist.Bernoulli(probs=noise['prob'], dtype=data.dtype)
            n = noise_dist.sample(shape)
            corrupted = data + n - 2 * data * n  # hacky way of implementing (data XOR n)
        else:
            raise KeyError('Unknown noise_type "{}"'.format(noise_type))

        corrupted.set_shape(data.get_shape())
        return corrupted


@ex.capture(prefix='training')
def set_up_optimizer(loss, optimizer, params, clip_gradients):
    opt = {
        'adam': tf.train.AdamOptimizer,
        'sgd': tf.train.GradientDescentOptimizer,
        'momentum': tf.train.MomentumOptimizer,
        'adadelta': tf.train.AdadeltaOptimizer,
        'adagrad': tf.train.AdagradOptimizer,
        'rmsprop': tf.train.RMSPropOptimizer
    }[optimizer](**params)
    grads_and_vars = opt.compute_gradients(loss)
    if clip_gradients is not None:
        grads_and_vars = [(tf.clip_by_norm(grad, clip_gradients), var)
                          for grad, var in grads_and_vars]

    return opt, opt.apply_gradients(grads_and_vars)


@ex.capture
def build_dynamic_graph(features, targets, gammas_old, thetas_old, preds_old, network, groups=None, collisions=None, actions=None):
    # Training graph
    features_corrupted = add_noise(features)
    loss, ub_loss, r_loss, r_ub_loss, thetas, preds, gammas, other_losses, other_ub_losses, r_other_losses, r_other_ub_losses = dynamic_nem_iteration(
        input_data=features_corrupted, target_data=targets, gamma_old=gammas_old, h_old=thetas_old, preds_old=preds_old, collisions=collisions, actions=actions)

    graph = {
        'inputs': features,
        'corrupted': features_corrupted,
        'targets': targets,
        'loss': loss,
        'ub_loss': ub_loss,
        'r_loss': r_loss,
        'r_ub_loss': r_ub_loss,
        'gammas_old': gammas_old,
        'thetas_old': thetas_old,
        'preds_old': preds_old,
        'gammas': gammas,
        'thetas': thetas,
        'preds': preds,
        'other_losses': other_losses,
        'other_ub_losses': other_ub_losses,
        'r_other_losses': r_other_losses,
        'r_other_ub_losses': r_other_ub_losses,
    }

    # compute grouping info
    if groups is not None:
        graph['groups'] = groups

    if collisions is not None:
        graph['collisions'] = collisions

    # add actions to the graph
    if actions is not None:
        graph['actions'] = actions

    # if NPE with a non-empty attention block.
    if network['recurrent'][0]['name'] == 'npe' and len(network['recurrent'][0]['attention']) > 0:
        k = gammas.shape[1].value

        ns = tf.contrib.framework.get_name_scope()
        g = tf.get_default_graph()
        attentions = [g.get_tensor_by_name("{}/R-RNNEM/step_0/NPE/Sigmoid:0".format(ns))]
        attentions = tf.stack(attentions, axis=0)
        attentions = tf.reshape(attentions, [1, -1, k, k - 1])
        graph['attentions'] = attentions  # in order

    return graph


@net.capture
def build_rollout_graph(inputs, batch_size, k, recurrent):
    feature_shape = [s.value for s in inputs['features'].shape[2:]]
    groups_shape = [s.value for s in inputs['groups'].shape[2:]] if inputs.get('groups', None) is not None else None
    actions_shape = [s.value for s in inputs['actions'].shape[2:]] if inputs.get('actions', None) is not None else None

    with tf.name_scope('rollout'):
        X_rollout_shape = [batch_size] + feature_shape
        X_rollout = tf.placeholder(tf.float32, shape=X_rollout_shape)

        Y_rollout = tf.placeholder(tf.float32, shape=X_rollout_shape)

        gamma_rollout_shape = [batch_size, k] + feature_shape[1:]
        gamma_rollout = tf.placeholder(tf.float32, shape=gamma_rollout_shape)

        theta_rollout_shape = [batch_size*k, recurrent[0]['size']]
        theta_rollout = tf.placeholder(tf.float32, shape=theta_rollout_shape)

        pred_rollout = tf.placeholder(tf.float32, shape=gamma_rollout_shape)

        if inputs.get('groups', None) is not None:
            G_rollout_shape = [batch_size] + groups_shape
            G_rollout = tf.placeholder(tf.float32, shape=G_rollout_shape)
        else:
            G_rollout = None

        if inputs.get('collisions', None) is not None:
            collisions_rollout_shape = [batch_size] + feature_shape
            collision_rollout = tf.placeholder(tf.float32, shape=collisions_rollout_shape)
        elif inputs.get('events', None) is not None:
            collisions_rollout_shape = [batch_size, 1, 1, 1, 1]
            collision_rollout = tf.placeholder(tf.float32, shape=collisions_rollout_shape)
        else:
            collision_rollout = None

        if inputs.get('actions', None) is not None:
            A_rollout_shape = [batch_size] + actions_shape
            A_rollout = tf.placeholder(tf.float32, shape=A_rollout_shape)
        else:
            A_rollout = None

        graph = build_dynamic_graph(X_rollout, Y_rollout, gamma_rollout, theta_rollout, pred_rollout,
                                              groups=G_rollout, collisions=collision_rollout, actions=A_rollout)
        return graph


@ex.capture
def build_graph(features, network, groups=None, collisions=None, actions=None):
    
    # Training graph
    features_corrupted = add_noise(features)
    loss, ub_loss, r_loss, r_ub_loss, thetas, preds, gammas, other_losses, other_ub_losses, r_other_losses, \
    r_other_ub_losses = static_nem_iterations(features_corrupted, features,
                                              collisions=collisions, actions=actions)

    graph = {
        'inputs': features,
        'corrupted': features_corrupted,
        'loss': loss,
        'ub_loss': ub_loss,
        'r_loss': r_loss,
        'r_ub_loss': r_ub_loss,
        'gammas': gammas,
        'thetas': thetas,
        'preds': preds,
        'other_losses': other_losses,
        'other_ub_losses': other_ub_losses,
        'r_other_losses': r_other_losses,
        'r_other_ub_losses': r_other_ub_losses,
    }

    # compute grouping info
    if groups is not None:
        graph['groups'] = groups
        graph['ARI'] = utils.tf_adjusted_rand_index(groups, gammas, get_loss_step_weights())

    # add actions to the graph
    if actions is not None:
        graph['actions'] = actions

    # if NPE with a non-empty attention block.
    if network['recurrent'][0]['name'] == 'npe' and len(network['recurrent'][0]['attention']) > 0:
        nr_iters = gammas.shape[0].value
        k = gammas.shape[2].value

        attentions = []
        ns = tf.contrib.framework.get_name_scope()
        g = tf.get_default_graph()
        for i in range(nr_iters-1):
            attention = g.get_tensor_by_name("{}/R-RNNEM/step_{}/NPE/Sigmoid:0".format(ns, i))
            attentions.append(attention)

        attentions = tf.stack(attentions, axis=0)
        attentions = tf.reshape(attentions, [nr_iters-1, -1, k, k-1])

        graph['attentions'] = attentions  # in order

    return graph


def build_debug_graph(inputs):
    nr_iters = inputs['features'].shape[0]
    feature_shape = [s.value for s in inputs['features'].shape[2:]]
    groups_shape = [s.value for s in inputs['groups'].shape[2:]] if inputs.get('groups', None) is not None else None
    actions_shape = [s.value for s in inputs['actions'].shape[2:]] if inputs.get('actions', None) is not None else None

    with tf.name_scope('debug'):
        X_debug_shape = [nr_iters, None] + feature_shape
        X_debug = tf.placeholder(tf.float32, shape=X_debug_shape)

        if inputs.get('groups', None) is not None:
            G_debug_shape = [nr_iters, None] + groups_shape
            G_debug = tf.placeholder(tf.float32, shape=G_debug_shape)
        else:
            G_debug = None

        if inputs.get('actions', None) is not None:
            A_debug_shape = [nr_iters, None] + actions_shape
            A_debug = tf.placeholder(tf.float32, shape=A_debug_shape)
        else:
            A_debug = None

        graph = build_graph(X_debug, groups=G_debug, actions=A_debug)
        return graph


@ex.capture
def build_graphs(train_inputs, valid_inputs, record_relational_loss):
    
    # Build Graph
    varscope = tf.get_variable_scope()
    with tf.name_scope("train"):
        train_graph = build_graph(train_inputs['features'],
                                  groups=train_inputs.get('groups', None),
                                  collisions=train_inputs.get(record_relational_loss, None),
                                  actions=train_inputs.get('actions', None)
                                  )
        opt, train_op = set_up_optimizer(train_graph['loss'])

    varscope.reuse_variables()
    with tf.name_scope("valid"):
        valid_graph = build_graph(valid_inputs['features'],
                                  groups=valid_inputs.get('groups', None),
                                  collisions=valid_inputs.get(record_relational_loss, None),
                                  actions=valid_inputs.get('actions', None))

    debug_graph = build_debug_graph(valid_inputs)

    return train_op, train_graph, valid_graph, debug_graph


@ex.capture
def create_curve_plots(name, plot_dict, coarse_range, fine_range, log_dir):
    import matplotlib.pyplot as plt
    fig = utils.curve_plot(plot_dict, coarse_range, fine_range)
    fig.suptitle(name)
    fig.savefig(os.path.join(log_dir, name + '_curve.png'), bbox_inches='tight', pad_inches=0)
    plt.close(fig)


@ex.capture
def create_debug_plots(name, debug_out, sample_indices, log_dir, debug_groups=None):
    import matplotlib.pyplot as plt

    if debug_groups is not None:
        scores, confidencess = utils.evaluate_groups_seq(debug_groups[1:], debug_out['gammas'][1:], get_loss_step_weights())
    else:
        scores, confidencess = len(sample_indices) * [0.0], len(sample_indices) * [0.0]

    # produce overview plot
    for i, nr in enumerate(sample_indices):
        fig = utils.overview_plot(i, **debug_out)
        fig.suptitle(name + ', sample {},  AMI Score: {:.3f} ({:.3f}) '.format(nr, scores[i], confidencess[i]))
        fig.savefig(os.path.join(log_dir, name + '_{}.png'.format(nr)), bbox_inches='tight', pad_inches=0)
        plt.close(fig)


def populate_debug_out(session, debug_graph, pipe_line, debug_samples, name):
    idxs = debug_samples if isinstance(debug_samples, list) else [debug_samples]

    out_list = ['features']
    out_list.append('groups') if debug_graph.get('groups', None) is not None else None
    out_list.append('actions') if debug_graph.get('actions', None) is not None else None

    debug_data = pipe_line.get_debug_samples(idxs, out_list=out_list)

    feed_dict = {debug_graph['inputs']: debug_data['features']}

    if debug_data.get('groups', None) is not None:
        feed_dict[debug_graph['groups']] = debug_data['groups']

    if debug_data.get('actions', None) is not None:
        feed_dict[debug_graph['actions']] = debug_data['actions']

    debug_out = session.run(debug_graph, feed_dict=feed_dict)
    create_debug_plots(name, debug_out, idxs, debug_groups=debug_data.get('groups', None))


def run_epoch(session, pipe_line, graph, debug_graph, debug_samples, debug_name, train_op=None):
    fetches = [graph['loss'], graph['ub_loss'], graph['r_loss'], graph['r_ub_loss'], graph['other_losses'],
               graph['other_ub_losses'], graph['r_other_losses'], graph['r_other_ub_losses']]

    fetches.append(graph['ARI']) if graph.get('ARI', None) is not None else None
    fetches.append(train_op) if train_op is not None else None

    losses, ub_losses, r_losses, r_ub_losses, others, others_ub, r_others, r_others_ub, ari_scores = [], [], [], [], [], [], [], [], []
    # run through the epoch
    for b in range(pipe_line.get_n_batches()):
        # run batch
        out = session.run(fetches=fetches)

        # total losses (and upperbound)
        losses.append(out[0])
        ub_losses.append(out[1])

        # total relational losses (and upperbound)
        r_losses.append(out[2])
        r_ub_losses.append(out[3])

        # other losses (and upperbound)
        others.append(out[4])
        others_ub.append(out[5])

        # other relational losses (and upperbound)
        r_others.append(out[6])
        r_others_ub.append(out[7])

        # ARI
        ari_scores.append(out[8] if graph.get('ARI', None) is not None else (0., 0., 0., 0.))

    if debug_samples is not None:
        populate_debug_out(session, debug_graph, pipe_line, debug_samples, debug_name)

    # build log dict
    log_dict = {
        'loss': float(np.mean(losses)),
        'ub_loss': float(np.mean(ub_losses)),
        'r_loss': float(np.mean(r_losses)),
        'r_ub_loss': float(np.mean(r_ub_losses)),
        'others': np.mean(others, axis=0),
        'others_ub': np.mean(others_ub, axis=0),
        'r_others': np.mean(r_others, axis=0),
        'r_others_ub': np.mean(r_others_ub, axis=0),
        'score': np.mean(ari_scores, axis=0)[0],
        'score_last': np.mean(ari_scores, axis=0)[1],
        'score_conf': np.mean(ari_scores, axis=0)[2],
        'score_last_conf': np.mean(ari_scores, axis=0)[3]
    }

    return log_dict


@ex.capture
def add_log(key, value, _run):
    if 'logs' not in _run.info:
        _run.info['logs'] = {}
    logs = _run.info['logs']
    split_path = key.split('.')
    current = logs
    for p in split_path[:-1]:
        if p not in current:
            current[p] = {}
        current = current[p]

    final_key = split_path[-1]
    if final_key not in current:
        current[final_key] = []
    entries = current[final_key]
    entries.append(value)


@ex.capture
def get_logs(key, _run):
    logs = _run.info.get('logs', {})
    return get_by_dotted_path(logs, key)


def log_log_dict(usage, log_dict):
    for log_key, value in log_dict.items():
        add_log('{}.{}'.format(usage, log_key), value)


def print_log_dict(log_dict, usage, t, dt, s_loss_weights, dt_s_loss_weights):
    print("%s Loss: %.3f (UB: %.3f), Relational Loss: %.3f (UB: %.3f), Score: %.3f (conf: %0.3f), Last Score:"
          " %.3f (conf: %.3f) took %.3fs" % (usage, log_dict['loss'], log_dict['ub_loss'], log_dict['r_loss'],
                                             log_dict['r_ub_loss'], log_dict['score'], log_dict['score_conf'],
                                             log_dict['score_last'], log_dict['score_last_conf'], time.time() - t))

    print("    other losses: {}".format(", ".join(["%.2f (UB: %.2f)" %
          (log_dict['others'][:, i].sum(0) / s_loss_weights, log_dict['others_ub'][:, i].sum(0) / s_loss_weights)
           for i in range(len(log_dict['others'][0]))])))
    print("        last {} steps avg: {}".format(dt, ", ".join(["%.2f (UB: %.2f)" %
          (log_dict['others'][-dt:, i].sum(0) / dt_s_loss_weights,
           log_dict['others_ub'][-dt:, i].sum(0) / dt_s_loss_weights) for i in range(len(log_dict['others'][0]))])))

    print("    other relational losses: {}".format(", ".join(["%.2f (UB: %.2f)" %
          (log_dict['r_others'][:, i].sum(0) / s_loss_weights, log_dict['r_others_ub'][:, i].sum(0) / s_loss_weights)
           for i in range(len(log_dict['r_others'][0]))])))
    print("        last {} steps avg: {}".format(dt, ", ".join(["%.2f (UB: %.2f)" %
          (log_dict['r_others'][-dt:, i].sum(0) / dt_s_loss_weights,
           log_dict['r_others_ub'][-dt:, i].sum(0) / dt_s_loss_weights) for i in range(len(log_dict['r_others'][0]))])))


@ex.command
def rollout_from_file(record_grouping_score, record_relational_loss, feed_actions, run_config, nem, dt, log_dir, seed, net_path=None):
    tf.set_random_seed(seed)

    # load network weights (default is log_dir/best if net_path is not set)
    net_path = os.path.abspath(os.path.join(log_dir, 'best')) if net_path is None else net_path
    usage = run_config['usage']

    # prep weights for print out
    loss_step_weights = get_loss_step_weights()
    s_loss_weights = np.sum(loss_step_weights)
    dt_s_loss_weights = np.sum(loss_step_weights[-dt:])

    with tf.Graph().as_default() as g:
        # Set up Data
        batch_size = run_config['batch_size']
        nr_steps = nem['nr_steps'] + run_config['rollout_steps'] + 1
        out_list = ['features']
        out_list.append('groups') if record_grouping_score else None
        out_list.append(record_relational_loss) if record_relational_loss else None
        out_list.append('actions') if feed_actions else None

        inputs = InputPipeLine(usage, shuffle=False, sequence_length=nr_steps, out_list=out_list, batch_size=batch_size)

        # Build Graph
        graph = build_rollout_graph(inputs.output, batch_size, nem['k'])

        start_time = time.time()
        with tf.Session(graph=g) as session:
            saver = tf.train.Saver()
            saver.restore(session, net_path)

            # produce data
            fetches = [graph['loss'], graph['ub_loss'], graph['r_loss'], graph['r_ub_loss'], graph['other_losses'],
                       graph['other_ub_losses'], graph['r_other_losses'], graph['r_other_ub_losses'],
                       graph['corrupted'], graph['gammas'], graph['thetas'], graph['preds']]

            # create loss dict
            loss_dict = {'loss': [], 'ub_loss': [], 'r_loss': [], 'r_ub_loss': [], 'others': [], 'others_ub': [],
                         'r_others': [], 'r_others_ub': []}

            # debug out
            for b in range(inputs.get_n_batches()):
                idxs = list(range(b*batch_size, (b+1) * batch_size))
                input_data = inputs.get_debug_samples(idxs, out_list=out_list)

                # create empty list
                loss_dict['loss'].append([])
                loss_dict['ub_loss'].append([])
                loss_dict['r_loss'].append([])
                loss_dict['r_ub_loss'].append([])
                loss_dict['others'].append([])
                loss_dict['others_ub'].append([])
                loss_dict['r_others'].append([])
                loss_dict['r_others_ub'].append([])

                # init
                with tf.name_scope('initial_state'):
                    # inner RNN hidden state init
                    with tf.name_scope('inner_RNN_init'):
                        theta = np.zeros((batch_size * nem['k'], 250), dtype=np.float32)

                    # initial prediction (B, K, W, H, C)
                    with tf.name_scope('pred_init'):
                        pred = np.ones((batch_size, nem['k'], 64, 64, 1), dtype=np.float32) * nem['pred_init']

                    # initial gamma (B, K, W, H, 1)
                    with tf.name_scope('gamma_init_gaussian'):
                        # init with Gaussian distribution
                        gamma = np.abs(np.random.randn(batch_size, nem['k'], 64, 64, 1))
                        gamma /= np.sum(gamma, axis=1, keepdims=True)

                        # init with all 1 if K = 1
                        if nem['k'] == 1:
                            gamma = np.ones_like(gamma)

                corrupted, scores, gammas, thetas, preds = [], [], [gamma], [theta], [pred]

                # run rollout steps
                for t in range(nem['nr_steps'] + run_config['rollout_steps']):

                    # build feed dict
                    feed_dict = {graph['targets']: input_data['features'][t + 1],
                                 graph['gammas_old']: gamma,
                                 graph['thetas_old']: theta,
                                 graph['preds_old']: pred}

                    # decided if rollout or real data
                    if t < nem['nr_steps']:
                        feed_dict[graph['inputs']] = input_data['features'][t]
                    else:
                        feed_dict[graph['inputs']] = np.sum(gamma * pred, 1, keepdims=True)

                    if input_data.get('groups', None) is not None:
                        feed_dict[graph['groups']] = input_data['groups'][t+1]

                    if input_data.get('collisions', None) is not None:
                        feed_dict[graph['collisions']] = input_data['collisions'][t]
                    elif input_data.get('events', None) is not None:
                        feed_dict[graph['collisions']] = input_data['events'][t]

                    if input_data.get('actions', None) is not None:
                        feed_dict[graph['actions']] = input_data['actions'][t]

                    # run forward pass
                    out = session.run(fetches, feed_dict=feed_dict)

                    # log results for iteration
                    corr, gamma, theta, pred = out[-4:]

                    # re-compute gamma if rollout
                    if t >= nem['nr_steps']:
                        truth = np.max(pred, axis=1, keepdims=True)

                        # avoid disappearing by scaling or sampling
                        truth[truth > 0.1] = 1.0
                        truth[truth <= 0.1] = 0.0

                        # compute probs
                        probs = truth * pred + (1 - truth) * (1 - pred)

                        # add epsilon to probs in order to prevent 0 gamma
                        probs += 1e-6

                        # compute the new gamma (E-step) or set to one for k=1
                        gamma = probs / np.sum(probs, 1, keepdims=True) if nem['k'] > 1 else np.ones_like(gamma)

                    corrupted.append(corr)
                    gammas.append(gamma)
                    thetas.append(theta)
                    preds.append(pred)

                    # log losses
                    loss_dict['loss'][-1].append(out[0])
                    loss_dict['ub_loss'][-1].append(out[1])
                    loss_dict['r_loss'][-1].append(out[2])
                    loss_dict['r_ub_loss'][-1].append(out[3])
                    loss_dict['others'][-1].append(out[4])
                    loss_dict['others_ub'][-1].append(out[5])
                    loss_dict['r_others'][-1].append(out[6])
                    loss_dict['r_others_ub'][-1].append(out[7])

                # build plot dict if needed
                out_dict = {
                    'inputs': input_data['features'],
                    'corrupted': np.array(corrupted),
                    'gammas': np.array(gammas),
                    'preds': np.array(preds),
                }

                # create debug plots for entries in first batch
                if b == 0 and run_config.get('debug_samples', None):
                    create_debug_plots('rollout_{}'.format(usage), out_dict, run_config['debug_samples'])

            # build log dict NOTE: this is not safe if not full steps
            log_dict = {
                'loss': np.mean(loss_dict['loss']),
                'ub_loss': np.mean(loss_dict['ub_loss']),
                'r_loss': np.mean(loss_dict['r_loss']),
                'r_ub_loss': np.mean(loss_dict['r_ub_loss']),
                'others': np.mean(loss_dict['others'], axis=0),
                'others_ub': np.mean(loss_dict['others_ub'], axis=0),
                'r_others': np.mean(loss_dict['r_others'], axis=0),
                'r_others_ub': np.mean(loss_dict['r_others_ub'], axis=0),
                'score': -1,
                'score_last': -1,
                'score_conf': -1,
                'score_last_conf': -1
            }

            # log in db
            log_log_dict(usage, log_dict)

            # print
            print_log_dict(log_dict, usage, start_time, dt, s_loss_weights, dt_s_loss_weights)


@ex.command
def run_from_file(record_grouping_score, record_relational_loss, feed_actions, run_config, nem, dt, log_dir, seed, net_path=None):
    tf.set_random_seed(seed)

    # load network weights (default is log_dir/best if net_path is not set)
    net_path = os.path.abspath(os.path.join(log_dir, 'best')) if net_path is None else net_path
    usage = run_config['usage']

    # prep weights for print out
    loss_step_weights = get_loss_step_weights()
    s_loss_weights = np.sum(loss_step_weights)
    dt_s_loss_weights = np.sum(loss_step_weights[-dt:])

    with tf.Graph().as_default() as g:
        # Set up Data
        nr_steps = nem['nr_steps'] + 1
        out_list = ['features']
        out_list.append('groups') if record_grouping_score else None
        out_list.append(record_relational_loss) if record_relational_loss else None
        out_list.append('actions') if feed_actions else None

        inputs = InputPipeLine(usage, shuffle=False, sequence_length=nr_steps, out_list=out_list, batch_size=run_config['batch_size'])

        # Build Graph
        _, _, graph, debug_graph = build_graphs(inputs.output, inputs.output)

        t = time.time()
        with tf.Session(graph=g) as session:
            coord = tf.train.Coordinator()
            saver = tf.train.Saver()
            saver.restore(session, net_path)

            # launch pipeline
            enqueue_thread = threading.Thread(target=inputs.enqueue, args=[session, coord])
            enqueue_thread.start()

            log_dict = run_epoch(session, inputs, graph, debug_graph, run_config['debug_samples'],
                                 "run_{}".format(usage))

            # log log dict
            log_log_dict(usage, log_dict)

            # shutdown pipeline
            coord.request_stop()
            session.run(inputs.queue.close(cancel_pending_enqueues=True))
            coord.join()

        print_log_dict(log_dict, usage, t, dt, s_loss_weights, dt_s_loss_weights)


@ex.automain
def run(record_grouping_score, record_relational_loss, feed_actions, net_path, training, validation, nem, dt, seed, log_dir, _run):
    save_epochs = training['save_epochs']

    # clear debug dir
    if log_dir and net_path is None:
        utils.create_directory(log_dir)
        utils.delete_files(log_dir, recursive=True)

    # prep weights for print out
    loss_step_weights = get_loss_step_weights()
    s_loss_weights = np.sum(loss_step_weights)
    dt_s_loss_weights = np.sum(loss_step_weights[-dt:])

    # Set up data pipelines
    nr_iters = nem['nr_steps'] + 1
    out_list = ['features']
    out_list.append('groups') if record_grouping_score else None
    out_list.append(record_relational_loss) if record_relational_loss else None
    out_list.append('actions') if feed_actions else None

    train_inputs = InputPipeLine('training', shuffle=True, out_list=out_list, sequence_length=nr_iters,
                                 batch_size=training['batch_size'])
    valid_inputs = InputPipeLine('validation', shuffle=False, out_list=out_list, sequence_length=nr_iters,
                                 batch_size=validation['batch_size'])

    # Build Graph
    train_op, train_graph, valid_graph, debug_graph = build_graphs(train_inputs.output, valid_inputs.output)
    init = tf.global_variables_initializer()

    # print vars
    utils.print_vars(tf.trainable_variables())

    with tf.Session() as session:
        tf.set_random_seed(seed)

        # continue training from net_path if specified
        saver = tf.train.Saver()
        if net_path is not None:
            saver.restore(session, net_path)
        else:
            session.run(init)

        # start training pipelines
        writer = tf.summary.FileWriter(log_dir, graph=session.graph,)
        coord = tf.train.Coordinator()
        train_enqueue_thread = threading.Thread(target=train_inputs.enqueue, args=[session, coord])
        coord.register_thread(train_enqueue_thread)
        train_enqueue_thread.start()
        valid_enqueue_thread = threading.Thread(target=valid_inputs.enqueue, args=[session, coord])
        coord.register_thread(valid_enqueue_thread)
        valid_enqueue_thread.start()

        best_valid_loss = np.inf
        best_valid_epoch = 0
        for epoch in range(1, training['max_epoch'] + 1):
            # run train epoch
            t = time.time()
            log_dict = run_epoch(session, train_inputs, train_graph, debug_graph, training['debug_samples'], "train_e{}".format(epoch), train_op=train_op)

            # log all items in dict
            log_log_dict('training', log_dict)

            # produce print-out
            print("\n" + 80 * "%" + "    EPOCH {}   ".format(epoch) + 80 * "%")
            print_log_dict(log_dict, 'Train', t, dt, s_loss_weights, dt_s_loss_weights)

            # run valid epoch
            t = time.time()
            log_dict = run_epoch(session, valid_inputs, valid_graph, debug_graph, validation['debug_samples'], "valid_e{}".format(epoch))

            # add logs
            log_log_dict('validation', log_dict)

            # produce plots
            create_curve_plots('loss', {'training': get_logs('training.loss'),
                                        'validation': get_logs('validation.loss')}, [0, 1000], [0, 200])
            create_curve_plots('r_loss', {'training': get_logs('training.r_loss'),
                                          'validation': get_logs('validation.r_loss')}, [0, 100], [0, 20])

            create_curve_plots('score', {'score': get_logs('validation.score'),
                                         'score_last': get_logs('validation.score_last')}, [0, 1], None)

            # produce print-out
            print("\n")
            print_log_dict(log_dict, 'Validation', t, dt, s_loss_weights, dt_s_loss_weights)

            if log_dict['loss'] < best_valid_loss:
                best_valid_loss = log_dict['loss']
                best_valid_epoch = epoch
                _run.result = float(log_dict['score']), float(log_dict['score_last']), \
                              float(log_dict['loss']), float(log_dict['ub_loss']), \
                              float(np.sum(log_dict['others'][-dt:, 1])/dt_s_loss_weights), \
                              float(np.sum(log_dict['others_ub'][-dt:, 1]) / dt_s_loss_weights), \
                              float(np.sum(log_dict['others'][-dt:, 2]) / dt_s_loss_weights), \
                              float(np.sum(log_dict['others_ub'][-dt:, 2]) / dt_s_loss_weights), \
                              float(log_dict['r_loss']), float(log_dict['r_ub_loss']), \
                              float(np.sum(log_dict['r_others'][-dt:, 1]) / dt_s_loss_weights), \
                              float(np.sum(log_dict['r_others_ub'][-dt:, 1]) / dt_s_loss_weights), \
                              float(np.sum(log_dict['r_others'][-dt:, 2]) / dt_s_loss_weights), \
                              float(np.sum(log_dict['r_others_ub'][-dt:, 2]) / dt_s_loss_weights)

                print("    Best validation loss improved to %.03f" % best_valid_loss)
                save_destination = saver.save(session, os.path.abspath(os.path.join(log_dir, 'best')))
                print("    Saved to:", save_destination)
            if epoch in save_epochs:
                save_destination = saver.save(session, os.path.abspath(os.path.join(log_dir, 'epoch_{}'.format(epoch))))
                print("    Saved to:", save_destination)

            best_valid_loss = min(best_valid_loss, log_dict['loss'])

            if best_valid_loss < np.min(get_logs('validation.loss')[-training['max_patience']:]):
                print('Early Stopping because validation loss did not improve for {} epochs'.format(training['max_patience']))
                break

            if np.isnan(log_dict['loss']):
                print('Early Stopping because validation loss is nan')
                break

        # shutdown everything to avoid zombies
        coord.request_stop()
        session.run(train_inputs.queue.close(cancel_pending_enqueues=True))
        session.run(valid_inputs.queue.close(cancel_pending_enqueues=True))
        coord.join()

    # reset the graph
    tf.reset_default_graph()

    # gather best results
    best_valid_score = float(get_logs('validation.score')[best_valid_epoch - 1])
    best_valid_score_last = float(get_logs('validation.score_last')[best_valid_epoch - 1])

    best_valid_loss = float(get_logs('validation.loss')[best_valid_epoch - 1])
    best_valid_ub_loss = float(get_logs('validation.ub_loss')[best_valid_epoch - 1])

    best_valid_intra_loss = float(np.sum(get_logs('validation.others')[best_valid_epoch - 1][-dt:, 1])/dt_s_loss_weights)
    best_valid_intra_ub_loss = float(np.sum(get_logs('validation.others_ub')[best_valid_epoch - 1][-dt:, 1])/dt_s_loss_weights)

    best_valid_inter_loss = float(np.sum(get_logs('validation.others')[best_valid_epoch - 1][-dt:, 2])/dt_s_loss_weights)
    best_valid_inter_ub_loss = float(np.sum(get_logs('validation.others_ub')[best_valid_epoch - 1][-dt:, 2])/dt_s_loss_weights)

    best_valid_r_loss = float(get_logs('validation.r_loss')[best_valid_epoch - 1])
    best_valid_r_ub_loss = float(get_logs('validation.r_ub_loss')[best_valid_epoch - 1])

    best_valid_r_intra_loss = float(np.sum(get_logs('validation.r_others')[best_valid_epoch - 1][-dt:, 1])/dt_s_loss_weights)
    best_valid_r_intra_ub_loss = float(np.sum(get_logs('validation.r_others_ub')[best_valid_epoch - 1][-dt:, 1])/dt_s_loss_weights)

    best_valid_r_inter_loss = float(np.sum(get_logs('validation.r_others')[best_valid_epoch - 1][-dt:, 2])/dt_s_loss_weights)
    best_valid_r_inter_ub_loss = float(np.sum(get_logs('validation.r_others_ub')[best_valid_epoch - 1][-dt:, 2])/dt_s_loss_weights)

    return best_valid_score, best_valid_score_last, best_valid_loss, best_valid_ub_loss, best_valid_intra_loss, \
           best_valid_intra_ub_loss, best_valid_inter_loss, best_valid_inter_ub_loss, best_valid_r_loss, \
           best_valid_r_ub_loss, best_valid_r_intra_loss, best_valid_r_intra_ub_loss, best_valid_r_inter_loss, \
           best_valid_r_inter_ub_loss