'''Provides an API for generating Event protocol buffers.'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os
import re
import six
import time

from caffe2.python import cnn, core
from caffe2.proto import caffe2_pb2

from c2board.src import event_pb2
from c2board.src import summary_pb2
from c2board.src import graph_pb2
from c2board.event_file_writer import EventFileWriter
from c2board.graph import model_to_graph, nets_to_graph, protos_to_graph
from c2board.x2num import make_nps
import c2board.summary as summary


class FileWriter(object):
    '''Write `Summary` protocol buffers to event files.'''
    def __init__(self, 
                logdir,
                max_queue=10,
                flush_secs=120):
        '''Create a `SummaryWriter` and an event file.'''
        self._event_writer = EventFileWriter(logdir, max_queue, flush_secs)
        self._closed = False

    def get_logdir(self):
        '''Return the directory.'''
        return self._event_writer.get_logdir()

    def add_summary(self, summary, global_step=None):
        '''Add a `Summary` protocol buffer to the event file.'''
        if isinstance(summary, bytes):
            summ = summary_pb2.Summary()
            summ.ParseFromString(summary)
            summary = summ
        event = event_pb2.Event(summary=summary)
        self._add_event(event, global_step)

    def add_graph(self, graph):
        '''Add a `Graph` protocol buffer to the event file.'''
        event = event_pb2.Event(graph_def=graph.SerializeToString())
        self._add_event(event, None)

    def _add_event(self, event, step):
        '''General function to add the event.'''
        event.wall_time = time.time()
        if step is not None:
            event.step = int(step)
        self._event_writer.add_event(event)

    def flush(self):
        '''Flush the event writer.'''
        self._event_writer.flush()

    def close(self):
        '''Close the file writer.'''
        self._event_writer.close()
        self._closed = True


class SummaryWriter(object):
    '''Write `Summary` directly to event files.'''
    def __init__(self, log_dir=None, tag='default',bins=100):
        '''Initialize the summary writer.'''
        if not log_dir:
            # Default: log to runs/
            log_dir = os.path.join('runs', tag)
        self._file_writer = FileWriter(logdir=log_dir)
        self.histogram_dict = {}
        self.histogram_keys = []
        self.histogram_values = []
        self.default_bins = bins
        self.image_dict = {}
        self.rois_dict = {}
        self.mem_dict = {}
        self.text_dir = None
        self.text_tags = []
        self._track_blob_names = {}
        self._reversed_block_names = {}

    def append_histogram(self, name):
        '''Append the name of the blobs to a list for histograms.'''
        self.histogram_dict[name] = name

    def append_image(self, name):
        '''Append the name of the blobs to a list for images.'''
        self.image_dict[name] = name

    def append_image_boxes(self, im_name, box_name):
        self.histogram_dict[box_name] = box_name
        self.image_dict[im_name] = im_name
        self.rois_dict[im_name] = box_name

    def append_mem(self, name):
        self.mem_dict[name] = name

    def reverse_map(self):
        '''Reverse the map from the graph.'''
        for key, value in six.iteritems(self._track_blob_names):
            if value in self._reversed_block_names:
                self._reversed_block_names[value].append(key)
            else:
                self._reversed_block_names[value] = [key]

    def check_names(self):
        '''Make sure we do not double dump the blobs.'''
        assert len(self.histogram_dict) == len(set(self.histogram_dict)), \
                "ERROR: duplicate name to account histograms"
        assert len(self.image_dict) == len(set(self.image_dict)), \
                "ERROR: duplicate name to account images"
        assert len(self.mem_dict) == len(set(self.mem_dict)), \
                "ERROR: duplicate name to account memories"

    def replace_names(self, dictionary):
        '''Replace the names according to the graph.'''
        GPU = re.compile('gpu_[0-9]+/')

        for key in dictionary.keys():
            # Remove GPU information, assume it is data parallelism
            # TODO(xinleic): make it applicable to more general cases
            match = GPU.match(key).group()
            key0 = key.replace(match, 'gpu_0/')
            assert key0 in self._reversed_block_names, \
                             "ERROR: {} not found in blob names!".format(key0)
            values = self._reversed_block_names[key0]
            # Hack, just get the common ones
            value = summary.clean_tag(match + os.path.commonprefix(values))
            dictionary[key] = value

    def sort_out_names(self):
        '''Wrapper function to replace names.'''
        if self._track_blob_names:
            if not self._reversed_block_names:
                self.reverse_map()

            self.replace_names(self.histogram_dict)
            for key, value in six.iteritems(self.histogram_dict):
                self.histogram_keys.append(key)
                self.histogram_values.append(value)
            self.replace_names(self.image_dict)
            self.replace_names(self.mem_dict)

    def _add_scalar(self, tag, scalar_value, global_step):
        '''Add scalar data to summary.'''
        self._file_writer.add_summary(summary.scalar(tag, scalar_value), 
                                    global_step)

    def _add_histogram(self, tag, values, global_step):
        '''Add histogram to summary.'''
        self._file_writer.add_summary(summary.histogram(tag, values, self.default_bins), 
                                    global_step)

    def _add_histograms(self, global_step):
        '''Add multiple histograms to summary.'''
        values = make_nps(self.histogram_keys)
        for name, value in zip(self.histogram_values, values):
            self._file_writer.add_summary(summary.histogram_with_values(name, 
                                                                        value, 
                                                                        self.default_bins),
                                                                        global_step)

    def _add_image(self, tag, img_tensor, global_step, **kwargs):
        '''Add image data to summary.'''
        res = summary.image(tag, img_tensor, **kwargs)
        if isinstance(res, list):
            for r in res:
                self._file_writer.add_summary(r, global_step)
        else:
            self._file_writer.add_summary(res, global_step)

    def _add_mem(self, tag, mem_tensor, global_step, **kwargs):
        '''Add memory data to summary.'''
        res = summary.memory(tag, mem_tensor, **kwargs)
        if isinstance(res, list):
            for r in res:
                self._file_writer.add_summary(r, global_step)
        else:
            self._file_writer.add_summary(res, global_step)

    def _add_image_boxes(self, tag, img_tensor, box_tensor, global_step, **kwargs):
        '''Add image data to summary.'''
        res = summary.image_boxes(tag, img_tensor, box_tensor, **kwargs)
        if isinstance(res, list):
            for r in res:
                self._file_writer.add_summary(r, global_step)
        else:
            self._file_writer.add_summary(res, global_step)

    def _add_text(self, tag, text_string, global_step):
        '''Add text data to summary.'''
        self._file_writer.add_summary(summary.text(tag, text_string), global_step)
        if tag not in self.text_tags:
            self.text_tags.append(tag)
            if not self.text_dir:
                text_dir =os.path.join(self._file_writer.get_logdir(),
                                        'plugins',
                                        'tensorboard_text')
                os.makedirs(text_dir)
            with open(os.path.join(text_dir, 'tensors.json'), 'w') as fp:
                json.dump(self.text_tags, fp)

    def add_audio(self, 
                tag, 
                snd_tensor, 
                global_step, 
                sample_rate=44100):
        raise NotImplementedError

    def add_pr_curve(self, 
                    tag, 
                    labels, 
                    predictions, 
                    global_step, 
                    num_thresholds=127, 
                    weights=None):
        raise NotImplementedError

    def write_graph(self, model_or_nets_or_protos=None, **kwargs):
        '''Write graph to the summary.'''
        if isinstance(model_or_nets_or_protos, cnn.CNNModelHelper):
            current_graph, track_blob_names = model_to_graph(model_or_nets_or_protos, **kwargs)
        elif isinstance(model_or_nets_or_protos, list):
            if isinstance(model_or_nets_or_protos[0], core.Net):
                current_graph, track_blob_names = nets_to_graph(model_or_nets_or_protos, **kwargs)
            elif isinstance(model_or_nets_or_protos[0], caffe2_pb2.NetDef):
                current_graph, track_blob_names = protos_to_graph(model_or_nets_or_protos, **kwargs)
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError
        self._file_writer.add_graph(current_graph)
        self._track_blob_names = track_blob_names
        # Once the graph is built, one can just map the blobs
        self.check_names()
        self.sort_out_names()

    def write_scalars(self, dictionary, global_step):
        '''Write multiple scalars to summary.'''
        for key, value in six.iteritems(dictionary):
            self._add_scalar(key, value, global_step)

    def write_summaries(self, global_step):
        '''Write histogram and image summaries.'''
        # for key, value in six.iteritems(self.histogram_dict):
        #     self._add_histogram(value, key, global_step)
        self._add_histograms(global_step)
        if self.rois_dict:
            for im_name, box_name in six.iteritems(self.rois_dict):
                self._add_image_boxes(box_name, im_name, box_name, global_step)
        else:
            for key, value in six.iteritems(self.image_dict):
                self._add_image(value, key, global_step)
        for key, value in six.iteritems(self.mem_dict):
                self._add_mem(value, key, global_step)

    def close(self):
        '''Close the writers.'''
        if not self._file_writer._closed:
            self._file_writer.flush()
            self._file_writer.close()

    def __del__(self):
        self.close()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()