from scipy.misc import imresize
from skimage.color import rgb2gray
from multiprocessing import *
from collections import deque
import gym
import numpy as np
import argparse

# -----
parser = argparse.ArgumentParser(description='Training model')
parser.add_argument('--game', default='Breakout-v0', help='OpenAI gym environment name', dest='game', type=str)
parser.add_argument('--processes', default=4, help='Number of processes that generate experience for agent',
                    dest='processes', type=int)
parser.add_argument('--lr', default=0.0001, help='Learning rate', dest='learning_rate', type=float)
parser.add_argument('--batch_size', default=20, help='Batch size to use during training', dest='batch_size', type=int)
parser.add_argument('--swap_freq', default=10000, help='Number of frames before swapping network weights',
                    dest='swap_freq', type=int)
parser.add_argument('--checkpoint', default=0, help='Iteration to resume training', dest='checkpoint', type=int)
parser.add_argument('--save_freq', default=250000, help='Number of frame before saving weights', dest='save_freq',
                    type=int)
parser.add_argument('--eps_decay', default=4000000,
                    help='Number of frames needed to decay epsilon to the lowest value', dest='eps_decay', type=int)
parser.add_argument('--lr_decay', default=80000000,
                    help='Number of frames needed to decay lr to the lowest value', dest='lr_decay', type=int)
parser.add_argument('--queue_size', default=256, help='Size of queue holding agent experience', dest='queue_size',
                    type=int)
parser.add_argument('--n_step', default=5, help='Number of steps in Q-learning', dest='n_step', type=int)
parser.add_argument('--th_comp_fix', default=True,
                    help='Sets different Theano compiledir for each process', dest='th_fix', type=bool)
# -----
args = parser.parse_args()


# -----


def build_network(input_shape, output_shape):
    from keras.models import Model
    from keras.layers import Input, Conv2D, Flatten, Dense

    x = Input(shape=input_shape)
    h = Conv2D(16, kernel_size=(8, 8), strides=(4, 4), activation='relu', data_format='channels_first')(x)
    h = Conv2D(32, kernel_size=(4, 4), strides=(2, 2), activation='relu', data_format='channels_first')(h)
    h = Flatten()(h)
    h = Dense(256, activation='relu')(h)
    v = Dense(output_shape, activation='linear')(h)
    return Model(inputs=x, outputs=v)


# -----

class LearningAgent(object):
    def __init__(self, action_space, batch_size=32, screen=(84, 84), swap_freq=200):
        from keras.optimizers import RMSprop
        # -----
        self.screen = screen
        self.input_depth = 1
        self.past_range = 3
        self.observation_shape = (self.input_depth * self.past_range,) + self.screen
        self.batch_size = batch_size

        self.action_value = build_network(self.observation_shape, action_space.n)
        self.action_value.compile(optimizer=RMSprop(clipnorm=1.), loss='mse')

        self.losses = deque(maxlen=25)
        self.q_values = deque(maxlen=25)
        self.swap_freq = swap_freq
        self.swap_counter = self.swap_freq
        self.unroll = np.arange(self.batch_size)
        self.frames = 0

    def learn(self, last_observations, actions, rewards, learning_rate=0.001):
        self.action_value.optimizer.lr.set_value(learning_rate)
        frames = len(last_observations)
        self.frames += frames
        # -----
        targets = self.action_value.predict_on_batch(last_observations)
        # -----
        targets[self.unroll, actions] = rewards
        # -----
        loss = self.action_value.train_on_batch(last_observations, targets)
        self.losses.append(loss)
        self.q_values.append(np.mean(targets))
        print('\rIter: %8d; Lr: %8.7f; Loss: %7.4f; Min: %7.4f; Max: %7.4f; Avg: %7.4f --- Q-value; Min: %7.4f; Max: %7.4f; Avg: %7.4f' % (
            self.frames, learning_rate, loss, min(self.losses), max(self.losses), np.mean(self.losses),
            np.min(self.q_values), np.max(self.q_values), np.mean(self.q_values)), end='')
        self.swap_counter -= frames
        if self.swap_counter < 0:
            self.swap_counter += self.swap_freq
            return True
        return False


