# -*- coding: utf-8 -*- import zmq import json import h5py import logging import threading import numpy as np def init_gaussian(shape, order='C'): arr = 0.01 * np.random.randn(*shape).astype(np.float32) if order == 'F': return np.asfortranarray(arr) else: return arr def init_uniform(shape, low=-0.1, high=0.1, order='C'): assert low < high arr = np.random.uniform(low, high, shape).astype(np.float32) if order == 'F': return np.asfortranarray(arr) else: return arr def init_zeros(shape, order='C'): return np.zeros(shape, order=order).astype(np.float32) class Optim(dict): __getattr__ = dict.get __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ class ParameterServer(threading.Thread): def __init__(self): threading.Thread.__init__(self) def run(self): context = zmq.Context() frontend = context.socket(zmq.ROUTER) frontend.bind('tcp://*:5570') backend = context.socket(zmq.DEALER) backend.bind('inproc://backend') worker = ParameterWorker(context) worker.start() try: zmq.proxy(frontend, backend) except zmq.ContextTerminated: frontend.close() backend.close() class ParameterWorker(threading.Thread): def __init__(self, context, lr=0.01, weight_decay=1e-4, momentum=0.9): threading.Thread.__init__(self) self.context = context self.clients = {} self.grads = {} self.mtable = {} self.key, self.wkey, self.hkey = ['@meta@ps', 'w', 'h'] self.updater = 'sgd' self.support_keys = [self.wkey, self.hkey] assert lr > 0 assert weight_decay >= 0 assert momentum >= 0 self.optim = Optim() self.optim.lr = lr self.optim.weight_decay = weight_decay self.optim.momentum = momentum def run(self): self._socket = self.context.socket(zmq.DEALER) self._socket.connect('inproc://backend') print('Worker started') while True: self._recv() self._socket.close() # since `proxy` is daemon # we have to teminate it in a not graceful way if self.context: print('Terminate proxy ... ') time.sleep(1.) # make sure proxy pass the last msg to clients self.context.term() """ comm """ @staticmethod def _parse_json(x): return json.loads(x.decode('utf-8')) @staticmethod def _buf_to_ndarray(x, md): x = np.frombuffer(x, dtype=md['dtype']) return x.reshape(md['shape']) def _ready_for_update(self, mid): # TODO: following neglect the situation that a client sends the request twice for k in self.clients: if self.clients[k][mid] == 0: return False return True def _recv(self): # TODO: # 1. use logging to better log server info # 2. better exception handling print('waiting for message ... ') packet = self._socket.recv_multipart() if len(packet) == 2: ident, msg = packet msg = self._parse_json(msg) elif len(packet) == 4: ident, msg, meta, data = packet msg, meta = map(self._parse_json, [msg, meta]) if msg['op'] == 'set_matrix': msg['data'] = self._buf_to_ndarray(data, meta) else: msg['rows'] = self._buf_to_ndarray(data, meta) elif len(packet) == 6: ident, msg, rows_meta, rows_data, val_meta, val_data = packet msg, rows_meta, val_meta = map(self._parse_json, [msg, rows_meta, val_meta]) msg['rows'] = self._buf_to_ndarray(rows_data, rows_meta) msg['data'] = self._buf_to_ndarray(val_data, val_meta) else: raise RuntimeError('Unsupported msg type') self.handle(ident, msg) # try: # self.handle(ident, msg) # except Exception as err: # print('Error ', err) def _recv_array(self, flags=0, copy=True, track=False): ident = self._socket.recv_json(flags=flags) md = self._socket.recv_json(flags=flags) print("recv a numpy array {}".format(md)) msg = self._socket.recv(flags=flags, copy=copy, track=track) data = np.frombuffer(msg, dtype=md['dtype']) return data.reshape(md['shape']) def _send(self, ident, data): assert isinstance(data, np.ndarray) meta = {'dtype': str(data.dtype), 'shape': data.shape} self._socket.send(ident, zmq.SNDMORE) self._socket.send_json(meta) self._socket.send_multipart([ident, data]) """ public functions """ def handle(self, ident, msg): op = msg['op'] print('receiving {} from {}'.format(op, ident)) if op == 'register': self.clients[ident] = {} print('Client {} register. ({} clients in total).'.format( ident, len(self.clients))) elif op == 'exit': del self.clients[ident] print('Client {} exit. ({} clients in total).'.format( ident, len(self.clients))) elif op == 'add_matrix': mid = msg['mid'] self.add_matrix(mid, msg['shape'], init_uniform) self._reset_grad(mid) elif op == 'set_matrix': assert 'data' in msg force = False if 'force' in msg and msg['force']: force = True self.set_matrix(msg['mid'], msg['data'], force) print('the value of matrix {} has been set.'.format(msg['mid'])) elif op == 'get_value_by_rows': weights = self.get_value_by_rows(msg['mid'], msg['rows']) # send back to client self._send(ident, weights) elif op == 'set_value_by_rows': assert 'data' in msg self.set_value_by_rows(msg['mid'], msg['rows'], msg['data']) print('the value of {} rows of matrix {} has been set.'.format( len(msg['rows']), msg['mid'])) elif op == 'update_params': self.update_params(msg) elif op == 'update_by_rows': assert 'data' in msg mid = msg['mid'] # merge data from all clients assert len(msg['rows']) == len(msg['data']) self.clients[ident][mid] = 1 for i, r in enumerate(msg['rows']): if r not in self.grads[mid]: self.grads[mid][r] = np.array(msg['data'][i]) else: self.grads[mid][r] += np.array(msg['data'][i]) if self._ready_for_update(mid): print('updating') skip_decay = False if len(self.clients) == 1 and \ 'skip_decay' in msg and msg['skip_decay']: print('skipping weight decay') skip_decay = True self.update_by_rows(mid, np.array(list(self.grads[mid].keys())), np.array(list(self.grads[mid].values())), skip_decay=skip_decay) print("weight change", self.mtable[mid][self.wkey].mean()) # reset gradient self._reset_grad(mid) elif op == 'snapshot': self.snapshot(msg['path']) elif op == 'load': self.load(msg['path']) elif op == 'resume': self.resume(msg['path']) else: raise KeyError('Unknown operation') def add_matrix(self, mid, shape, init_func, his=True): mid = self._build_mtable(mid) # TODO: add `force` to rm already built matrix if mid is None: return self.mtable[mid][self.wkey] = init_func(shape) print(self.mtable[mid][self.wkey].shape) if his: self.mtable[mid][self.hkey] = init_zeros(shape) self._check_order(mid) def load_matrix(self, mid, h5group, his=True): mid = self._build_mtable(mid) if mid is None: return for key in h5group: if not key in self.support_keys: raise KeyError('The {} is not in the support list'.format(key)) self.mtable[mid][key] = np.asfortranarray(h5group[key]) if key == self.hkey and not his: self.mtable[mid][key].fill(0) self._check_order(mid) def get_value_by_rows(self, mid, rows): return self.mtable[mid][self.wkey][rows, :] def set_matrix(self, mid, data, force=False): """ Note that when you set value of weights directly, the history of SGD will automatically be set to zero. If `force` is not true, the shape of input data should be equal to the shape of existing weights. """ if not force: assert data.shape == self.mtable[mid][self.wkey].shape self.mtable[mid][self.wkey][:] = data if self.hkey in self.mtable[mid]: self.mtable[mid][self.hkey].fill(0) def set_value_by_rows(self, mid, rows, data): """ Note that when you set value of weights directly, the history of SGD will automatically be set to zero. """ self.mtable[mid][self.wkey][rows, :] = data if self.hkey in self.mtable[mid]: self.mtable[mid][self.hkey][rows, :].fill(0) def update_by_rows(self, mid, rows, grad, skip_decay=False): """ Note that the gradient from PyTorch is already conducted L2 regularization! That is, $grad += weight * weight\_decay$ has been applied to the grad. If you use `param.grad` from PyTorch Parameter, you don't have to regularize weights again. """ if self.optim.weight_decay > 0: if not skip_decay: grad = self._l2_regularize( self.mtable[mid][self.wkey][rows, :], grad) pass self._sgd_update(mid, rows, grad) def update_params(self, msg): for key in msg: if key == 'op': pass elif key not in self.optim: raise KeyError('Not supported key found: {}'.format(key)) else: val = msg[key] if key == 'lr': assert val > 0 else: assert val >= 0 self.optim[key] = val print("{} has been updated to {}".format(key, val)) def snapshot(self, path): with h5py.File(path, 'w') as f: print('snapshot to {}'.format(path)) ps = f.create_group(self.key) for key in self.mtable: midg = ps.create_group(str(key)) for k in self.mtable[key]: midg[k] = self.mtable[key][k][...] def resume(self, path, his=True): with h5py.File(path, 'r') as f: print('resume from {} with history={}'.format(path, his)) if self.key not in f.keys(): logging.warn('The model does not have {}'.format(self.key)) return ps = f[self.key] for key in ps.keys(): self.load_matrix(key, ps[key], his) def load(self, path): self.resume(path, his=False) """ private functions """ def _exists(self, mid): return (mid in self.mtable) def _build_mtable(self, mid): if isinstance(mid, str): pass elif isinstance(mid, int): mid = str(mid) elif isinstance(mid, unicode): mid = mid.encode('ascii', 'ignore') else: raise TypeError( 'The key({},{}) for Parameter Server should be str!'.format( mid, type(mid))) if not self._exists(mid): self.mtable[mid] = {} return mid else: return None def _check_order(self, mid): for key in self.mtable[mid]: if not self.mtable[mid][key].flags['C_CONTIGUOUS']: raise TypeError('np.darray should be C order!') def _reset_grad(self, mid): for k in self.clients: self.clients[k][mid] = 0 self.grads[mid] = {} def _l2_regularize(self, data, grad): grad += self.optim.weight_decay * data return grad def _sgd_update(self, mid, rows, grad): """ The algorithm here is compatible with PyTorch. Note that it is different from platform like Caffe. Detailed description can be found at https://pytorch.org/docs/stable/_modules/torch/optim/sgd.html#SGD """ if self.optim.momentum > 0: grad += self.mtable[mid][self.hkey][rows, :] * self.optim.momentum self.mtable[mid][self.hkey][rows, :] = grad self.mtable[mid][self.wkey][rows, :] -= self.optim.lr * grad if __name__ == "__main__": ps = ParameterServer() ps.start()