#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: simulator.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>

import sys
import os
import signal
import time
import tensorflow as tf
import multiprocessing as mp
import time
import threading
import weakref
from abc import abstractmethod, ABCMeta
from collections import defaultdict, namedtuple
import numpy as np
import six
from six.moves import queue

from ..models._common import disable_layer_logging
from ..callbacks import Callback
from ..tfutils.varmanip import SessionUpdate
from ..predict import OfflinePredictor
from ..utils import logger
from ..utils.timer import *
from ..utils.serialize import *
from ..utils.concurrency import *

__all__ = ['SimulatorProcess', 'SimulatorMaster',
        'SimulatorProcessStateExchange', 'SimulatorProcessSharedWeight',
        'TransitionExperience', 'WeightSync']

try:
    import zmq
except ImportError:
    logger.warn("Error in 'import zmq'. RL simulator won't be available.")
    __all__ = []

class TransitionExperience(object):
    """ A transition of state, or experience"""
    def __init__(self, state, action, reward, **kwargs):
        """ kwargs: whatever other attribute you want to save"""
        self.state = state
        self.action = action
        self.reward = reward
        for k, v in six.iteritems(kwargs):
            setattr(self, k, v)

class SimulatorProcessBase(mp.Process):
    __metaclass__ = ABCMeta

    def __init__(self, idx):
        super(SimulatorProcessBase, self).__init__()
        self.idx = int(idx)
        self.identity = u'simulator-{}'.format(self.idx).encode('utf-8')

    @abstractmethod
    def _build_player(self):
        pass


class SimulatorProcessStateExchange(SimulatorProcessBase):
    """
    A process that simulates a player and communicates to master to
    send states and receive the next action
    """
    __metaclass__ = ABCMeta

    def __init__(self, idx, pipe_c2s, pipe_s2c):
        """
        :param idx: idx of this process
        """
        super(SimulatorProcessStateExchange, self).__init__(idx)
        self.c2s = pipe_c2s
        self.s2c = pipe_s2c

    def run(self):
        player = self._build_player()
        context = zmq.Context()
        c2s_socket = context.socket(zmq.PUSH)
        c2s_socket.setsockopt(zmq.IDENTITY, self.identity)
        c2s_socket.set_hwm(2)
        c2s_socket.connect(self.c2s)

        s2c_socket = context.socket(zmq.DEALER)
        s2c_socket.setsockopt(zmq.IDENTITY, self.identity)
        #s2c_socket.set_hwm(5)
        s2c_socket.connect(self.s2c)

        state = player.current_state()
        reward, isOver = 0, False
        ts = 0

        while True:
            c2s_socket.send(dumps(
                (self.identity, state, reward, isOver, ts, True)),
                copy=False)
            #t.grel here we get the action
            (action, ts, isAlive) = loads(s2c_socket.recv(copy=False).bytes)
            if not isAlive:
                c2s_socket.send(dumps(
                    (self.identity, 0, 0, 0, 0, False)),
                    copy=False)
                print("closing thread : {}".format(self.identity))
                break

            reward, isOver = player.action(action)
            state = player.current_state()

# compatibility
SimulatorProcess = SimulatorProcessStateExchange

class SimulatorMaster(threading.Thread):
    """ A base thread to communicate with all StateExchangeSimulatorProcess.
        It should produce action for each simulator, as well as
        defining callbacks when a transition or an episode is finished.
    """
    __metaclass__ = ABCMeta

    class ClientState(object):
        def __init__(self):
            self.memory = []    # list of Experience

    def __init__(self, pipe_c2s, pipe_s2c, simulator_procs, pid):
        super(SimulatorMaster, self).__init__()
        self.daemon = True

        self.context = zmq.Context()

        self.c2s_socket = self.context.socket(zmq.PULL)
        self.c2s_socket.bind(pipe_c2s)
        self.c2s_socket.set_hwm(10)
        self.s2c_socket = self.context.socket(zmq.ROUTER)
        self.s2c_socket.bind(pipe_s2c)
        self.s2c_socket.set_hwm(10)

        # queueing messages to client
        self.send_queue = queue.Queue(maxsize=1)

        self.simulator_procs = simulator_procs
        self.killed_threads = 0
        self.pid = pid

        def f():
            msg = self.send_queue.get()
            self.s2c_socket.send_multipart(msg, copy=False)
        self.send_thread = LoopThread(f)
        self.send_thread.daemon = True
        self.send_thread.start()

        # make sure socket get closed at the end
        def clean_context(soks, context):
            for s in soks:
                s.close()
            context.term()
        import atexit
        atexit.register(clean_context, [self.c2s_socket, self.s2c_socket], self.context)

    def run(self):
        self.clients = defaultdict(self.ClientState)
        while True:
            bytes = self.c2s_socket.recv(copy=False).bytes
            msg = loads(bytes)
            ident, state, reward, isOver, ts, isAlive = msg
            client = self.clients[ident]

            if not isAlive:
                self.killed_threads += 1
                print("killed : {}, waiting for {}".format(self.killed_threads, self.simulator_procs))
                if self.killed_threads == self.simulator_procs:
                    self.M.isDone = True
                    break
                continue

            # check if reward&isOver is valid
            # in the first message, only state is valid
            if len(client.memory) > 0:
                client.memory[-1].reward = reward
                if isOver:
                    self._on_episode_over((ident, ts))
                else:
                    self._on_datapoint((ident, ts))
            # feed state and return action
            self._on_state(state, (ident, ts))

        print("MasterSimulator is out, peace")
        time.sleep(10)
        os.kill(self.pid, signal.SIGKILL)

    @abstractmethod
    def _on_state(self, state, ident):
        """response to state sent by ident. Preferrably an async call"""

    @abstractmethod
    def _on_episode_over(self, client):
        """ callback when the client just finished an episode.
            You may want to clear the client's memory in this callback.
        """

    def _on_datapoint(self, client):
        """ callback when the client just finished a transition
        """

    def __del__(self):
        self.context.destroy(linger=0)