def learn_proc(global_frame, mem_queue, weight_dict):
    import os
    pid = os.getpid()
    if args.th_fix:
        os.environ['THEANO_FLAGS'] = 'floatX=float32,device=gpu,nvcc.fastmath=False,lib.cnmem=0,' + \
                                     'compiledir=th_comp_learn'
    # -----
    save_freq = args.save_freq
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    checkpoint = args.checkpoint
    lr_decay = args.lr_decay
    # -----
    env = gym.make(args.game)
    agent = LearningAgent(env.action_space, batch_size=args.batch_size, swap_freq=args.swap_freq)
    # -----
    if checkpoint > 0:
        print(' %5d> Loading weights from file' % (pid,))
        agent.action_value.load_weights('model-%d.h5' % (checkpoint,))
        # -----
    weight_dict['update'] = 0
    weight_dict['weights'] = agent.action_value.get_weights()
    print(' %5d> Setting weights in dict' % (pid,))
    # -----
    last_obs = np.zeros((batch_size,) + agent.observation_shape)
    actions = np.zeros(batch_size, dtype=np.int32)
    rewards = np.zeros(batch_size)
    # -----
    idx = 0
    agent.frames = checkpoint
    save_counter = checkpoint % save_freq + save_freq
    while True:
        # -----
        last_obs[idx, ...], actions[idx], rewards[idx] = mem_queue.get()
        idx = (idx + 1) % batch_size
        if idx == 0:
            lr = max(0.000000001, learning_rate * (1. - agent.frames / lr_decay))
            updated = agent.learn(last_obs, actions, rewards, learning_rate=lr)
            global_frame.value = agent.frames
            if updated:
                # print(' %5d> Updating weights in dict' % (pid,))
                weight_dict['weights'] = agent.action_value.get_weights()
                weight_dict['update'] += 1
        # -----
        save_counter -= 1
        if save_counter % save_freq == 0:
            agent.action_value.save_weights('model-%d.h5' % (agent.frames,), overwrite=True)


class ActingAgent(object):
    def __init__(self, action_space, screen=(84, 84), n_step=8, discount=0.99):
        from keras.optimizers import RMSprop
        # -----
        self.screen = screen
        self.input_depth = 1
        self.past_range = 3
        self.observation_shape = (self.input_depth * self.past_range,) + self.screen

        self.action_value = build_network(self.observation_shape, action_space.n)
        self.action_value.compile(optimizer=RMSprop(clipnorm=1.), loss='mse')  # clipnorm=1.

        self.action_space = action_space
        self.observations = np.zeros(self.observation_shape)
        self.last_observations = np.zeros_like(self.observations)
        # -----
        self.n_step_observations = deque(maxlen=n_step)
        self.n_step_actions = deque(maxlen=n_step)
        self.n_step_rewards = deque(maxlen=n_step)
        self.n_step = n_step
        self.discount = discount
        self.counter = 0

    def init_episode(self, observation):
        for _ in range(self.past_range):
            self.save_observation(observation)

    def reset(self):
        self.counter = 0
        self.n_step_observations.clear()
        self.n_step_actions.clear()
        self.n_step_rewards.clear()

    def sars_data(self, action, reward, observation, terminal, mem_queue):
        self.save_observation(observation)
        reward = np.clip(reward, -1., 1.)
        # -----
        self.n_step_observations.appendleft(self.last_observations)
        self.n_step_actions.appendleft(action)
        self.n_step_rewards.appendleft(reward)
        # -----
        self.counter += 1
        if terminal or self.counter >= self.n_step:
            r = 0.
            if not terminal:
                r = np.max(self.action_value.predict(self.observations[None, ...]))
            for i in range(self.counter):
                r = self.n_step_rewards[i] + self.discount * r
                mem_queue.put((self.n_step_observations[i], self.n_step_actions[i], r))
            self.reset()

    def choose_action(self, epsilon=0.0):
        if np.random.random() < epsilon:
            return self.action_space.sample()
        else:
            return np.argmax(self.action_value.predict(self.observations[None, ...]))

    def save_observation(self, observation):
        self.last_observations = self.observations[...]
        self.observations = np.roll(self.observations, -self.input_depth, axis=0)
        self.observations[-self.input_depth:, ...] = self.transform_screen(observation)

    def transform_screen(self, data):
        return rgb2gray(imresize(data, self.screen))[None, ...]


