import six import random as rnd import token from tokenize import tokenize from io import BytesIO SKP_WORD = '<sk>' RIG_WORD = '<]>' LFT_WORD = '<[>' class SCode(object): def __init__(self, init): self.token_list = None self.type_list = None if init is not None: if isinstance(init, list): self.set_by_list(init, None) elif isinstance(init, tuple): self.set_by_list(init[0], init[1]) elif isinstance(init, six.string_types): self.set_by_str(init) else: raise NotImplementedError def set_by_str(self, f): tk_list = list( tokenize(BytesIO(f.strip().encode('utf-8')).readline))[1:-1] self.token_list = [tk.string for tk in tk_list] self.type_list = [token.tok_name[tk.type] for tk in tk_list] # well-tokenized token list def set_by_list(self, token_list, type_list): self.token_list = list(token_list) if type_list is not None: self.type_list = list(type_list) def to_list(self): return self.token_list def __str__(self): return ' '.join(self.to_list()) def layout(self, add_skip=False): assert len(self.token_list) == len(self.type_list) r_list = [] for tk, tp in zip(self.token_list, self.type_list): if tp in ('OP', 'KEYWORD'): r_list.append(tk) elif tp in ('STRING',): if add_skip: s_list = tk.split(' ') r_list.extend( [LFT_WORD] + [SKP_WORD for __ in range(len(s_list) - 2)] + [RIG_WORD]) else: r_list.append(tp) # elif tp in ('NAME', 'NUMBER'): # if add_skip: # r_list.append(SKP_WORD) # else: # r_list.append(tp) else: r_list.append(tp) return r_list def target(self): assert len(self.token_list) == len(self.type_list) r_list = [] for tk, tp in zip(self.token_list, self.type_list): if tp in ('STRING',): s_list = tk.split(' ') r_list.extend([LFT_WORD] + s_list[1:-1] + [RIG_WORD]) else: r_list.append(tk) return r_list def norm(self, not_layout=False): return self def is_code_eq(t1, t2, not_layout=False): if isinstance(t1, SCode): t1 = str(t1) else: t1 = ' '.join(t1) if isinstance(t2, SCode): t2 = str(t2) else: t2 = ' '.join(t2) t1 = ['\"' if it in (RIG_WORD, LFT_WORD) else it for it in t1.split(' ')] t2 = ['\"' if it in (RIG_WORD, LFT_WORD) else it for it in t2.split(' ')] if len(t1) == len(t2): for tk1,tk2 in zip(t1,t2): # if not (tk1 == tk2 or tk1 == '<unk>' or tk2 == '<unk>'): if tk1 != tk2: return False return True else: return False return t1==t2 if __name__ == '__main__': for s in ("if base64d [ : 1 ] == b' _STR:0_ ' :".split(), "if base64d [ : 1 ] == b' _STR:0_ ' :".split(), "compressed = zlib . compress ( data )".split(), "compressed = zlib.compress(data)".split(),): t = SCode(s) print(1, t) print(2, t.to_list()) print(3, ' '.join(t.layout(add_skip=False))) print('\n')