class SimulatorProcessDF(SimulatorProcessBase):
    """ A simulator which contains a forward model itself, allowing
    it to produce data points directly """
    def __init__(self, idx, pipe_c2s):
        super(SimulatorProcessDF, self).__init__(idx)
        self.pipe_c2s = pipe_c2s

    def run(self):
        self.player = self._build_player()

        self.ctx = zmq.Context()
        self.c2s_socket = self.ctx.socket(zmq.PUSH)
        self.c2s_socket.setsockopt(zmq.IDENTITY, self.identity)
        self.c2s_socket.set_hwm(5)
        self.c2s_socket.connect(self.pipe_c2s)

        self._prepare()
        for dp in self.get_data():
            self.c2s_socket.send(dumps(dp), copy=False)

    @abstractmethod
    def _prepare(self):
        pass

    @abstractmethod
    def get_data(self):
        pass

class SimulatorProcessSharedWeight(SimulatorProcessDF):
    """ A simulator process with an extra thread waiting for event,
    and take shared weight from shm.

    Start me under some CUDA_VISIBLE_DEVICES set!
    """
    def __init__(self, idx, pipe_c2s, condvar, shared_dic, pred_config):
        super(SimulatorProcessSharedWeight, self).__init__(idx, pipe_c2s)
        self.condvar = condvar
        self.shared_dic = shared_dic
        self.pred_config = pred_config

    def _prepare(self):
        disable_layer_logging()
        self.predictor = OfflinePredictor(self.pred_config)
        with self.predictor.graph.as_default():
            vars_to_update = self._params_to_update()
            self.sess_updater = SessionUpdate(
                    self.predictor.session, vars_to_update)
        # TODO setup callback for explore?
        self.predictor.graph.finalize()

        self.weight_lock = threading.Lock()

        # start a thread to wait for notification
        def func():
            self.condvar.acquire()
            while True:
                self.condvar.wait()
                self._trigger_evt()
        self.evt_th = threading.Thread(target=func)
        self.evt_th.daemon = True
        self.evt_th.start()

    def _trigger_evt(self):
        with self.weight_lock:
            self.sess_updater.update(self.shared_dic['params'])
            logger.info("Updated.")

    def _params_to_update(self):
        # can be overwritten to update more params
        return tf.trainable_variables()

class WeightSync(Callback):
    """ Sync weight from main process to shared_dic and notify"""
    def __init__(self, condvar, shared_dic):
        self.condvar = condvar
        self.shared_dic = shared_dic

    def _setup_graph(self):
        self.vars = self._params_to_update()

    def _params_to_update(self):
        # can be overwritten to update more params
        return tf.trainable_variables()

    def _before_train(self):
        self._sync()
    def _trigger_epoch(self):
        self._sync()

    def _sync(self):
        logger.info("Updating weights ...")
        dic = {v.name: v.eval() for v in self.vars}
        self.shared_dic['params'] = dic
        self.condvar.acquire()
        self.condvar.notify_all()
        self.condvar.release()

if __name__ == '__main__':
    import random
    from tensorpack.RL import NaiveRLEnvironment
    class NaiveSimulator(SimulatorProcess):
        def _build_player(self):
            return NaiveRLEnvironment()
    class NaiveActioner(SimulatorActioner):
        def _get_action(self, state):
            time.sleep(1)
            return random.randint(1, 12)
        def _on_episode_over(self, client):
            #print("Over: ", client.memory)
            client.memory = []
            client.state = 0

    name = 'ipc://whatever'
    procs = [NaiveSimulator(k, name) for k in range(10)]
    [k.start() for k in procs]

    th = NaiveActioner(name)
    ensure_proc_terminate(procs)
    th.start()

    import time
    time.sleep(100)