""" Implements a custom protocol for sending and receiving line delineated messages. For blocking sockets, time-out is required to avoid DoS attacks when talking to a misbehaving or malicious third party. The benefit of this class is it makes communication with the P2P network easy to code without having to depend on threads and hence on mutexes (which are hard to use correctly.) In practice, a connection to a node on the P2P network would be done using the default options of this class and the connection would periodically be polled for replies. The processing of replies would automatically break once the socket indicated it would block and to prevent a malicious node from sending replies as fast as it could - there would be a max message limit per check period. Quirks: * send_line will block until the entire line has been sent even if the socket has been set to non-blocking to make things easier. If you need a non-blocking way to send a line: use send(). Note that you will have to check for the number of bytes sent and resend if needed just like the real send function. * connect has the same behaviour as above to make things simpler (so will block regardless of whether socket is in non-blocking mode or not.) If you want to bypass this behaviour you can always connect the socket outside this class and then pass it to set_socket. Otherwise, all functions in this class behave how you would expect them to (depending on whether you're using non-blocking mode or blocking mode.) It's assumed that all blocking operations have a timeout by default. This can't be disabled. Todo: test various functions under connection exit. Timeouts are needed for non-blocking too under conditions where you attempt to send all / recv all. """ import errno import platform import socket import ssl import sys import time from pyp2p.lib import get_lan_ip, parse_exception, log_exception from pyp2p.lib import encode_str error_log_path = "error.log" class Sock: def __init__(self, addr=None, port=None, blocking=0, timeout=5, interface="default", use_ssl=0, debug=0): self.nonce = None self.nonce_buf = u"" self.reply_filter = None self.buf = b"" self.max_buf = 1024 * 1024 # 1 MB. self.max_chunks = 1024 # Prevents spamming of multiple short messages. self.chunk_size = 1024 * 4 self.replies = [] self.blocking = blocking self.timeout = timeout self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # self.s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.use_ssl = use_ssl self.alive = time.time() self.unl = None if self.use_ssl: self.s = ssl.wrap_socket(self.s) self.connected = 0 self.interface = interface self.delimiter = b"\r\n" self.debug = debug # Set keep alive. # self.set_keep_alive(self.s) # Connect socket. if addr is not None and port is not None: # Set a timeout for blocking operations so they don't DoS program. # Disabled after connect if non-blocking is set. # (Connect is so far always blocking regardless of blocking mode.) self.s.settimeout(5) self.connect(addr, port) else: self.set_blocking(self.blocking, self.timeout) def debug_print(self, msg): if self.debug: msg = "> " + str(msg) print(msg) def set_keep_alive(self, sock, after_idle_sec=5, interval_sec=60, max_fails=5): """ This function instructs the TCP socket to send a heart beat every n seconds to detect dead connections. It's the TCP equivalent of the IRC ping-pong protocol and allows for better cleanup / detection of dead TCP connections. It activates after 1 second (after_idle_sec) of idleness, then sends a keepalive ping once every 3 seconds(interval_sec), and closes the connection after 5 failed ping (max_fails), or 15 seconds """ # OSX if platform.system() == "Darwin": # scraped from /usr/include, not exported by python's socket module TCP_KEEPALIVE = 0x10 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) sock.setsockopt(socket.IPPROTO_TCP, TCP_KEEPALIVE, interval_sec) if platform.system() == "Windows": sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, 10000, 3000)) if platform.system() == "Linux": sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, after_idle_sec) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, interval_sec) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, max_fails) def set_blocking(self, blocking, timeout=5): if self.s is None: return # Update blocking mode. self.s.setblocking(blocking) # Adjust timeout if needed. if blocking: if timeout is not None: self.s.settimeout(timeout) # Update blocking status. self.timeout = timeout self.blocking = blocking def set_sock(self, s): self.close() # Close old socket. self.s = s self.set_blocking(self.blocking, self.timeout) # self.s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # Set keep alive. # self.set_keep_alive(self.s) # Save addr + port. try: addr, port = self.s.getpeername() self.addr = addr self.port = port self.connected = 1 except: self.connected = 0 def reconnect(self): if not self.connected: if self.addr is not None and self.port is not None: try: return self.connect(self.addr, self.port) except: self.connected = 0 # Blocking (regardless of socket mode.) def connect(self, addr, port): # Save addr and port so socket can be reconnected. self.addr = addr self.port = port # No socket detected. if self.s is None: self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) if self.use_ssl: self.s = ssl.wrap_socket(self.s) # Make connection from custom interface. if self.interface != "default": try: # Todo: fix this to use static ips from Net src_ip = get_lan_ip(self.interface) self.s.bind((src_ip, 0)) except socket.error as e: if e.errno != 98: raise e try: self.s.connect((addr, int(port))) self.connected = 1 self.set_blocking(self.blocking, self.timeout) except Exception as e: self.debug_print("Connect failed") error = parse_exception(e) self.debug_print(error) log_exception(error_log_path, error) raise socket.error("Socket connect failed.") def close(self): self.connected = 0 # Attempt graceful shutdown. try: try: self.s.shutdown(1) except: pass self.s.close() except: pass self.s = None def parse_buf(self, encoding="unicode"): """ Since TCP is a stream-orientated protocol, responses aren't guaranteed to be complete when they arrive. The buffer stores all the data and this function splits the data into replies based on the new line delimiter. """ buf_len = len(self.buf) replies = [] reply = b"" chop = 0 skip = 0 i = 0 buf_len = len(self.buf) for i in range(0, buf_len): ch = self.buf[i:i + 1] if skip: skip -= 1 i += 1 continue nxt = i + 1 if nxt < buf_len: if ch == b"\r" and self.buf[nxt:nxt + 1] == b"\n": # Append new reply. if reply != b"": if encoding == "unicode": replies.append(encode_str(reply, encoding)) else: replies.append(reply) reply = b"" # Truncate the whole buf if chop is out of bounds. chop = nxt + 1 skip = 1 i += 1 continue reply += ch i += 1 # Truncate buf. if chop: self.buf = self.buf[chop:] return replies # Blocking or non-blocking. def get_chunks(self, fixed_limit=None, encoding="unicode"): """ This is the function which handles retrieving new data chunks. It's main logic is avoiding a recv call blocking forever and halting the program flow. To do this, it manages errors and keeps an eye on the buffer to avoid overflows and DoS attacks. http://stackoverflow.com/questions/16745409/what-does-pythons-socket-recv-return-for-non-blocking-sockets-if-no-data-is-r http://stackoverflow.com/questions/3187565/select-and-ssl-in-python """ # Socket is disconnected. if not self.connected: return # Recv chunks until network buffer is empty. repeat = 1 wait = 0.2 chunk_no = 0 max_buf = self.max_buf max_chunks = self.max_chunks if fixed_limit is not None: max_buf = fixed_limit max_chunks = fixed_limit while repeat: chunk_size = self.chunk_size while True: # Don't exceed buffer size. buf_len = len(self.buf) if buf_len >= max_buf: break remaining = max_buf - buf_len if remaining < chunk_size: chunk_size = remaining # Don't allow non-blocking sockets to be # DoSed by multiple small replies. if chunk_no >= max_chunks and not self.blocking: break try: chunk = self.s.recv(chunk_size) except socket.timeout as e: self.debug_print("Get chunks timed out.") self.debug_print(e) # Timeout on blocking sockets. err = e.args[0] self.debug_print(err) if err == "timed out": repeat = 0 break except ssl.SSLError as e: # Will block on non-blocking SSL sockets. if e.errno == ssl.SSL_ERROR_WANT_READ: self.debug_print("SSL_ERROR_WANT_READ") break else: self.debug_print("Get chunks ssl error") self.close() return except socket.error as e: # Will block on nonblocking non-SSL sockets. err = e.args[0] if err == errno.EAGAIN or err == errno.EWOULDBLOCK: break else: # Connection closed or other problem. self.debug_print("get chunks other closing") self.close() return else: if chunk == b"": self.close() return # Avoid decoding errors. self.buf += chunk # Otherwise the loop will be endless. if self.blocking: break # Used to avoid DoS of small packets. chunk_no += 1 # Repeat is already set -- manual skip. if not repeat: break else: repeat = 0 # Block until there's a full reply or there's a timeout. if self.blocking: if fixed_limit is None: # Partial response. if self.delimiter not in self.buf: repeat = 1 time.sleep(wait) def reply_callback(self, callback): self.reply_callback = callback # Called to check for replies and update buffers. def update(self): self.get_chunks() self.replies += self.parse_buf() # Execute callbacks on replies. if self.reply_filter is not None: replies = [] for reply in self.replies: if not self.reply_filter(reply): replies.append(u"") else: replies.append(reply) self.replies = replies # Blocking or non-blocking. def send(self, msg, send_all=0, timeout=5, encoding="ascii"): # Update timeout. if timeout != self.timeout and self.blocking: self.set_blocking(self.blocking, timeout) try: # Convert to bytes Python 2 & 3 # The caller should ensure correct encoding. if type(msg) == type(u""): msg = encode_str(msg, "ascii") # Work out stop time. if send_all: future = time.time() + (timeout or self.timeout) else: future = 0 repeat = 1 total_sent = 0 msg_len = len(msg) while repeat: repeat = 0 while True: # Attempt to send all. # This won't work if the network buffer is already full. try: bytes_sent = self.s.send( msg[total_sent:self.chunk_size]) except socket.timeout as e: err = e.args[0] if err == "timed out": return 0 except socket.error as e: err = e.args[0] if err == errno.EAGAIN or err == errno.EWOULDBLOCK: break else: # Connection closed or other problem. self.debug_print("Con send closing other") self.close() return 0 # Connection broken. if not bytes_sent or bytes_sent is None: self.close() return 0 # How much has been sent? total_sent += bytes_sent # Avoid looping forever. if self.blocking and not send_all: break # Everything sent. if total_sent >= msg_len: break # Don't block. if not send_all: break # Avoid 100% CPU. time.sleep(0.001) # Avoid looping forever. if send_all: if time.time() >= future: repeat = 0 break # Send the rest if blocking: if total_sent < msg_len and send_all: repeat = 1 return total_sent except Exception as e: self.debug_print("Con send: " + str(e)) error = parse_exception(e) log_exception(error_log_path, error) self.close() # Blocking or non-blocking. def recv(self, n, encoding="unicode", timeout=5): # Sanity checking. assert n # Update timeout. if timeout != self.timeout and self.blocking: self.set_blocking(self.blocking, timeout) try: # Get data. self.get_chunks(n, encoding=encoding) # Return the current buffer. ret = self.buf # Reset the old buffer. self.buf = b"" # Return results. if encoding == "unicode": ret = encode_str(ret, encoding) return ret except Exception as e: self.debug_print("Recv closign e" + str(e)) error = parse_exception(e) log_exception(error_log_path, error) self.close() if encoding == "unicode": return u"" else: return b"" # Sends a new message delimitered by a new line. # Blocking: blocks until entire line is sent for simplicity. def send_line(self, msg, timeout=5): # Sanity checking. assert (len(msg)) # Not connected. if not self.connected: return 0 # Update timeout. if timeout != self.timeout and self.blocking: self.set_blocking(self.blocking, timeout) try: # Convert to bytes Python 2 & 3 if type(msg) == type(u""): msg = encode_str(msg, "ascii") # Convert delimiter to bytes. msg += self.delimiter """ The inclusion of the send_all flag makes this function behave like a blocking socket for the purposes of sending a full line even if the socket is non-blocking. It's assumed that lines will be small and if the network buffer is full this code won't end up as a bottleneck. (Otherwise you would have to check the number of bytes returned every time you sent a line which is quite annoying.) """ ret = self.send(msg, send_all=1, timeout=timeout) return ret except Exception as e: self.debug_print("Send line closing" + str(e)) error = parse_exception(e) log_exception(error_log_path, error) self.close() return 0 # Receives a new message delimited by a new line. # Blocking or non-blocking. def recv_line(self, timeout=5): # Socket is disconnected. if not self.connected: return u"" # Update timeout. if timeout != self.timeout and self.blocking: self.set_blocking(self.blocking, timeout) # Return existing reply. if len(self.replies): temp = self.replies[0] self.replies = self.replies[1:] return temp try: future = time.time() + (timeout or self.timeout) while True: self.update() # Socket is disconnected. if not self.connected: return u"" # Non-blocking. if not ((not len(self.replies) or len( self.buf) >= self.max_buf) and self.blocking): break # Timeout elapsed. if time.time() >= future and self.blocking: break # Avoid 100% CPU. time.sleep(0.002) if len(self.replies): temp = self.replies[0] self.replies = self.replies[1:] return temp return u"" except Exception as e: self.debug_print("recv line error") error = parse_exception(e) self.debug_print(error) log_exception(error_log_path, error) """ These functions here make the class behave like a list. The list is a collection of replies received from the socket. Every iteration also has the bonus of checking for any new replies so it is very easy, for example to do: for replies in sock: To process replies without handling networking boilerplate. """ def __len__(self): self.update() return len(self.replies) def __getitem__(self, key): self.update() return self.replies[key] def __setitem__(self, key, value): self.update() self.replies[key] = value def __delitem__(self, key): self.update() del self.replies[key] def pop_reply(self): # Get replies. replies = [] for reply in self.replies: replies.append(reply) if len(replies): # Put replies back in the queue. self.replies = replies[1:] # Return the first reply. return replies[0] else: return None def __iter__(self): try: # Get replies. self.update() # Return replies. return iter(self.replies) finally: # Clear old replies. self.replies = [] def __reversed__(self): return self.__iter__() if __name__ == "__main__": """ s = Sock("158.69.201.105", 8540) exit() s.send_line("SOURCE TCP") while 1: for reply in s: print(reply) time.sleep(0.5) # print(s.recv_line()) # print("yes") """