import pickle
import gym
from collections import deque
from PIL import Image
from gym import spaces
import imageio
import numpy as np
from multiprocessing import Process, Pipe
import mpi4py.rc
import horovod.tensorflow as hvd
mpi4py.rc.initialize = False
from mpi4py import MPI

reset_for_batch = False

class MyWrapper(gym.Wrapper):
    def __init__(self, env):
        super(MyWrapper, self).__init__(env)
    def decrement_starting_point(self, nr_steps):
        return self.env.decrement_starting_point(nr_steps)
    def recursive_getattr(self, name):
        if hasattr(self, name):
            return getattr(self, name)
        else:
            return self.env.recursive_getattr(name)
    def batch_reset(self):
        global reset_for_batch
        reset_for_batch = True
        obs = self.env.reset()
        reset_for_batch = False
        return obs
    def reset(self):
        return self.env.reset()
    def step(self, action):
        return self.env.step(action)

    def step_async(self, actions):
        return self.env.step_async(actions)

    def step_wait(self):
        return self.env.step_wait()

    def reset_task(self):
        return self.env.reset_task()

    @property
    def num_envs(self):
        return self.env.num_envs

class VecFrameStack(MyWrapper):
    """
    Vectorized environment base class
    """
    def __init__(self, venv, nstack):
        self.venv = venv
        self.nstack = nstack
        wos = venv.observation_space # wrapped ob space
        low = np.repeat(wos.low, self.nstack, axis=-1)
        high = np.repeat(wos.high, self.nstack, axis=-1)
        self.stackedobs = np.zeros((venv.num_envs,)+low.shape, low.dtype)
        self._observation_space = spaces.Box(low=low, high=high)
        self._action_space = venv.action_space
    def step(self, vac):
        """
        Apply sequence of actions to sequence of environments
        actions -> (observations, rewards, news)
        where 'news' is a boolean vector indicating whether each element is new.
        """
        obs, rews, news, infos = self.venv.step(vac)
        self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=-1)
        for (i, new) in enumerate(news):
            if new:
                self.stackedobs[i] = 0
        self.stackedobs[..., -obs.shape[-1]:] = obs
        return self.stackedobs, rews, news, infos
    def reset(self):
        """
        Reset all environments
        """
        obs = self.venv.reset()
        self.stackedobs[...] = 0
        self.stackedobs[..., -obs.shape[-1]:] = obs
        return self.stackedobs
    @property
    def action_space(self):
        return self._action_space
    @property
    def observation_space(self):
        return self._observation_space
    def close(self):
        self.venv.close()
    @property
    def num_envs(self):
        return self.venv.num_envs

