""" helpers for deserializing Thrift messages """ from struct import unpack from .thrift_struct import ThriftStruct from .util import to_bytes from thrift.Thrift import TMessageType from thrift.protocol.TBinaryProtocol import TBinaryProtocol from thrift.protocol.TCompactProtocol import TCompactProtocol from thrift.protocol.TJSONProtocol import TJSONProtocol from thrift.transport import TTransport class ThriftMessage(object): def __init__(self, method, mtype, seqid, args, header=(), length=-1): self._method = method self._type = mtype self._seqid = seqid if not isinstance(args, ThriftStruct): raise ValueError('args must be a ThriftStruct instance') self._args = args self._header = header # finagle-thrift prepends this to each call self._length = length def __len__(self): """ in bytes """ return self._length @property def bytes_length(self): """ for ThriftStruct, __len__ means something different so lets add this other property to unify the way to refer to bytes """ return len(self) @property def method(self): return self._method @property def type(self): return self._type @property def seqid(self): return self._seqid @property def args(self): return self._args @property def header(self): return self._header def __str__(self): return 'method=%s, type=%s, seqid=%s, header=%s, fields=%s' % ( self.method, self.type, self.seqid, self.header, str(self.args)) @property def as_dict(self): return { 'method': self.method, 'type': self.type, 'seqid': self.seqid, 'header': self.header, 'args': self.args.as_dict, 'length': len(self), } MAX_METHOD_LENGTH = 70 # For Binary, this is i32 + str + i32 # For Compact, the empty ping() gets through in 8 bytes MIN_MESSAGE_SIZE = 8 # some sane defaults to keep memory usage tight MAX_FIELDS = 1000 MAX_LIST_SIZE = 1000000 MAX_MAP_SIZE = 1000000 MAX_SET_SIZE = 1000000 @classmethod def read(cls, data, protocol=None, fallback_protocol=TBinaryProtocol, finagle_thrift=False, max_fields=MAX_FIELDS, max_list_size=MAX_LIST_SIZE, max_map_size=MAX_MAP_SIZE, max_set_size=MAX_SET_SIZE, read_values=False): """ tries to deserialize a message, might fail if data is missing """ # do we have enough data? if len(data) < cls.MIN_MESSAGE_SIZE: raise ValueError('not enough data: %d' % len(data)) if protocol is None: protocol = cls.detect_protocol(data, fallback_protocol) trans = TTransport.TMemoryBuffer(data) proto = protocol(trans) # finagle-thrift prepends a RequestHeader # # See: http://git.io/vsziG header = None if finagle_thrift: try: header = ThriftStruct.read( proto, max_fields, max_list_size, max_map_size, max_set_size, read_values) except Exception as ex: # reset stream, maybe it's not finagle-thrift trans = TTransport.TMemoryBuffer(data) proto = protocol(trans) # unpack the message method, mtype, seqid = proto.readMessageBegin() mtype = cls.message_type_to_str(mtype) if len(method) == 0 or method.isspace() or method.startswith(' '): raise ValueError('no method name') if len(method) > cls.MAX_METHOD_LENGTH: raise ValueError('method name too long') # we might have made it until this point by mere chance, so filter out # suspicious method names valid = range(33, 127) if any(ord(char) not in valid for char in method): raise ValueError('invalid method name' % method) args = ThriftStruct.read( proto, max_fields, max_list_size, max_map_size, max_set_size, read_values) proto.readMessageEnd() # Note: this is a bit fragile, the right thing would be to count bytes # as we read them (i.e.: when calling readI32, etc). msglen = trans._buffer.tell() return cls(method, mtype, seqid, args, header, msglen), msglen @classmethod def detect_protocol(cls, data, default=None): """ TODO: support fbthrift, finagle-thrift, finagle-mux, CORBA """ if cls.is_compact_protocol(data): return TCompactProtocol elif cls.is_binary_protocol(data): return TBinaryProtocol elif cls.is_json_protocol(data): return TJSONProtocol if default is None: raise ValueError('Unknown protocol') return default COMPACT_PROTOCOL_ID = 0x82 @classmethod def is_compact_protocol(cls, data): result, = unpack('!B', data[:1]) return result == cls.COMPACT_PROTOCOL_ID BINARY_PROTOCOL_VERSION_MASK = -65536 # 0xffff0000 BINARY_PROTOCOL_VERSION_1 = -2147418112 # 0x80010000 @classmethod def is_binary_protocol(cls, data): val, = unpack('!i', data[0:4]) if val >= 0: return False version = val & cls.BINARY_PROTOCOL_VERSION_MASK return version == cls.BINARY_PROTOCOL_VERSION_1 @classmethod def is_json_protocol(cls, data): # FIXME: more elaborate parsing would make this more robust return data.tobytes().startswith(b'[1') @staticmethod def message_type_to_str(mtype): if mtype == TMessageType.CALL: return 'call' elif mtype == TMessageType.REPLY: return 'reply' elif mtype == TMessageType.EXCEPTION: return 'exception' elif mtype == TMessageType.ONEWAY: return 'oneway' else: raise ValueError('Unknown message type: %s' % mtype)