# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf

def _plot_item(W, name, full_name, nspaces):
  plt.figure()
  if W.shape == ():
    print(name, ": ", W)
  elif W.shape[0] == 1:
    plt.stem(W.T)
    plt.title(full_name)
  elif W.shape[1] == 1:
    plt.stem(W)
    plt.title(full_name)
  else:
    plt.imshow(np.abs(W), interpolation='nearest', cmap='jet');
    plt.colorbar()
    plt.title(full_name)


def all_plot(d, full_name="", exclude="", nspaces=0):
  """Recursively plot all the LFADS model parameters in the nested
  dictionary."""
  for k, v in d.iteritems():
    this_name = full_name+"/"+k
    if isinstance(v, dict):
      all_plot(v, full_name=this_name, exclude=exclude, nspaces=nspaces+4)
    else:
      if exclude == "" or exclude not in this_name:
        _plot_item(v, name=k, full_name=full_name+"/"+k, nspaces=nspaces+4)


def plot_priors():
  g0s_prior_mean_bxn = train_modelvals['prior_g0_mean']
  g0s_prior_var_bxn = train_modelvals['prior_g0_var']
  g0s_post_mean_bxn = train_modelvals['posterior_g0_mean']
  g0s_post_var_bxn = train_modelvals['posterior_g0_var']

  plt.figure(figsize=(10,4), tight_layout=True);
  plt.subplot(1,2,1)
  plt.hist(g0s_post_mean_bxn.flatten(), bins=20, color='b');
  plt.hist(g0s_prior_mean_bxn.flatten(), bins=20, color='g');

  plt.title('Histogram of Prior/Posterior Mean Values')
  plt.subplot(1,2,2)
  plt.hist((g0s_post_var_bxn.flatten()), bins=20, color='b');
  plt.hist((g0s_prior_var_bxn.flatten()), bins=20, color='g');
  plt.title('Histogram of Prior/Posterior Log Variance Values')

  plt.figure(figsize=(10,10), tight_layout=True)
  plt.subplot(2,2,1)
  plt.imshow(g0s_prior_mean_bxn.T, interpolation='nearest', cmap='jet')
  plt.colorbar(fraction=0.025, pad=0.04)
  plt.title('Prior g0 means')

  plt.subplot(2,2,2)
  plt.imshow(g0s_post_mean_bxn.T, interpolation='nearest', cmap='jet')
  plt.colorbar(fraction=0.025, pad=0.04)
  plt.title('Posterior g0 means');

  plt.subplot(2,2,3)
  plt.imshow(g0s_prior_var_bxn.T, interpolation='nearest', cmap='jet')
  plt.colorbar(fraction=0.025, pad=0.04)
  plt.title('Prior g0 variance Values')

  plt.subplot(2,2,4)
  plt.imshow(g0s_post_var_bxn.T, interpolation='nearest', cmap='jet')
  plt.colorbar(fraction=0.025, pad=0.04)
  plt.title('Posterior g0 variance Values')

  plt.figure(figsize=(10,5))
  plt.stem(np.sort(np.log(g0s_post_mean_bxn.std(axis=0))));
  plt.title('Log standard deviation of h0 means');


def plot_time_series(vals_bxtxn, bidx=None, n_to_plot=np.inf, scale=1.0,
                     color='r', title=None):

  if bidx is None:
    vals_txn = np.mean(vals_bxtxn, axis=0)
  else:
    vals_txn = vals_bxtxn[bidx,:,:]

  T, N = vals_txn.shape
  if n_to_plot > N:
    n_to_plot = N

  plt.plot(vals_txn[:,0:n_to_plot] + scale*np.array(range(n_to_plot)),
           color=color, lw=1.0)
  plt.axis('tight')
  if title:
    plt.title(title)