class ReplayResetEnv(MyWrapper):
    """
        Randomly resets to states from a replay
    """

    def __init__(self, env, demo_file_name, seed, reset_steps_ignored=512, workers_per_sp=4, frac_sample=0.2, game_over_on_life_loss=True):
        super(ReplayResetEnv, self).__init__(env)
        with open(demo_file_name, "rb") as f:
            dat = pickle.load(f)
        self.actions = dat['actions']
        rewards = dat['rewards']
        assert len(rewards) == len(self.actions)
        self.returns = np.cumsum(rewards)
        self.checkpoints = dat['checkpoints']
        self.checkpoint_action_nr = dat['checkpoint_action_nr']
        self.rng = np.random.RandomState(seed)
        self.reset_steps_ignored = reset_steps_ignored
        self.actions_to_overwrite = []
        self.starting_point = len(self.actions) - 1 - seed//workers_per_sp
        self.starting_point_current_ep = None
        self.frac_sample = frac_sample
        self.game_over_on_life_loss = game_over_on_life_loss

    def step(self, action):
        if len(self.actions_to_overwrite) > 0:
            action = self.actions_to_overwrite.pop(0)
            valid = False
        else:
            valid = True
        prev_lives = self.env.unwrapped.ale.lives()
        obs, reward, done, info = self.env.step(action)
        self.action_nr += 1
        self.score += reward

        # game over on loss of life, to speed up learning
        if self.game_over_on_life_loss:
            lives = self.env.unwrapped.ale.lives()
            if lives < prev_lives and lives > 0:
                done = True

        # kill if we have achieved the final score, or if we're laggging the demo too much
        if self.score >= self.returns[-1]:
            self.extra_frames_counter -= 1
            if self.extra_frames_counter <= 0:
                done = True
                info['replay_reset.random_reset'] = True # to distinguish from actual game over
        elif self.action_nr>50 and self.score<self.returns[np.minimum(len(self.returns)-1,self.action_nr-50)]:
            done = True

        # output flag to increase entropy if near the starting point of this episode
        if self.action_nr < self.starting_point + 100:
            info['increase_entropy'] = True

        if done:
            ep_info = {'l':self.action_nr, 'as_good_as_demo':(self.score >= self.returns[-1]),
                       'r':self.score, 'starting_point': self.starting_point_current_ep}
            info['episode'] = ep_info

        if not valid:
            info['replay_reset.invalid_transition'] = True

        return obs, reward, done, info

    def decrement_starting_point(self, nr_steps):
        if self.starting_point>0:
            self.starting_point = int(np.maximum(self.starting_point - nr_steps, 0))

    def reset(self):
        obs = self.env.reset()
        self.extra_frames_counter = int(np.exp(self.rng.rand()*7))

        if reset_for_batch:
            self.starting_point_current_ep = 0
            self.actions_to_overwrite = self.actions[:]
            self.action_nr = 0
            self.score = self.returns[0]
        else:

            if self.rng.rand() <= 1.-self.frac_sample:
                self.starting_point_current_ep = self.starting_point
            else:
                self.starting_point_current_ep = self.rng.randint(low=self.starting_point, high=len(self.actions))

            start_action_nr = 0
            start_ckpt = None
            for nr, ckpt in zip(self.checkpoint_action_nr[::-1], self.checkpoints[::-1]):
                if nr <= (self.starting_point_current_ep - self.reset_steps_ignored):
                    start_action_nr = nr
                    start_ckpt = ckpt
                    break
            if start_action_nr > 0:
                self.env.unwrapped.restore_state(start_ckpt)
            nr_to_start_lstm = np.maximum(self.starting_point_current_ep - self.reset_steps_ignored, start_action_nr)
            if nr_to_start_lstm>start_action_nr:
                for a in self.actions[start_action_nr:nr_to_start_lstm]:
                    action = self.env.unwrapped._action_set[a]
                    self.env.unwrapped.ale.act(action)
            self.actions_to_overwrite = self.actions[nr_to_start_lstm:self.starting_point_current_ep]
            if nr_to_start_lstm>0:
                obs = self.env.unwrapped._get_image()
            self.action_nr = nr_to_start_lstm
            self.score = self.returns[nr_to_start_lstm]

        return obs

class MaxAndSkipEnv(MyWrapper):
    def __init__(self, env, skip=4):
        """Return only every `skip`-th frame"""
        MyWrapper.__init__(self, env)
        # most recent raw observations (for max pooling across time steps)
        self._obs_buffer = deque(maxlen=2)
        self._skip       = skip

    def step(self, action):
        """Repeat action, sum reward, and max over last observations."""
        total_reward = 0.0
        done = None
        combined_info = {}
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            self._obs_buffer.append(obs)
            total_reward += reward
            combined_info.update(info)
            if done:
                break
        max_frame = np.max(np.stack(self._obs_buffer), axis=0)

        return max_frame, total_reward, done, combined_info

    def reset(self):
        """Clear past frame buffer and init. to first obs. from inner env."""
        self._obs_buffer.clear()
        obs = self.env.reset()
        self._obs_buffer.append(obs)
        return obs

class ClipRewardEnv(MyWrapper):
    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        reward = np.sign(reward)
        return obs, reward, done, info

class EpsGreedyEnv(MyWrapper):
    def __init__(self, env, eps=0.01):
        MyWrapper.__init__(self, env)
        self.eps = eps

    def step(self, action):
        if np.random.uniform()<self.eps:
            action = np.random.randint(self.env.action_space.n)
        obs, reward, done, info = self.env.step(action)
        return obs, reward, done, info

