import zmq
import json
import numpy as np
import time


class ParameterClient():
    def __init__(self, _id):
        print('Connecting to ParamServer ...')
        context = zmq.Context()
        socket = context.socket(zmq.DEALER)
        identity = u'%d' % _id
        socket.identity = identity.encode('ascii')
        socket.connect('tcp://localhost:5570')
        print('Client %s started' % (identity))
        self._poll = zmq.Poller()
        self._poll.register(socket, zmq.POLLIN)
        self._socket = socket
        self._context = context
        self._register()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        print('Client %s exit' % (self._socket.identity))
        self._exit()
        self._socket.close()
        self._context.term()

    def _register(self):
        msg = dict(op='register')
        self._socket.send_json(msg)

    def _exit(self):
        msg = dict(op='exit')
        self._socket.send_json(msg)

    def _send_array(self, data, flags=0, copy=True, track=False):
        """send a numpy array with metadata"""
        md = dict(
            dtype=str(data.dtype),
            shape=data.shape,
        )
        self._socket.send_json(md, flags | zmq.SNDMORE)
        return self._socket.send(data, flags, copy=copy, track=track)

    def add_matrix(self, mid, shape):
        msg = dict(op='add_matrix', mid=mid, shape=shape)
        self._socket.send_json(msg)

    def set_matrix(self, mid, data, force=False):
        msg = dict(op='set_matrix', mid=mid, force=force)
        self._socket.send_json(msg, zmq.SNDMORE)
        self._send_array(data)

    def get_value_by_rows(self, mid, rows):
        msg = dict(op='get_value_by_rows', mid=mid)
        self._socket.send_json(msg, zmq.SNDMORE)
        self._send_array(rows)
        # receive data
        meta = None
        while True:
            sockets = dict(self._poll.poll(1000))
            if self._socket in sockets:
                msg = self._socket.recv()
                if not meta:
                    meta = json.loads(msg)
                else:
                    data = np.frombuffer(msg, dtype=meta['dtype'])
                    return data.reshape(meta['shape'])

    def set_value_by_rows(self, mid, rows, data):
        msg = dict(
            op='set_value_by_rows',
            mid=mid,
        )
        self._socket.send_json(msg, zmq.SNDMORE)
        self._send_array(rows, zmq.SNDMORE)
        self._send_array(data)

    def update_params(self, dic):
        assert len(dic) > 0
        msg = dict(op='update_params')
        msg.update(dic)
        self._socket.send_json(msg)

    def update_by_rows(self, mid, rows, data, skip_decay=False):
        msg = dict(op='update_by_rows', mid=mid, skip_decay=skip_decay)
        self._socket.send_json(msg, zmq.SNDMORE)
        self._send_array(rows, zmq.SNDMORE)
        self._send_array(data)

    def snapshot(self, path):
        msg = dict(op='snapshot', path=path)
        self._socket.send_json(msg)

    def load(self, path):
        msg = dict(op='resume', path=path)
        self._socket.send_json(msg)

    def resume(self, path):
        msg = dict(op='resume', path=path)
        self._socket.send_json(msg)


if __name__ == '__main__':
    num_class, fdim = 10, 256
    client0 = ParameterClient(0)
    client1 = ParameterClient(1)
    # make sure all the clients have successfully setup
    import time
    time.sleep(3)
    client0.add_matrix(mid='0', shape=[num_class, fdim])
    weights = client0.get_value_by_rows(mid='0', rows=[0, 1, 2, 3])
    client0.update_by_rows(mid='0', rows=[0, 1, 2, 3], data=np.ones([4, fdim]))
    client1.update_by_rows(mid='0', rows=[2, 3, 4, 5], data=np.ones([4, fdim]))