import chainer
import copy
import os
import sys
import time
import datetime

import six
from chainer import reporter as reporter_module
from chainer import variable
from chainer.dataset import convert, concat_examples
from chainer.training import extension
from chainer.training import extensions
import chainer.training.trigger as trigger_module
from chainer.training.triggers import IntervalTrigger
from chainer.training.extensions import Evaluator
from chainer.training.extensions import util

from .logger import Logger


class AttributeUpdater(extension.Extension):

    def __init__(self, shift, attr='lr', trigger=(1, 'epoch')):
        self.shift = shift
        self.attr = attr
        self.trigger = trigger_module.get_trigger(trigger)

    def __call__(self, trainer):
        if self.trigger(trainer):
            optimizer = trainer.updater.get_optimizer('main')
            current_value = getattr(optimizer, self.attr)
            shifted_value = current_value * self.shift
            setattr(optimizer, self.attr, shifted_value)


class TwoStateLearningRateShifter(extension.Extension):

    CONTINUOS_SHIFT_STATE = 0
    INTERVAL_BASED_SHIFT_STATE = 1

    def __init__(self, start_lr, states):
        self.start_lr = start_lr
        self.lr = start_lr
        self.states = states
        self.current_state = self.states.pop(0)
        self.start_epoch = 0
        self.start_iteration = 0
        self.set_triggers()

    def set_triggers(self):
        self.target_lr = self.current_state['target_lr']
        self.update_trigger = trigger_module.get_trigger(self.current_state['update_trigger'])
        self.stop_trigger = trigger_module.get_trigger(self.current_state['stop_trigger'])
        self.phase_length, self.unit = self.current_state['stop_trigger']

    def switch_state_if_necessary(self, trainer):
        if self.stop_trigger(trainer):
            if len(self.states) > 1:
                self.current_state = self.states.pop(0)
                self.set_triggers()
                self.start_lr = self.target_lr
                self.start_epoch = trainer.updater.epoch
                self.start_iteration = self.update_trigger.iteration

    def __call__(self, trainer):
        updater = trainer.updater
        optimizer = trainer.updater.get_optimizer('main')

        if self.update_trigger(trainer):
            if self.current_state['state'] == self.CONTINUOS_SHIFT_STATE:
                epoch = updater.epoch_detail

                if self.unit == 'iteration':
                    interpolation_factor = (updater.iteration - self.start_iteration) / self.phase_length
                else:
                    interpolation_factor = (epoch - self.start_epoch) / self.phase_length

                new_lr = (1 - interpolation_factor) * self.start_lr + interpolation_factor * self.target_lr
                self.lr = new_lr
                optimizer.lr = new_lr

            else:
                optimizer.lr = self.target_lr
                self.lr = optimizer.lr

        self.switch_state_if_necessary(trainer)


