import numpy as np
import pandas as pd
import pandas.compat as compat


class TablePlotter(object):
    """
    Layout some DataFrames in vertical/horizontal layout for explanation.
    Used in merging.rst
    """

    def __init__(self, cell_width=0.37, cell_height=0.25, font_size=7.5):
        self.cell_width = cell_width
        self.cell_height = cell_height
        self.font_size = font_size

    def _shape(self, df):
        """
        Calculate table chape considering index levels.
        """

        row, col = df.shape
        return row + df.columns.nlevels, col + df.index.nlevels

    def _get_cells(self, left, right, vertical):
        """
        Calculate appropriate figure size based on left and right data.
        """

        if vertical:
            # calcurate required number of cells
            vcells = max(sum([self._shape(l)[0] for l in left]),
                         self._shape(right)[0])
            hcells = (max([self._shape(l)[1] for l in left]) +
                      self._shape(right)[1])
        else:
            vcells = max([self._shape(l)[0] for l in left] +
                         [self._shape(right)[0]])
            hcells = sum([self._shape(l)[1] for l in left] +
                         [self._shape(right)[1]])
        return hcells, vcells

    def plot(self, left, right, labels=None, vertical=True):
        """
        Plot left / right DataFrames in specified layout.

        Parameters
        ----------
        left : list of DataFrames before operation is applied
        right : DataFrame of operation result
        labels : list of str to be drawn as titles of left DataFrames
        vertical : bool
            If True, use vertical layout. If False, use horizontal layout.
        """
        import matplotlib.pyplot as plt
        import matplotlib.gridspec as gridspec

        if not isinstance(left, list):
            left = [left]
        left = [self._conv(l) for l in left]
        right = self._conv(right)

        hcells, vcells = self._get_cells(left, right, vertical)

        if vertical:
            figsize = self.cell_width * hcells, self.cell_height * vcells
        else:
            # include margin for titles
            figsize = self.cell_width * hcells, self.cell_height * vcells
        fig = plt.figure(figsize=figsize)

        if vertical:
            gs = gridspec.GridSpec(len(left), hcells)
            # left
            max_left_cols = max([self._shape(l)[1] for l in left])
            max_left_rows = max([self._shape(l)[0] for l in left])
            for i, (l, label) in enumerate(zip(left, labels)):
                ax = fig.add_subplot(gs[i, 0:max_left_cols])
                self._make_table(ax, l, title=label,
                                 height=1.0 / max_left_rows)
            # right
            ax = plt.subplot(gs[:, max_left_cols:])
            self._make_table(ax, right, title='Result', height=1.05 / vcells)
            fig.subplots_adjust(top=0.9, bottom=0.05, left=0.05, right=0.95)
        else:
            max_rows = max([self._shape(df)[0] for df in left + [right]])
            height = 1.0 / np.max(max_rows)
            gs = gridspec.GridSpec(1, hcells)
            # left
            i = 0
            for l, label in zip(left, labels):
                sp = self._shape(l)
                ax = fig.add_subplot(gs[0, i:i + sp[1]])
                self._make_table(ax, l, title=label, height=height)
                i += sp[1]
            # right
            ax = plt.subplot(gs[0, i:])
            self._make_table(ax, right, title='Result', height=height)
            fig.subplots_adjust(top=0.85, bottom=0.05, left=0.05, right=0.95)

        return fig

    def _conv(self, data):
        """Convert each input to appropriate for table outplot"""
        if isinstance(data, pd.Series):
            if data.name is None:
                data = data.to_frame(name='')
            else:
                data = data.to_frame()
        data = data.fillna('NaN')
        return data

    def _insert_index(self, data):
        # insert is destructive
        data = data.copy()
        idx_nlevels = data.index.nlevels
        if idx_nlevels == 1:
            data.insert(0, 'Index', data.index)
        else:
            for i in range(idx_nlevels):
                data.insert(i, 'Index{0}'.format(i),
                            data.index._get_level_values(i))

        col_nlevels = data.columns.nlevels
        if col_nlevels > 1:
            col = data.columns._get_level_values(0)
            values = [data.columns._get_level_values(i).values
                      for i in range(1, col_nlevels)]
            col_df = pd.DataFrame(values)
            data.columns = col_df.columns
            data = pd.concat([col_df, data])
            data.columns = col
        return data

    def _make_table(self, ax, df, title, height=None):
        if df is None:
            ax.set_visible(False)
            return

        import pandas.plotting as plotting

        idx_nlevels = df.index.nlevels
        col_nlevels = df.columns.nlevels
        # must be convert here to get index levels for colorization
        df = self._insert_index(df)
        tb = plotting.table(ax, df, loc=9)
        tb.set_fontsize(self.font_size)

        if height is None:
            height = 1.0 / (len(df) + 1)

        props = tb.properties()
        for (r, c), cell in compat.iteritems(props['celld']):
            if c == -1:
                cell.set_visible(False)
            elif r < col_nlevels and c < idx_nlevels:
                cell.set_visible(False)
            elif r < col_nlevels or c < idx_nlevels:
                cell.set_facecolor('#AAAAAA')
            cell.set_height(height)

        ax.set_title(title, size=self.font_size)
        ax.axis('off')


if __name__ == "__main__":
    import matplotlib.pyplot as plt

    p = TablePlotter()

    df1 = pd.DataFrame({'A': [10, 11, 12],
                        'B': [20, 21, 22],
                        'C': [30, 31, 32]})
    df2 = pd.DataFrame({'A': [10, 12],
                        'C': [30, 32]})

    p.plot([df1, df2], pd.concat([df1, df2]),
           labels=['df1', 'df2'], vertical=True)
    plt.show()

    df3 = pd.DataFrame({'X': [10, 12],
                        'Z': [30, 32]})

    p.plot([df1, df3], pd.concat([df1, df3], axis=1),
           labels=['df1', 'df2'], vertical=False)
    plt.show()

    idx = pd.MultiIndex.from_tuples([(1, 'A'), (1, 'B'), (1, 'C'),
                                     (2, 'A'), (2, 'B'), (2, 'C')])
    col = pd.MultiIndex.from_tuples([(1, 'A'), (1, 'B')])
    df3 = pd.DataFrame({'v1': [1, 2, 3, 4, 5, 6],
                        'v2': [5, 6, 7, 8, 9, 10]},
                       index=idx)
    df3.columns = col
    p.plot(df3, df3, labels=['df3'])
    plt.show()