class StickyActionEnv(MyWrapper):
    def __init__(self, env, p=0.25):
        MyWrapper.__init__(self, env)
        self.p = p
        self.last_action = 0

    def step(self, action):
        if np.random.uniform() < self.p:
            action = self.last_action
        self.last_action = action
        obs, reward, done, info = self.env.step(action)
        return obs, reward, done, info

class Box(gym.Space):
    """
    A box in R^n.
    I.e., each coordinate is bounded.
    Example usage:
    self.action_space = spaces.Box(low=-10, high=10, shape=(1,))
    """
    def __init__(self, low, high, shape=None, dtype=np.uint8):
        """
        Two kinds of valid input:
            Box(-1.0, 1.0, (3,4)) # low and high are scalars, and shape is provided
            Box(np.array([-1.0,-2.0]), np.array([2.0,4.0])) # low and high are arrays of the same shape
        """
        if shape is None:
            assert low.shape == high.shape
            self.low = low
            self.high = high
        else:
            assert np.isscalar(low) and np.isscalar(high)
            self.low = low + np.zeros(shape)
            self.high = high + np.zeros(shape)
        self.dtype = dtype
    def contains(self, x):
        return x.shape == self.shape and (x >= self.low).all() and (x <= self.high).all()
    def to_jsonable(self, sample_n):
        return np.array(sample_n).tolist()
    def from_jsonable(self, sample_n):
        return [np.asarray(sample) for sample in sample_n]
    @property
    def shape(self):
        return self.low.shape
    @property
    def size(self):
        return self.low.shape
    def __repr__(self):
        return "Box" + str(self.shape)
    def __eq__(self, other):
        return np.allclose(self.low, other.low) and np.allclose(self.high, other.high)

class WarpFrame(MyWrapper):
    def __init__(self, env):
        """Warp frames to 84x84 as done in the Nature paper and later work."""
        MyWrapper.__init__(self, env)
        self.res = 84
        self.observation_space = Box(low=0, high=255, shape=(self.res, self.res, 1), dtype = np.uint8)

    def reshape_obs(self, obs):
        obs = np.dot(obs.astype('float32'), np.array([0.299, 0.587, 0.114], 'float32'))
        obs = np.array(Image.fromarray(obs).resize((self.res, self.res),
                                                   resample=Image.BILINEAR), dtype=np.uint8)
        return obs.reshape((self.res, self.res, 1))

    def reset(self):
        return self.reshape_obs(self.env.reset())

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        return self.reshape_obs(obs), reward, done, info

class MyResizeFrame(MyWrapper):
    def __init__(self, env):
        """Warp frames to 105x80"""
        MyWrapper.__init__(self, env)
        self.res = (105, 80, 3)
        self.observation_space = Box(low=0, high=255, shape=self.res, dtype = np.uint8)

    def reshape_obs(self, obs):
        obs = np.array(Image.fromarray(obs).resize((self.res[0],self.res[1]),
                                                   resample=Image.BILINEAR), dtype=np.uint8)
        return obs.reshape(self.res)

    def reset(self):
        return self.reshape_obs(self.env.reset())

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        return self.reshape_obs(obs), reward, done, info

class FireResetEnv(MyWrapper):
    def __init__(self, env):
        """Take action on reset for environments that are fixed until firing."""
        MyWrapper.__init__(self, env)
        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
        assert len(env.unwrapped.get_action_meanings()) >= 3

    def reset(self):
        self.env.reset()
        obs, _, done, _ = self.env.step(1)
        if done:
            self.env.reset()
        obs, _, done, _ = self.env.step(2)
        if done:
            self.env.reset()
        return obs

class VideoWriter(MyWrapper):
    def __init__(self, env, file_prefix):
        MyWrapper.__init__(self, env)
        self.file_prefix = file_prefix
        self.video_writer = None
        self.counter = 0

    def process_frame(self, frame):
        f_out = np.zeros((224, 160, 3), dtype=np.uint8)
        f_out[7:-7, :] = np.cast[np.uint8](frame)
        return f_out

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        self.video_writer.append_data(self.process_frame(obs))
        return obs, reward, done, info

    def reset(self):
        if self.video_writer is not None:
            self.video_writer.close()
            self.counter += 1
        self.video_writer = imageio.get_writer(self.file_prefix + str(self.counter) + '.mp4', mode='I', fps=120)
        return self.env.reset()

