"""Wrapper class for ApeX based distributed workers.

- Author: Chris Yoon
- Contact: chris.yoon@medipixel.io
"""
import argparse
from collections import deque
from typing import Dict

import numpy as np
import pyarrow as pa
import ray
import zmq

from rl_algorithms.common.abstract.distributed_worker import (
    DistributedWorker,
    DistributedWorkerWrapper,
)
from rl_algorithms.utils.config import ConfigDict


@ray.remote(num_cpus=1)
class ApeXWorkerWrapper(DistributedWorkerWrapper):
    """Wrapper class for ApeX based distributed workers.

    Attributes:
        hyper_params (ConfigDict): worker hyper_params
        update_step (int): tracker for learner update step
        use_n_step (int): indication for using n-step transitions
        sub_socket (zmq.Context): subscriber socket for receiving params from learner
        push_socket (zmq.Context): push socket for sending experience to global buffer

    """

    def __init__(
        self, worker: DistributedWorker, args: argparse.Namespace, comm_cfg: ConfigDict
    ):
        DistributedWorkerWrapper.__init__(self, worker, args, comm_cfg)
        self.update_step = 0
        self.hyper_params = self.worker.hyper_params
        self.use_n_step = self.hyper_params.n_step > 1
        self.scores = dict()

        self.worker._init_env()

    # pylint: disable=attribute-defined-outside-init
    def init_communication(self):
        """Initialize sockets connecting worker-learner, worker-buffer."""
        # for receiving params from learner
        ctx = zmq.Context()
        self.sub_socket = ctx.socket(zmq.SUB)
        self.sub_socket.setsockopt_string(zmq.SUBSCRIBE, "")
        self.sub_socket.setsockopt(zmq.RCVHWM, 2)
        self.sub_socket.connect(f"tcp://127.0.0.1:{self.comm_cfg.learner_worker_port}")

        # for sending replay data to buffer
        self.push_socket = ctx.socket(zmq.PUSH)
        self.push_socket.connect(f"tcp://127.0.0.1:{self.comm_cfg.worker_buffer_port}")

    def send_data_to_buffer(self, replay_data):
        """Send replay data to global buffer."""
        replay_data_id = pa.serialize(replay_data).to_buffer()
        self.push_socket.send(replay_data_id)

    def recv_params_from_learner(self):
        """Get new params and sync. return True if success, False otherwise."""
        received = False
        try:
            new_params_id = self.sub_socket.recv(zmq.DONTWAIT)
            received = True
        except zmq.Again:
            # Although learner doesn't send params, don't wait
            pass

        if received:
            new_param_info = pa.deserialize(new_params_id)
            update_step, new_params = new_param_info
            self.update_step = update_step
            self.worker.synchronize(new_params)

            # Add new entry for scores dict
            self.scores[self.update_step] = []

    def compute_priorities(self, experience: Dict[str, np.ndarray]):
        """Compute priority values (TD error) of collected experience."""
        return self.worker.compute_priorities(experience)

    def collect_data(self) -> dict:
        """Fill and return local buffer."""
        local_memory = dict(states=[], actions=[], rewards=[], next_states=[], dones=[])
        local_memory_keys = local_memory.keys()
        if self.use_n_step:
            nstep_queue = deque(maxlen=self.hyper_params.n_step)

        while len(local_memory["states"]) < self.hyper_params.local_buffer_max_size:
            state = self.worker.env.reset()
            done = False
            score = 0
            num_steps = 0
            while not done:
                if self.args.worker_render:
                    self.worker.env.render()
                num_steps += 1
                action = self.select_action(state)
                next_state, reward, done, _ = self.step(action)
                transition = (state, action, reward, next_state, int(done))
                if self.use_n_step:
                    nstep_queue.append(transition)
                    if self.hyper_params.n_step == len(nstep_queue):
                        nstep_exp = self.preprocess_nstep(nstep_queue)
                        for entry, keys in zip(nstep_exp, local_memory_keys):
                            local_memory[keys].append(entry)
                else:
                    for entry, keys in zip(transition, local_memory_keys):
                        local_memory[keys].append(entry)

                state = next_state
                score += reward

                self.recv_params_from_learner()

            self.scores[self.update_step].append(score)

            if self.args.worker_verbose:
                print(
                    f"[TRAIN] [Worker {self.worker.rank}] "
                    + f"Update step: {self.update_step}, Score: {score}, "
                    + f"Epsilon: {self.worker.epsilon:.5f}"
                )

        for key in local_memory_keys:
            local_memory[key] = np.array(local_memory[key])

        return local_memory

    def run(self) -> Dict[int, float]:
        """Run main worker loop."""
        self.scores[self.update_step] = []
        while self.update_step < self.args.max_update_step:
            experience = self.collect_data()
            priority_values = self.compute_priorities(experience)
            worker_data = [experience, priority_values]
            self.send_data_to_buffer(worker_data)

        mean_scores_per_ep_step = self.compute_mean_scores(self.scores)
        return mean_scores_per_ep_step

    @staticmethod
    def compute_mean_scores(scores: Dict[int, list]):
        for step in list(scores):
            if scores[step]:
                scores[step] = np.mean(scores[step])
            else:
                # Delete empty score list
                # made when network is updated before termination of episode
                scores.pop(step)
        return scores