#!/usr/bin/env python

"""
multi-step prediction using E2C
results are kind of poor...
"""

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os
from data.plane_data2 import PlaneData, get_params

A=B=40

x_dim,u_dim,T=get_params()
z_dim=2 # latent space dimensionality
eps=1e-9 # numerical stability

def orthogonal_initializer(scale = 1.1):
  """
  From Lasagne and Keras. Reference: Saxe et al., http://arxiv.org/abs/1312.6120
  """
  def _initializer(shape, dtype=tf.float32):
    flat_shape = (shape[0], np.prod(shape[1:]))
    a = np.random.normal(0.0, 1.0, flat_shape)
    u, _, v = np.linalg.svd(a, full_matrices=False)
    # pick the one with the correct shape
    q = u if u.shape == flat_shape else v
    q = q.reshape(shape)
    print('Warning -- You have opted to use the orthogonal_initializer function')
    return tf.constant(scale * q[:shape[0], :shape[1]], dtype=tf.float32)
  return _initializer

class NormalDistribution(object):
  """
  Represents a multivariate normal distribution parameterized by
  N(mu,Cov). If cov. matrix is diagonal, Cov=(sigma).^2. Otherwise,
  Cov=A*(sigma).^2*A', where A = (I+v*r^T).
  """
  def __init__(self,mu,sigma,logsigma,v=None,r=None):
    self.mu=mu
    self.sigma=sigma # either stdev diagonal itself, or stdev diagonal from decomposition
    self.logsigma=logsigma
    dim=mu.get_shape()
    if v is None:
      v=tf.constant(0.,shape=dim)
    if r is None:
      r=tf.constant(0.,shape=dim)
    self.v=v
    self.r=r

def linear(x,output_dim):
  w=tf.get_variable("w", [x.get_shape()[1], output_dim], initializer=orthogonal_initializer(1.1))
  b=tf.get_variable("b", [output_dim], initializer=tf.constant_initializer(0.0))
  return tf.matmul(x,w)+b

def ReLU(x,output_dim, scope):
  # helper function for implementing stacked ReLU layers
  with tf.variable_scope(scope):
    return tf.nn.relu(linear(x,output_dim))

def encode(x,share=None):
  with tf.variable_scope("encoder",reuse=share):
    for l in range(3):
      x=ReLU(x,150,"aggregate_loss"+str(l))
    return linear(x,2*z_dim)

def KLGaussian(Q,N):
  # Q, N are instances of NormalDistribution
  # implements KL Divergence term KL(N0,N1) derived in Appendix A.1
  # Q ~ Normal(mu,A*sigma*A^T), N ~ Normal(mu,sigma_1)
  # returns scalar divergence, measured in nats (information units under log rather than log2), shape= batch x 1
  sum=lambda x: tf.reduce_sum(x,1) # convenience fn for summing over features (columns)
  k=float(Q.mu.get_shape()[1].value) # dimension of distribution
  mu0,v,r,mu1=Q.mu,Q.v,Q.r,N.mu
  s02,s12=tf.square(Q.sigma),tf.square(N.sigma)+eps
  #vr=sum(v*r)
  a=sum(s02*(1.+2.*v*r)/s12) + sum(tf.square(v)/s12)*sum(tf.square(r)*s02) # trace term
  b=sum(tf.square(mu1-mu0)/s12) # difference-of-means term
  c=2.*(sum(N.logsigma-Q.logsigma) - tf.log(1.+sum(v*r))) # ratio-of-determinants term. 
  return 0.5*(a+b-k+c)#, a, b, c

def sampleNormal(mu,sigma):
  # diagonal stdev
  n01=tf.random_normal(sigma.get_shape(), mean=0, stddev=1)
  return mu+sigma*n01

def sampleQ_phi(h_enc,share=None):
  with tf.variable_scope("sampleQ_phi",reuse=share):
    mu,log_sigma=tf.split(1,2,linear(h_enc,z_dim*2)) # diagonal stdev values
    sigma=tf.exp(log_sigma)
    return sampleNormal(mu,sigma), NormalDistribution(mu, sigma, log_sigma)