def my_wrapper(env, clip_rewards=True):
    assert 'NoFrameskip' in env.spec.id
    if clip_rewards:
        env = ClipRewardEnv(env)
    env = MaxAndSkipEnv(env, skip=4)
    if 'Pong' in env.spec.id:
        env = FireResetEnv(env)
    env = MyResizeFrame(env)
    return env

class ResetManager(MyWrapper):
    def __init__(self, env):
        super(ResetManager, self).__init__(env)
        starting_points = self.env.recursive_getattr('starting_point')
        all_starting_points = flatten_lists(MPI.COMM_WORLD.allgather(starting_points))
        self.min_starting_point = min(all_starting_points)
        self.max_starting_point = max(all_starting_points)
        self.nrstartsteps = self.max_starting_point - self.min_starting_point
        assert(self.nrstartsteps > 10)
        self.max_max_starting_point = self.max_starting_point
        self.starting_point_success = np.zeros(self.max_starting_point+10000)
        self.counter = 0
        self.infos = []

    def proc_infos(self):
        epinfos = [info['episode'] for info in self.infos if 'episode' in info]

        if hvd.size()>1:
            epinfos = flatten_lists(MPI.COMM_WORLD.allgather(epinfos))

        new_sp_wins = {}
        new_sp_counts = {}
        for epinfo in epinfos:
            sp = epinfo['starting_point']
            if sp in new_sp_counts:
                new_sp_counts[sp] += 1
                if epinfo['as_good_as_demo']:
                    new_sp_wins[sp] += 1
            else:
                new_sp_counts[sp] = 1
                if epinfo['as_good_as_demo']:
                    new_sp_wins[sp] = 1
                else:
                    new_sp_wins[sp] = 0

        for sp,wins in new_sp_wins.items():
            self.starting_point_success[sp] = np.cast[np.float32](wins)/new_sp_counts[sp]

        # move starting point, ensuring at least 20% of workers are able to complete the demo
        csd = np.argwhere(np.cumsum(self.starting_point_success) / self.nrstartsteps >= 0.2)
        if len(csd) > 0:
            new_max_start = csd[0][0]
        else:
            new_max_start = np.minimum(self.max_starting_point + 100, self.max_max_starting_point)
        n_points_to_shift = self.max_starting_point - new_max_start
        self.decrement_starting_point(n_points_to_shift)
        self.infos = []

    def decrement_starting_point(self, n_points_to_shift):
        self.env.decrement_starting_point(n_points_to_shift)
        starting_points = self.env.recursive_getattr('starting_point')
        all_starting_points = flatten_lists(MPI.COMM_WORLD.allgather(starting_points))
        self.max_starting_point = max(all_starting_points)

    def set_max_starting_point(self, starting_point):
        n_points_to_shift = self.max_starting_point - starting_point
        self.decrement_starting_point(n_points_to_shift)

    def step(self, action):
        obs, rews, news, infos = self.env.step(action)
        self.infos += infos
        self.counter += 1
        if self.counter > (self.max_max_starting_point - self.max_starting_point) / 2 and self.counter % 1024 == 0:
            self.proc_infos()
        return obs, rews, news, infos

    def step_wait(self):
        obs, rews, news, infos = self.env.step_wait()
        self.infos += infos
        self.counter += 1
        if self.counter > (self.max_max_starting_point - self.max_starting_point) / 2 and self.counter % 1024 == 0:
            self.proc_infos()
        return obs, rews, news, infos

def flatten_lists(listoflists):
    return [el for list_ in listoflists for el in list_]

