#!/usr/bin/env python # -*- coding: utf-8 -*- # File: remote.py # Author: Yuxin Wu <ppwwyyxxc@gmail.com> from ..utils import logger try: import zmq except ImportError: logger.warn("Error in 'import zmq'. remote feature won't be available") __all__ = [] else: __all__ = ['serve_data', 'RemoteData'] from .base import DataFlow from .common import RepeatedData from ..utils import logger from ..utils.serialize import dumps, loads def serve_data(ds, addr): ctx = zmq.Context() socket = ctx.socket(zmq.PUSH) socket.set_hwm(10) socket.bind(addr) ds = RepeatedData(ds, -1) try: ds.reset_state() logger.info("Serving data at {}".format(addr)) while True: for dp in ds.get_data(): socket.send(dumps(dp), copy=False) finally: socket.setsockopt(zmq.LINGER, 0) socket.close() if not ctx.closed: ctx.destroy(0) class RemoteData(DataFlow): def __init__(self, addr): self.ctx = zmq.Context() self.socket = self.ctx.socket(zmq.PULL) self.socket.set_hwm(10) self.socket.connect(addr) def get_data(self): while True: dp = loads(self.socket.recv(copy=False)) yield dp if __name__ == '__main__': import sys from tqdm import tqdm from .raw import FakeData addr = "tcp://127.0.0.1:8877" if sys.argv[1] == 'serve': ds = FakeData([(128,244,244,3)], 1000) serve_data(ds, addr) else: ds = RemoteData(addr) logger.info("Each DP is 73.5MB") with tqdm(total=10000) as pbar: for k in ds.get_data(): pbar.update()