def transition(h,share=None):
  # compute A,B,o linearization matrices
  with tf.variable_scope("trans",reuse=share):
    for l in range(2):
      h=ReLU(h,100,"aggregate_loss"+str(l))
    with tf.variable_scope("A"):
      v,r=tf.split(1,2,linear(h,z_dim*2))
      v1=tf.expand_dims(v,-1) # (batch, z_dim, 1)
      rT=tf.expand_dims(r,1) # batch, 1, z_dim
      I=tf.diag([1.]*z_dim)
      A=(I+tf.batch_matmul(v1,rT)) # (z_dim, z_dim) + (batch, z_dim, 1)*(batch, 1, z_dim) (I is broadcasted) 
    with tf.variable_scope("B"):
      B=linear(h,z_dim*u_dim)
      B=tf.reshape(B,[-1,z_dim,u_dim])
    with tf.variable_scope("o"):
      o=linear(h,z_dim)
    return A,B,o,v,r

def sampleQ_psi(z,u,Q_phi,share=None):
  A,B,o,v,r=transition(z,share)
  with tf.variable_scope("sampleQ_psi"):
    mu_t=tf.expand_dims(Q_phi.mu,-1) # batch,z_dim,1
    Amu=tf.squeeze(tf.batch_matmul(A,mu_t), [-1])
    u=tf.expand_dims(u,-1) # batch,u_dim,1
    Bu=tf.squeeze(tf.batch_matmul(B,u),[-1])
    Q_psi=NormalDistribution(Amu+Bu+o,Q_phi.sigma,Q_phi.logsigma, v, r)
    # the actual z_next sample is generated by deterministically transforming z_t
    z=tf.expand_dims(z,-1)
    Az=tf.squeeze(tf.batch_matmul(A,z),[-1])
    z_next=Az+Bu+o
    return z_next,Q_psi#,(A,B,o,v,r) # debugging

def decode(z,share=None):
  with tf.variable_scope("decoder",reuse=share):
    for l in range(2):
      z=ReLU(z,200,"aggregate_loss"+str(l))
    return linear(z,x_dim)

def binary_crossentropy(t,o):
    return t*tf.log(o+eps) + (1.0-t)*tf.log(1.0-o+eps)

def recons_loss(x,x_recons):
  with tf.variable_scope("Lx"):
    ll=tf.reduce_sum(binary_crossentropy(x,x_recons),1) # sum across features
    return -ll # negative log-likelihood

def latent_loss(Q):
  with tf.variable_scope("Lz"):
    mu2=tf.square(Q.mu)
    sigma2=tf.square(Q.sigma)
    # negative of the upper bound of posterior
    return -0.5*tf.reduce_sum(1+2*Q.logsigma-mu2-sigma2,1)

def sampleP_theta(h_dec,share=None):
  # sample x from bernoulli distribution with means p=W(h_dec)
  with tf.variable_scope("P_theta",reuse=share):
    p=linear(h_dec,x_dim)
    return tf.sigmoid(p) # mean of bernoulli distribution

# BUILD NETWORK
batch_size=64


# modified
T=10 # sequence len

# inputs
xs=[tf.placeholder(tf.float32,[batch_size, x_dim])]*T
us=[tf.placeholder(tf.float32, [batch_size, u_dim])]*(T-1)

x_recons=[0]*(T-1) # we'latent_loss apply VAE recons to the last one in the sequence as well
x_predicts=[0]*(T-1)

