import glob
import numpy as np
import os
from PIL import Image

import data
import models
import utils
from utils.visualizer import Visualizer

def save_images(prediction, gt, latent, save_dir, step):
  pose, components = latent['pose'].data.cpu(), latent['components'].data.cpu()
  batch_size, n_frames_total = prediction.shape[:2]
  n_components = components.shape[2]
  for i in range(batch_size):
    filename = '{:05d}.png'.format(step)
    y = gt[i, ...]
    rows = [y]
    if n_components > 1:
      for j in range(n_components):
        p = pose[i, :, j, :]
        comp = components[i, :, j, ...]
        if pose.size(-1) == 3:
          comp = utils.draw_components(comp, p)
        rows.append(utils.to_numpy(comp))
    x = prediction[i, ...]
    rows.append(x)
    # Make a grid of 4 x n_frames_total images
    image = np.concatenate(rows, axis=2).squeeze(1)
    image = np.concatenate([image[i] for i in range(n_frames_total)], axis=1)
    image = (image * 255).astype(np.uint8)
    # Save image
    Image.fromarray(image).save(os.path.join(save_dir, filename))
    step += 1

  return step

def evaluate(opt, dloader, model, use_saved_file=False):
  # Visualizer
  if hasattr(opt, 'save_visuals') and opt.save_visuals:
    vis = Visualizer(os.path.join(opt.ckpt_path, 'tb_test'))
  else:
    opt.save_visuals = False

  model.setup(is_train=False)
  metric = utils.Metrics()
  results = {}

  if hasattr(opt, 'save_all_results') and opt.save_all_results:
    save_dir = os.path.join(opt.ckpt_path, 'results')
    os.makedirs(save_dir, exist_ok=True)
  else:
    opt.save_all_results = False

  # Hacky
  is_bouncing_balls = ('bouncing_balls' in opt.dset_name) and opt.n_components == 4
  if is_bouncing_balls:
    dloader.dataset.return_positions = True
    saved_positions = os.path.join(opt.ckpt_path, 'positions.npy') if use_saved_file else ''
    velocity_metric = utils.VelocityMetrics(saved_positions)

  count = 0
  for step, data in enumerate(dloader):
    if not is_bouncing_balls:
      input, gt = data
    else:
      input, gt, positions = data
    output, latent = model.test(input, gt)
    pred = output[:, opt.n_frames_input:, ...]
    metric.update(gt, pred)

    if opt.save_all_results:
      gt = np.concatenate([input.numpy(), gt.numpy()], axis=1)
      prediction = utils.to_numpy(output)
      count = save_images(prediction, gt, latent, save_dir, count)

    if is_bouncing_balls:
      # Calculate position and velocity from pose
      pose = latent['pose'].data.cpu()
      velocity_metric.update(positions, pose, opt.n_frames_input)

    if (step + 1) % opt.log_every == 0:
      print('{}/{}'.format(step + 1, len(dloader)))
      if opt.save_visuals:
        vis.add_images(model.get_visuals(), step, prefix='test')

  # BCE, MSE
  results.update(metric.get_scores())

  if is_bouncing_balls:
    # Don't break the original code
    dloader.dataset.return_positions = False
    results.update(velocity_metric.get_scores())

  return results

def main():
  opt, logger, vis = utils.build(is_train=False)

  dloader = data.get_data_loader(opt)
  print('Val dataset: {}'.format(len(dloader.dataset)))
  model = models.get_model(opt)

  for epoch in opt.which_epochs:
    # Load checkpoint
    if epoch == -1:
      # Find the latest checkpoint
      checkpoints = glob.glob(os.path.join(opt.ckpt_path, 'net*.pth'))
      assert len(checkpoints) > 0
      epochs = [int(filename.split('_')[-1][:-4]) for filename in checkpoints]
      epoch = max(epochs)
    logger.print('Loading checkpoints from {}, epoch {}'.format(opt.ckpt_path, epoch))
    model.load(opt.ckpt_path, epoch)

    results = evaluate(opt, dloader, model)
    for metric in results:
      logger.print('{}: {}'.format(metric, results[metric]))

if __name__ == '__main__':
  main()