# Copyright 2018 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Basic oracle agent for StreetLearn."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app
from absl import flags
from absl import logging

import time
import numpy as np
import pygame

from streetlearn.python.environment import courier_game
from streetlearn.python.environment import default_config
from streetlearn.python.environment import streetlearn

FLAGS = flags.FLAGS
flags.DEFINE_integer('width', 400, 'Observation and map width.')
flags.DEFINE_integer('height', 400, 'Observation and map height.')
flags.DEFINE_integer('field_of_view', 60, 'Field of view.')
flags.DEFINE_integer('graph_zoom', 1, 'Zoom level.')
flags.DEFINE_float('horizontal_rot', 22.5, 'Horizontal rotation step (deg).')
flags.DEFINE_string('dataset_path', None, 'Dataset path.')
flags.DEFINE_string('start_pano', '',
                     'Pano at root of partial graph (default: full graph).')
flags.DEFINE_integer('graph_depth', 200, 'Depth of the pano graph.')
flags.DEFINE_integer('frame_cap', 1000, 'Number of frames / episode.')
flags.DEFINE_string('stats_path', None, 'Statistics path.')
flags.DEFINE_float('proportion_of_panos_with_coins', 0, 'Proportion of coins.')
flags.mark_flag_as_required('dataset_path')


TOL_BEARING = 30


def interleave(array, w, h):
  """Turn a planar RGB array into an interleaved one.

  Args:
    array: An array of bytes consisting the planar RGB image.
    w: Width of the image.
    h: Height of the image.
  Returns:
    An interleaved array of bytes shape shaped (h, w, 3).
  """
  arr = array.reshape(3, w * h)
  return np.ravel((arr[0], arr[1], arr[2]),
                  order='F').reshape(h, w, 3).swapaxes(0, 1)

def loop(env, screen):
  """Main loop of the oracle agent."""
  action = np.array([0, 0, 0, 0])
  action_spec = env.action_spec()
  sum_rewards = 0
  sum_rewards_at_goal = 0
  previous_goal_id = None
  seen_pano_ids = {}
  while True:
    observation = env.observation()
    view_image = interleave(observation['view_image'],
                            FLAGS.width, FLAGS.height)
    graph_image = interleave(observation['graph_image'],
                             FLAGS.width, FLAGS.height)
    screen_buffer = np.concatenate((view_image, graph_image), axis=1)
    pygame.surfarray.blit_array(screen, screen_buffer)
    pygame.display.update()

    for event in pygame.event.get():
      if (event.type == pygame.QUIT or
          (event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE)):
        return
      if event.type == pygame.KEYDOWN:
        if event.key == pygame.K_p:
          filename = time.strftime('oracle_agent_%Y%m%d_%H%M%S.bmp')
          pygame.image.save(screen, filename)

    # Take a step given the previous action and record the reward.
    _, reward, done, info = env.step(action)
    sum_rewards += reward
    if (reward > 0) and (info['current_goal_id'] is not previous_goal_id):
      sum_rewards_at_goal += reward
      seen_pano_ids = {}
    previous_goal_id = info['current_goal_id']
    if done:
      print('Episode reward: {}'.format(sum_rewards))
      if FLAGS.stats_path:
        with open(FLAGS.stats_path, 'a') as f:
          f.write(str(sum_rewards) + '\t' + str(sum_rewards_at_goal) + '\n')
      sum_rewards = 0
      sum_rewards_at_goal = 0

    # Determine the next pano and bearing to that pano.
    current_pano_id = info['current_pano_id']
    next_pano_id = info['next_pano_id']
    bearing = info['bearing_to_next_pano']
    logging.info('Current pano: %s, next pano %s at %f',
                 current_pano_id, next_pano_id, bearing)

    # Maintain the count of pano visits, in case the agent gets stuck.
    if current_pano_id in seen_pano_ids:
      seen_pano_ids[current_pano_id] += 1
    else:
      seen_pano_ids[current_pano_id] = 1

    # Bearing-based navigation.
    if bearing > TOL_BEARING:
      if bearing > TOL_BEARING + 2 * FLAGS.horizontal_rot:
        action = 3 * FLAGS.horizontal_rot * action_spec['horizontal_rotation']
      else:
        action = FLAGS.horizontal_rot * action_spec['horizontal_rotation']
    elif bearing < -TOL_BEARING:
      if bearing < -TOL_BEARING - 2 * FLAGS.horizontal_rot:
        action = -3 * FLAGS.horizontal_rot * action_spec['horizontal_rotation']
      else:
        action = -FLAGS.horizontal_rot * action_spec['horizontal_rotation']
    else:
      action = action_spec['move_forward']

      # Sometimes, two panos B and C are close to each other, which causes
      # cyclic loops: A -> C -> A -> C -> A... whereas agent wants to go A -> B.
      # There is a simple strategy to get out of that A - C loop: detect that A
      # has been visited a large number of times in the current trajectory, then
      # instead of moving forward A -> B and ending up in C, directly jump to B.
      # First, we check if the agent has spent more time in a pano than required
      # to make a full U-turn...
      if seen_pano_ids[current_pano_id] > (180.0 / FLAGS.horizontal_rot):
        # ... then we teleport to the desired location and turn randomly.
        logging.info('Teleporting from %s to %s', current_pano_id, next_pano_id)
        _ = env.goto(next_pano_id, np.random.randint(359))

def main(argv):
  config = {'width': FLAGS.width,
            'height': FLAGS.height,
            'field_of_view': FLAGS.field_of_view,
            'graph_width': FLAGS.width,
            'graph_height': FLAGS.height,
            'graph_zoom': FLAGS.graph_zoom,
            'goal_timeout': FLAGS.frame_cap,
            'frame_cap': FLAGS.frame_cap,
            'full_graph': (FLAGS.start_pano == ''),
            'start_pano': FLAGS.start_pano,
            'min_graph_depth': FLAGS.graph_depth,
            'max_graph_depth': FLAGS.graph_depth,
            'proportion_of_panos_with_coins':
                FLAGS.proportion_of_panos_with_coins,
            'action_spec': 'streetlearn_fast_rotate',
            'observations': ['view_image', 'graph_image', 'yaw', 'pitch']}
  config = default_config.ApplyDefaults(config)
  game = courier_game.CourierGame(config)
  env = streetlearn.StreetLearn(FLAGS.dataset_path, config, game)
  env.reset()
  pygame.init()
  screen = pygame.display.set_mode((FLAGS.width, FLAGS.height * 2))
  loop(env, screen)

if __name__ == '__main__':
  app.run(main)