def worker(remote, env_fn_wrapper):
    env = env_fn_wrapper.x()
    while True:
        cmd, data = remote.recv()
        if cmd == 'step':
            ob, reward, done, info = env.step(data)
            if done:
                ob = env.reset()
            remote.send((ob, reward, done, info))
        elif cmd == 'reset':
            ob = env.reset()
            remote.send(ob)
        elif cmd == 'close':
            remote.close()
            break
        elif cmd == 'get_spaces':
            remote.send((env.action_space, env.observation_space))
        elif cmd == 'get_history':
            senv = env
            while not hasattr(senv, 'get_history'):
                senv = senv.env
            remote.send(senv.get_history(data))
        elif cmd == 'recursive_getattr':
            remote.send(env.recursive_getattr(data))
        elif cmd == 'decrement_starting_point':
            env.decrement_starting_point(data)
        else:
            raise NotImplementedError

class CloudpickleWrapper(object):
    """
    Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
    """
    def __init__(self, x):
        self.x = x
    def __getstate__(self):
        import cloudpickle
        return cloudpickle.dumps(self.x)
    def __setstate__(self, ob):
        import pickle
        self.x = pickle.loads(ob)

class SubprocVecEnv(MyWrapper):
    def __init__(self, env_fns):
        """
        envs: list of gym environments to run in subprocesses
        """
        nenvs = len(env_fns)
        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
        self.ps = [Process(target=worker, args=(work_remote, CloudpickleWrapper(env_fn)))
            for (work_remote, env_fn) in zip(self.work_remotes, env_fns)]
        for p in self.ps:
            p.start()

        self.remotes[0].send(('get_spaces', None))
        self.action_space, self.observation_space = self.remotes[0].recv()

    def step(self, actions):
        for remote, action in zip(self.remotes, actions):
            remote.send(('step', action))
        results = [remote.recv() for remote in self.remotes]
        obs, rews, dones, infos = zip(*results)

        return np.stack(obs), np.stack(rews), np.stack(dones), infos

    def step_async(self, actions):
        for remote, action in zip(self.remotes, actions):
            remote.send(('step', action))
        self.waiting = True

    def step_wait(self):
        results = [remote.recv() for remote in self.remotes]
        self.waiting = False
        obs, rews, dones, infos = zip(*results)
        return np.stack(obs), np.stack(rews), np.stack(dones), infos

    def reset(self):
        for remote in self.remotes:
            remote.send(('reset', None))
        return np.stack([remote.recv() for remote in self.remotes])

    def reset_task(self):
        for remote in self.remotes:
            remote.send(('reset_task', None))
        return np.stack([remote.recv() for remote in self.remotes])

    def get_history(self, nsteps):
        for remote in self.remotes:
            remote.send(('get_history', nsteps))
        results = [remote.recv() for remote in self.remotes]
        obs, acts, dones = zip(*results)
        obs = np.stack(obs)
        acts = np.stack(acts)
        dones = np.stack(dones)
        return obs, acts, dones

    def recursive_getattr(self, name):
        for remote in self.remotes:
            remote.send(('recursive_getattr',name))
        return [remote.recv() for remote in self.remotes]

    def decrement_starting_point(self, n):
        for remote in self.remotes:
            remote.send(('decrement_starting_point', n))

    def close(self):
        for remote in self.remotes:
            remote.send(('close', None))
        for p in self.ps:
            p.join()

    @property
    def num_envs(self):
        return len(self.remotes)

class NoopResetEnv(gym.Wrapper):
    def __init__(self, env, noop_max=30):
        """Sample initial states by taking random number of no-ops on reset.
        No-op is assumed to be action 0.
        """
        gym.Wrapper.__init__(self, env)
        self.noop_max = noop_max
        self.override_num_noops = None
        self.noop_action = 0
        assert env.unwrapped.get_action_meanings()[0] == 'NOOP'

    def reset(self, **kwargs):
        """ Do no-op action for a number of steps in [1, noop_max]."""
        self.env.reset(**kwargs)
        if self.override_num_noops is not None:
            noops = self.override_num_noops
        else:
            noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101
        assert noops > 0
        obs = None
        for _ in range(noops):
            obs, _, done, _ = self.env.step(self.noop_action)
            if done:
                obs = self.env.reset(**kwargs)
        return obs

    def step(self, ac):
        return self.env.step(ac)