class FastEvaluatorBase(Evaluator):

    def __init__(self, iterator, target, converter=convert.concat_examples,
                 device=None, eval_hook=None, eval_func=None, num_iterations=200):
        super(FastEvaluatorBase, self).__init__(
            iterator,
            target,
            converter=converter,
            device=device,
            eval_hook=eval_hook,
            eval_func=eval_func
        )
        self.num_iterations = num_iterations

    def evaluate(self):
        iterator = self._iterators['main']
        target = self._targets['main']
        eval_func = self.eval_func or target

        if self.eval_hook:
            self.eval_hook(self)
        it = copy.copy(iterator)
        summary = reporter_module.DictSummary()

        for _ in range(min(len(iterator.dataset) // iterator.batch_size, self.num_iterations)):
            batch = next(it, None)
            if batch is None:
                break

            observation = {}
            with reporter_module.report_scope(observation), chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
                in_arrays = self.converter(batch, self.device)
                if isinstance(in_arrays, tuple):
                    eval_func(*in_arrays)
                elif isinstance(in_arrays, dict):
                    eval_func(**in_arrays)
                else:
                    eval_func(in_arrays)

            summary.add(observation)

        return summary.compute_mean()


def get_fast_evaluator(trigger_interval):
    return type('FastEvaluator', (FastEvaluatorBase,), dict(trigger=trigger_interval, name='fast_validation'))


class EarlyStopIntervalTrigger(IntervalTrigger):

    def __init__(self, period, unit, curriculum):
        super().__init__(period, unit)
        self.curriculum = curriculum        

    def __call__(self, trainer):
        fire = super().__call__(trainer)
        if self.curriculum.training_finished is True:
            fire = True
        return fire


def get_trainer(net, updater, log_dir, print_fields, curriculum=None, extra_extensions=(), epochs=10, snapshot_interval=20000, print_interval=100, postprocess=None, do_logging=True, model_files=()):
    if curriculum is None:
        trainer = chainer.training.Trainer(
            updater,
            (epochs, 'epoch'),
            out=log_dir,
        )
    else:
        trainer = chainer.training.Trainer(
            updater,
            EarlyStopIntervalTrigger(epochs, 'epoch', curriculum),
            out=log_dir,
        )

    # dump computational graph
    trainer.extend(extensions.dump_graph('main/loss'))

    # also observe learning rate
    observe_lr_extension = chainer.training.extensions.observe_lr()
    observe_lr_extension.trigger = (print_interval, 'iteration')
    trainer.extend(observe_lr_extension)

    # Take snapshots
    trainer.extend(
        extensions.snapshot(filename="trainer_snapshot"),
        trigger=lambda trainer:
        trainer.updater.is_new_epoch or
        (trainer.updater.iteration > 0 and trainer.updater.iteration % snapshot_interval == 0)
    )

    if do_logging:
        # write all statistics to a file
        trainer.extend(Logger(model_files, log_dir, keys=print_fields, trigger=(print_interval, 'iteration'), postprocess=postprocess))

        # print some interesting statistics
        trainer.extend(extensions.PrintReport(
            print_fields,
            log_report='Logger',
        ))

    # Progressbar!!
    trainer.extend(extensions.ProgressBar(update_interval=1))

    for extra_extension, trigger in extra_extensions:
        trainer.extend(extra_extension, trigger=trigger)

    return trainer


def add_default_arguments(parser):
    parser.add_argument("log_dir", help='directory where generated models and logs shall be stored')
    parser.add_argument('-b', '--batch-size', dest='batch_size', type=int, required=True,
                        help="Number of images per training batch")
    parser.add_argument('-g', '--gpus', type=int, nargs="*", default=[], help="Ids of GPU to use [default: (use cpu)]")
    parser.add_argument('-e', '--epochs', type=int, default=20, help="Number of epochs to train [default: 20]")
    parser.add_argument('-r', '--resume', help="path to previously saved state of trained model from which training shall resume")
    parser.add_argument('-si', '--snapshot-interval', dest='snapshot_interval', type=int, default=20000,
                        help="number of iterations after which a snapshot shall be taken [default: 20000]")
    parser.add_argument('-ln', '--log-name', dest='log_name', default='training', help="name of the log folder")
    parser.add_argument('-lr', '--learning-rate', dest='learning_rate', type=float, default=0.01,
                        help="initial learning rate [default: 0.01]")
    parser.add_argument('-li', '--log-interval', dest='log_interval', type=int, default=100,
                        help="number of iterations after which an update shall be logged [default: 100]")
    parser.add_argument('--lr-step', dest='learning_rate_step_size', type=float, default=0.1,
                        help="Step size for decreasing learning rate [default: 0.1]")
    parser.add_argument('-t', '--test-interval', dest='test_interval', type=int, default=1000,
                        help="number of iterations after which testing should be performed [default: 1000]")
    parser.add_argument('--test-iterations', dest='test_iterations', type=int, default=200,
                        help="number of test iterations [default: 200]")
    parser.add_argument("-dr", "--dropout-ratio", dest='dropout_ratio', default=0.5, type=float,
                        help="ratio for dropout layers")

    return parser


def get_concat_and_pad_examples(padding=-10000):
    def concat_and_pad_examples(batch, device=None):
        return concat_examples(batch, device=device, padding=padding)

    return concat_and_pad_examples


def concat_and_pad_examples(batch, device=None, padding=-10000):
    return concat_examples(batch, device=device, padding=padding)


def get_definition_filepath(obj):
    return __import__(obj.__module__, fromlist=obj.__module__.split('.')[:1]).__file__


def get_definition_filename(obj):
    return os.path.basename(get_definition_filepath(obj))