import os.path as osp
import numpy as np

from mmcv.runner import LoggerHook, master_only


class TensorboardLoggerHook(LoggerHook):
    """
    Hook for starting a tensor-board logger.

    Args:
        log_dir (str or Path): dir to save logger file.
        interval (int): logging interval, default is 10
        ignore_last:
        reset_flag:
        register_logWithIter_keyword:
    """

    def __init__(
            self,
            log_dir=None,
            interval=10,
            ignore_last=True,
            reset_flag=True,
            register_logWithIter_keyword=None
    ):
        super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
                                                    reset_flag)
        self.log_dir = log_dir
        self.register_logWithIter_keyword = register_logWithIter_keyword

    @master_only
    def before_run(self, runner):
        try:
            from tensorboardX import SummaryWriter
        except ImportError:
            raise ImportError('Please install tensorflow and tensorboardX '
                              'to use TensorboardLoggerHook.')
        else:
            if self.log_dir is None:
                self.log_dir = osp.join(runner.work_dir, 'tf_logs')
            self.writer = SummaryWriter(self.log_dir)

    @master_only
    def single_log(self, tag, record, global_step):
        # self-defined, in format: prefix/suffix_tag
        prefix = tag.split('/')[0]
        suffix_tag = '/'.join(tag.split('/')[1:])
        if prefix == 'image':
            self.writer.add_image(suffix_tag, record, global_step)
            return
        if prefix == 'figure':
            self.writer.add_figure(suffix_tag, record, global_step)
            return
        if prefix == 'histogram':
            self.writer.add_histogram(suffix_tag, record, global_step)
            return
        if prefix == 'scalar':
            self.writer.add_scalar(suffix_tag, record, global_step)
            return

        if isinstance(record, str):
            self.writer.add_text(tag, record, global_step)
            return

        if record.size > 1:
            self.writer.add_image(tag, record, global_step)
        else:
            self.writer.add_scalar(tag, record, global_step)

    @master_only
    def log(self, runner):
        for var in runner.log_buffer.output:
            if var in ['time', 'data_time']:
                continue
            tag = var
            record = runner.log_buffer.output[var]

            global_step = runner.epoch

            # for example, loss will be log as iteration
            if isinstance(self.register_logWithIter_keyword, (tuple, list)):
                for keyword in self.register_logWithIter_keyword:
                    if var.find(keyword) > -1:
                        global_step = runner.iter

            global_step = global_step + 1

            if isinstance(record, (list, tuple)):
                for idx, rec in enumerate(record):
                    tag = var + '/' + '{}'.format(idx)
                    self.single_log(tag, rec, global_step)
            else:
                self.single_log(tag, record, global_step)

    @master_only
    def after_run(self, runner):
        self.writer.close()