# T-1 recons and prediction losses
losses=[0]*(T-1)
kl_losses=[0]*(T-1)
lbound_losses=[0]*(T-1)
print(T)
# arrays for storing losses
DO_SHARE=None
for t in range(T-1):
  if t==0:
    x=xs[0] # we are given the first one
  else:
    x=x_predicts[t-1] # re-use our hallucinations to predict
  #x=xs[t]
  h_enc=encode(x,share=DO_SHARE)
  z,Q_phi=sampleQ_phi(h_enc,share=DO_SHARE)
  # reconstitute x_t
  h_dec=decode(z,share=DO_SHARE)
  x_recons[t]=sampleP_theta(h_dec,share=DO_SHARE)
  # linearized dynamics, predict
  # compute linearized dynamics, predict new latent state
  z_predict,Q_psi=sampleQ_psi(z,us[t],Q_phi,share=DO_SHARE)
  DO_SHARE=True # share from now on
  h_dec_predict=decode(z_predict,share=DO_SHARE)
  x_predicts[t]=sampleP_theta(h_dec_predict,share=DO_SHARE)
  # encode next true data
  h_enc_next=encode(xs[t+1],share=DO_SHARE)
  z_next,Q_phi_next=sampleQ_phi(h_enc_next,share=DO_SHARE)

  # add loss terms
  L_x=recons_loss(xs[t],x_recons[t])
  L_x_next=recons_loss(xs[t+1],x_predicts[t])
  L_z=latent_loss(Q_phi)
  L_bound=L_x+L_x_next+L_z
  KL=KLGaussian(Q_psi,Q_phi_next)
  lambd=1.0
  
  lbound_losses[t]=tf.reduce_mean(L_bound)
  kl_losses[t]=tf.reduce_mean(lambd*KL)
  losses[t]=tf.reduce_mean(L_bound+lambd*KL) # average loss over minibatch to single scalar


loss=tf.add_n(losses) # sum all the losses to compute the sequence loss

for v in tf.all_variables():
    print("%s : %s" % (v.name, v.get_shape()))

with tf.variable_scope("Optimizer"):
  learning_rate=1e-4
  optimizer=tf.train.AdamOptimizer(learning_rate, beta1=0.1, beta2=0.1) # beta2=0.1
  train_op=optimizer.minimize(loss)

saver = tf.train.Saver(max_to_keep=100) # keep all checkpoint files

ckpt_file="/ltmp/e2c-plane"

# summaries
tf.scalar_summary("loss", loss)

# monitor losses in case behavior of Lz_8 differs drastically from Lz_0 prediction
#for t in range(T-1):
#  tf.scalar_summary("Lb_%d" % (t), lbound_losses[t])
#  tf.scalar_summary("Lz_%d" % (t), kl_losses[t])
#  tf.scalar_summary("loss_%d" % (t), losses[t])

# tf.scalar_summary("L_x", tf.reduce_mean(L_x))
# tf.scalar_summary("L_x_next", tf.reduce_mean(L_x_next))
# tf.scalar_summary("L_z", tf.reduce_mean(L_z))
# tf.scalar_summary("KL",tf.reduce_mean(KL))
all_summaries = tf.merge_all_summaries()

# TRAIN
if __name__=="__main__":
  init=tf.initialize_all_variables()
  sess=tf.InteractiveSession()
  sess.run(init)
  # WRITER
  writer = tf.train.SummaryWriter("/ltmp/e2c", sess.graph_def)

  dataset=PlaneData("data/plane1.npz","data/env1.png")
  dataset.initialize()

  # resume training
  saver.restore(sess, "/ltmp/e2c-plane-single1.ckpt")

  # test to make sure samples are actually trajectories

  def getimgs(x):
    padsize=1
    padval=.5
    ph=B+2*padsize
    pw=A+2*padsize
    img=np.ones((ph,len(x)*pw))*padval
    for t in range(len(x)):
      startc=t*pw+padsize
      img[padsize:padsize+B, startc:startc+A]=x[t][20,:].reshape((A,B))
    return img
  (x_vals,u_vals)=dataset.sample_seq(batch_size,T)
  plt.matshow(getimgs(x_vals),cmap=plt.cm.gray,vmin=0,vmax=1)
  plt.show()

  train_iters=2e5 # 5K iters
  for i in range(int(train_iters)):
    (x_vals,u_vals)=dataset.sample_seq(batch_size,T,replace=False)
    feed_dict={}
    for t in range(T):
      feed_dict[xs[t]] = x_vals[t]
    for t in range(T-1):
      feed_dict[us[t]] = u_vals[t]

    results=sess.run([loss,all_summaries,train_op],feed_dict)
    if i%1000==0:
      print("iter=%d : Loss: %f" % (i,results[0]))
      if i>2000:
        writer.add_summary(results[1], i)
    if (i%100==0 and i < 1000) or (i % 1000 == 0):
      saver.save(sess,ckpt_file+"-%05d"%(i)+".ckpt")

  # save variables
  #print("Model saved in file: %s" % saver.save(sess,ckpt_file+".ckpt"))

  sess.close()