import tensorflow as tf
import numpy as np


class Logger(object):
    class Loggable(object):
        def __init__(self, summary, console_str):
            self.summary = summary
            self.console_str = console_str

        def to_tensorboard(self, filewriter, itr):
            filewriter.add_summary(self.summary, global_step=itr)
            return self

        def to_console(self, itr, append=''):
            print('{}: {} {}'.format(itr, self.console_str, append))

    class Numpy1DFormatter(object):
        def __init__(self, wrapper_str='{}', max_elements=None, precision=3, sep=','):
            self._array2string = lambda _arr: wrapper_str.format(
                    np.array2string(
                            _arr.flatten()[:max_elements],
                            precision=precision, separator=sep))

        def format(self, arr):
            return self._array2string(arr)


    def __init__(self):
        self._summaries = []
        self._summary = None  # merged
        self._sess = None
        self._console_format_strs = []
        self._console_tensors = []
        self._final = False

    def add_summaries(self, summaries):
        assert isinstance(summaries, list)
        self._summaries.extend(summaries)

    def add_console_tensor(self, formatter, tensor):
        """
        :param formatter: object responding to -format(), e.g., a string with {}, or Numpy1DFormatter
        :param tensor:
        :return:
        """
        self._console_format_strs.append(formatter)
        self._console_tensors.append(tensor)

    def finalize_with_sess(self, sess):
        assert not self._final
        self._final = True
        self._sess = sess
        assert len(self._console_format_strs) == len(self._console_tensors)
        self._summary = tf.summary.merge(self._summaries)

    def log(self):
        summary, console_tensor = self._sess.run([self._summary, self._console_tensors])
        return Logger.Loggable(
            summary=summary,
            console_str=', '.join(
                format_str.format(tensor)
                for format_str, tensor in zip(self._console_format_strs, console_tensor)))