import time import socket import zlib import struct import pysyncobj.pickle as pickle import pysyncobj.win_inet_pton from .poller import POLL_EVENT_TYPE from .monotonic import monotonic as monotonicTime class CONNECTION_STATE: DISCONNECTED = 0 CONNECTING = 1 CONNECTED = 2 def _getAddrType(addr): try: socket.inet_aton(addr) return socket.AF_INET except socket.error: pass try: socket.inet_pton(socket.AF_INET6, addr) return socket.AF_INET6 except socket.error: pass raise Exception('unknown address type') class TcpConnection(object): def __init__(self, poller, onMessageReceived = None, onConnected = None, onDisconnected = None, socket=None, timeout=10.0, sendBufferSize = 2 ** 13, recvBufferSize = 2 ** 13): self.sendRandKey = None self.recvRandKey = None self.encryptor = None self.__socket = socket self.__readBuffer = bytes() self.__writeBuffer = bytes() self.__lastReadTime = monotonicTime() self.__timeout = timeout self.__poller = poller if socket is not None: self.__socket = socket self.__fileno = socket.fileno() self.__state = CONNECTION_STATE.CONNECTED self.__poller.subscribe(self.__fileno, self.__processConnection, POLL_EVENT_TYPE.READ | POLL_EVENT_TYPE.WRITE | POLL_EVENT_TYPE.ERROR) else: self.__state = CONNECTION_STATE.DISCONNECTED self.__fileno = None self.__socket = None self.__onMessageReceived = onMessageReceived self.__onConnected = onConnected self.__onDisconnected = onDisconnected self.__sendBufferSize = sendBufferSize self.__recvBufferSize = recvBufferSize def setOnConnectedCallback(self, onConnected): self.__onConnected = onConnected def setOnMessageReceivedCallback(self, onMessageReceived): self.__onMessageReceived = onMessageReceived def setOnDisconnectedCallback(self, onDisconnected): self.__onDisconnected = onDisconnected def connect(self, host, port): self.__state = CONNECTION_STATE.DISCONNECTED self.__fileno = None self.__socket = socket.socket(_getAddrType(host), socket.SOCK_STREAM) self.__socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.__sendBufferSize) self.__socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.__recvBufferSize) self.__socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.__socket.setblocking(0) self.__readBuffer = bytes() self.__writeBuffer = bytes() self.__lastReadTime = monotonicTime() try: self.__socket.connect((host, port)) except socket.error as e: if e.errno not in (socket.errno.EINPROGRESS, socket.errno.EWOULDBLOCK): return False self.__fileno = self.__socket.fileno() self.__state = CONNECTION_STATE.CONNECTING self.__poller.subscribe(self.__fileno, self.__processConnection, POLL_EVENT_TYPE.READ | POLL_EVENT_TYPE.WRITE | POLL_EVENT_TYPE.ERROR) return True def send(self, message): if self.sendRandKey: message = (self.sendRandKey, message) data = zlib.compress(pickle.dumps(message), 3) if self.encryptor: data = self.encryptor.encrypt(data) data = struct.pack('i', len(data)) + data self.__writeBuffer += data self.__trySendBuffer() def fileno(self): return self.__fileno def disconnect(self): if self.__onDisconnected is not None and self.__state != CONNECTION_STATE.DISCONNECTED: self.__onDisconnected() self.sendRandKey = None self.recvRandKey = None self.encryptor = None if self.__socket is not None: self.__socket.close() self.__socket = None if self.__fileno is not None: self.__poller.unsubscribe(self.__fileno) self.__fileno = None self.__writeBuffer = bytes() self.__readBuffer = bytes() self.__state = CONNECTION_STATE.DISCONNECTED def getSendBufferSize(self): return len(self.__writeBuffer) def __processConnection(self, descr, eventType): poller = self.__poller if descr != self.__fileno: poller.unsubscribe(descr) return if eventType & POLL_EVENT_TYPE.ERROR: self.disconnect() return if monotonicTime() - self.__lastReadTime > self.__timeout: self.disconnect() return if eventType & POLL_EVENT_TYPE.READ or eventType & POLL_EVENT_TYPE.WRITE: if self.__socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR): self.disconnect() return if self.__state == CONNECTION_STATE.CONNECTING: if self.__onConnected is not None: self.__onConnected() self.__state = CONNECTION_STATE.CONNECTED self.__lastReadTime = monotonicTime() return if eventType & POLL_EVENT_TYPE.WRITE: self.__trySendBuffer() if self.__state == CONNECTION_STATE.DISCONNECTED: return event = POLL_EVENT_TYPE.READ | POLL_EVENT_TYPE.ERROR if len(self.__writeBuffer) > 0: event |= POLL_EVENT_TYPE.WRITE poller.subscribe(descr, self.__processConnection, event) if eventType & POLL_EVENT_TYPE.READ: self.__tryReadBuffer() if self.__state == CONNECTION_STATE.DISCONNECTED: return while True: message = self.__processParseMessage() if message is None: break if self.__onMessageReceived is not None: self.__onMessageReceived(message) if self.__state == CONNECTION_STATE.DISCONNECTED: return def __trySendBuffer(self): while self.__processSend(): pass def __processSend(self): if not self.__writeBuffer: return False try: res = self.__socket.send(self.__writeBuffer) if res < 0: self.disconnect() return False if res == 0: return False self.__writeBuffer = self.__writeBuffer[res:] return True except socket.error as e: if e.errno not in (socket.errno.EAGAIN, socket.errno.EWOULDBLOCK): self.disconnect() return False def __tryReadBuffer(self): while self.__processRead(): pass self.__lastReadTime = monotonicTime() def __processRead(self): try: incoming = self.__socket.recv(self.__recvBufferSize) except socket.error as e: if e.errno not in (socket.errno.EAGAIN, socket.errno.EWOULDBLOCK): self.disconnect() return False if self.__socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR): self.disconnect() return False if not incoming: self.disconnect() return False self.__readBuffer += incoming return True def __processParseMessage(self): if len(self.__readBuffer) < 4: return None l = struct.unpack('i', self.__readBuffer[:4])[0] if len(self.__readBuffer) - 4 < l: return None data = self.__readBuffer[4:4 + l] try: if self.encryptor: data = self.encryptor.decrypt(data) message = pickle.loads(zlib.decompress(data)) if self.recvRandKey: randKey, message = message assert randKey == self.recvRandKey except: self.disconnect() return None self.__readBuffer = self.__readBuffer[4 + l:] return message @property def state(self): return self.__state