def plot_lfads_timeseries(data_bxtxn, model_vals, ext_input_bxtxi=None,
                          truth_bxtxn=None, bidx=None, output_dist="poisson",
                          conversion_factor=1.0, subplot_cidx=0,
                          col_title=None):

  n_to_plot = 10
  scale = 1.0
  nrows = 7
  plt.subplot(nrows,2,1+subplot_cidx)

  if output_dist == 'poisson':
    rates = means = conversion_factor * model_vals['output_dist_params']
    plot_time_series(rates, bidx, n_to_plot=n_to_plot, scale=scale,
                     title=col_title + " rates (LFADS - red, Truth - black)")
  elif output_dist == 'gaussian':
    means_vars = model_vals['output_dist_params']
    means, vars = np.split(means_vars,2, axis=2) # bxtxn
    stds = np.sqrt(vars)
    plot_time_series(means, bidx, n_to_plot=n_to_plot, scale=scale,
                     title=col_title + " means (LFADS - red, Truth - black)")
    plot_time_series(means+stds, bidx, n_to_plot=n_to_plot, scale=scale,
                     color='c')
    plot_time_series(means-stds, bidx, n_to_plot=n_to_plot, scale=scale,
                     color='c')
  else:
    assert 'NIY'


  if truth_bxtxn is not None:
    plot_time_series(truth_bxtxn, bidx, n_to_plot=n_to_plot, color='k',
                     scale=scale)

  input_title = ""
  if "controller_outputs" in model_vals.keys():
    input_title += " Controller Output"
    plt.subplot(nrows,2,3+subplot_cidx)
    u_t = model_vals['controller_outputs'][0:-1]
    plot_time_series(u_t, bidx, n_to_plot=n_to_plot, color='c', scale=1.0,
                     title=col_title + input_title)

  if ext_input_bxtxi is not None:
    input_title += " External Input"
    plot_time_series(ext_input_bxtxi, n_to_plot=n_to_plot, color='b',
                     scale=scale, title=col_title + input_title)

  plt.subplot(nrows,2,5+subplot_cidx)
  plot_time_series(means, bidx,
                   n_to_plot=n_to_plot, scale=1.0,
                   title=col_title + " Spikes (LFADS - red, Spikes - black)")
  plot_time_series(data_bxtxn, bidx, n_to_plot=n_to_plot, color='k', scale=1.0)

  plt.subplot(nrows,2,7+subplot_cidx)
  plot_time_series(model_vals['factors'], bidx, n_to_plot=n_to_plot, color='b',
                   scale=2.0, title=col_title + " Factors")

  plt.subplot(nrows,2,9+subplot_cidx)
  plot_time_series(model_vals['gen_states'], bidx, n_to_plot=n_to_plot,
                   color='g', scale=1.0, title=col_title + " Generator State")

  if bidx is not None:
    data_nxt = data_bxtxn[bidx,:,:].T
    params_nxt = model_vals['output_dist_params'][bidx,:,:].T
  else:
    data_nxt = np.mean(data_bxtxn, axis=0).T
    params_nxt = np.mean(model_vals['output_dist_params'], axis=0).T
  if output_dist == 'poisson':
    means_nxt = params_nxt
  elif output_dist == 'gaussian': # (means+vars) x time
    means_nxt = np.vsplit(params_nxt,2)[0] # get means
  else:
    assert "NIY"

  plt.subplot(nrows,2,11+subplot_cidx)
  plt.imshow(data_nxt, aspect='auto', interpolation='nearest')
  plt.title(col_title + ' Data')

  plt.subplot(nrows,2,13+subplot_cidx)
  plt.imshow(means_nxt, aspect='auto', interpolation='nearest')
  plt.title(col_title + ' Means')


def plot_lfads(train_bxtxd, train_model_vals,
               train_ext_input_bxtxi=None, train_truth_bxtxd=None,
               valid_bxtxd=None, valid_model_vals=None,
               valid_ext_input_bxtxi=None, valid_truth_bxtxd=None,
               bidx=None, cf=1.0, output_dist='poisson'):

  # Plotting
  f = plt.figure(figsize=(18,20), tight_layout=True)
  plot_lfads_timeseries(train_bxtxd, train_model_vals,
                        train_ext_input_bxtxi,
                        truth_bxtxn=train_truth_bxtxd,
                        conversion_factor=cf, bidx=bidx,
                        output_dist=output_dist, col_title='Train')
  plot_lfads_timeseries(valid_bxtxd, valid_model_vals,
                        valid_ext_input_bxtxi,
                        truth_bxtxn=valid_truth_bxtxd,
                        conversion_factor=cf, bidx=bidx,
                        output_dist=output_dist,
                        subplot_cidx=1, col_title='Valid')

  # Convert from figure to an numpy array width x height x 3 (last for RGB)
  f.canvas.draw()
  data = np.fromstring(f.canvas.tostring_rgb(), dtype=np.uint8, sep='')
  data_wxhx3 = data.reshape(f.canvas.get_width_height()[::-1] + (3,))
  plt.close()

  return data_wxhx3