def generate_experience_proc(global_frame, mem_queue, weight_dict, no, epsilon):
    import os
    pid = os.getpid()
    if args.th_fix:
        os.environ['THEANO_FLAGS'] = 'floatX=float32,device=gpu,nvcc.fastmath=True,lib.cnmem=0,' + \
                                     'compiledir=th_comp_act_' + str(no)
    # -----
    batch_size = args.batch_size
    # -----
    print(' %5d> Process started with %6.3f' % (pid, epsilon))
    # -----
    env = gym.make(args.game)
    agent = ActingAgent(env.action_space, n_step=args.n_step)

    if args.checkpoint > 0:
        print(' %5d> Loaded weights from file' % (pid,))
        agent.action_value.load_weights('model-%d.h5' % (args.checkpoint,))
    else:
        import time
        while 'weights' not in weight_dict:
            time.sleep(0.1)
        agent.action_value.set_weights(weight_dict['weights'])
        print(' %5d> Loaded weights from dict' % (pid,))

    best_score, last_update, frames = 0, 0, 0
    avg_score = deque(maxlen=20)
    stop_decay = global_frame.value > args.eps_decay

    while True:
        done = False
        episode_reward = 0
        last_op, op_count = 0, 0
        observation = env.reset()
        agent.init_episode(observation)

        # -----
        while not done:
            frames += 1
            if not stop_decay:
                frame_tmp = global_frame.value
                decayed_epsilon = max(epsilon, epsilon + (1. - epsilon) * (
                                        args.eps_decay - frame_tmp) / args.eps_decay)
                stop_decay = frame_tmp > args.eps_decay
            # -----
            action = agent.choose_action(decayed_epsilon)
            observation, reward, done, _ = env.step(action)
            episode_reward += reward
            best_score = max(best_score, episode_reward)
            # -----
            agent.sars_data(action, reward, observation, done, mem_queue)
            # -----
            if action == last_op:
                op_count += 1
            else:
                op_count, last_op = 0, action
            # -----
            if op_count > 100:
                agent.reset()  # reset agent memory
                break
            # -----
            if frames % 2000 == 0:
                print(' %5d> Epsilon: %9.6f; Best score: %4d; Avg: %9.3f' % (
                    pid, decayed_epsilon, best_score, np.mean(avg_score)))
            if frames % batch_size == 0:
                update = weight_dict.get('update', 0)
                if update > last_update:
                    last_update = update
                    # print(' %5d> Getting weights from dict' % (pid,))
                    agent.action_value.set_weights(weight_dict['weights'])
        # -----
        avg_score.append(episode_reward)


def init_worker():
    import signal
    signal.signal(signal.SIGINT, signal.SIG_IGN)


def main():
    manager = Manager()
    weight_dict = manager.dict()
    global_frame = manager.Value('i', args.checkpoint)
    mem_queue = manager.Queue(args.queue_size)

    eps = [0.1, 0.01, 0.5]
    pool = Pool(args.processes + 1, init_worker)

    try:
        for i in range(args.processes):
            pool.apply_async(generate_experience_proc,
                             args=(global_frame, mem_queue, weight_dict, i, eps[i % len(eps)]))

        pool.apply_async(learn_proc, args=(global_frame, mem_queue, weight_dict))

        pool.close()
        pool.join()

    except KeyboardInterrupt:
        pool.terminate()
        pool.join()


if __name__ == "__main__":
    main()