""" some nlp process utilty functions """ import io import re import sys import time import logging import numpy as np from itertools import groupby logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') logger = logging.getLogger('nlp_toolkit') global special_tokens special_tokens = set(['s_', 'lan_', 'ss_']) # [1, ['a', 'b], [True, False]] ---> [1, 'a', 'b', True, False] def flatten_gen(x): for i in x: if isinstance(i, list) or isinstance(i, tuple): for inner_i in i: yield inner_i else: yield i # judge char type ['cn', 'en', 'num', 'other'] def char_type(word): for char in word: unicode_char = ord(char) if unicode_char >= 19968 and unicode_char <= 40869: yield (char, 'cn') elif unicode_char >= 65 and unicode_char <= 122: yield (char, 'en') elif unicode_char >= 48 and unicode_char <= 57: yield (char, 'num') else: yield (char, 'other') # split word into chars def split_cn_en(word): new_word = [c for c in char_type(word)] new_word_len = len(new_word) tmp = '' for ix, item in enumerate(new_word): if item[1] in {'en', 'num'}: if ix < new_word_len - 1: if new_word[ix+1][1] == item[1]: tmp += item[0] else: tmp += item[0] yield tmp tmp = '' else: tmp += item[0] yield tmp else: yield item[0] # reassign token labels according new tokens def extract_char(word_list, label_list=None, use_seg=False): if label_list: for word, label in zip(word_list, label_list): # label = label.strip('#') single_check = word in special_tokens or not re.search(r'[^a-z0-9]+', word) if len(word) == 1 or single_check: if use_seg: yield (word, label, 'S') else: yield (word, label) else: try: new_word = list(split_cn_en(word)) word_len = len(new_word) if label == 'O': new_label = ['O'] * word_len elif label.startswith('I'): new_label = [label] * word_len else: label_i = 'I' + label[1:] if label.startswith('B'): new_label = [label] + [label_i] * (word_len - 1) elif label.startswith('E'): new_label = [label_i] * (word_len - 1) + [label] if use_seg: seg_tag = ['M'] * word_len seg_tag[0] = 'B' seg_tag[-1] = 'E' for x, y, z in zip(new_word, new_label, seg_tag): yield (x, y, z) else: for x, y in zip(new_word, new_label): yield (x, y) except Exception as e: print(e) print(list(zip(word_list, label_list))) sys.exit() else: for word in word_list: single_check = word in special_tokens or not re.search(r'[^a-z0-9]+', word) if len(word) == 1 or single_check: if use_seg: yield (word, 'S') else: yield (word) else: new_word = list(split_cn_en(word)) if use_seg: seg_tag = ['M'] * len(new_word) seg_tag[0] = 'B' seg_tag[-1] = 'E' for x, y in zip(new_word, seg_tag): yield (x, y) else: for x in new_word: yield x # get radical token by chars def get_radical(d, char_list): return [d[char] if char in d else '<unk>' for char in char_list] def word2char(word_list, label_list=None, task_type='', use_seg=False, radical_dict=None): """ convert basic token from word to char non-chinese word will not be simply splitted into char sequences e.g. "machine02" will be splitted into "machine" and "02" """ if task_type == 'classification': assert label_list is None assert radical_dict is None assert use_seg is False return [char for word in word_list for char in list(split_cn_en(word))] elif task_type == 'sequence_labeling': results = list( zip(*[item for item in extract_char(word_list, label_list, use_seg)])) if label_list: if use_seg: chars, new_labels, seg_tags = results assert len(chars) == len(new_labels) == len(seg_tags) else: chars, new_labels = results assert len(chars) == len(new_labels) new_result = {'token': chars, 'label': new_labels} else: if use_seg: chars, seg_tags = results assert len(chars) == len(seg_tags) else: chars = results new_result = {'token': chars} if use_seg: new_result['seg'] = seg_tags if radical_dict: new_result['radical'] = get_radical(radical_dict, chars) return new_result else: logger.error('invalid task type') sys.exit() def shorten_word(word): """ Shorten groupings of 3+ identical consecutive chars to 2, e.g. '!!!!' --> '!!' """ # must have at least 3 char to be shortened if len(word) < 3: return word # find groups of 3+ consecutive letters letter_groups = [list(g) for k, g in groupby(word)] triple_or_more = [''.join(g) for g in letter_groups if len(g) >= 3] if len(triple_or_more) == 0: return word # replace letters to find the short word short_word = word for trip in triple_or_more: short_word = short_word.replace(trip, trip[0] * 2) return short_word # Command line arguments are cast to bool type def boolean_string(s): if s not in {'False', 'True'}: raise ValueError('Not a valid boolean string') return s == 'True' # decorator to time a function def timer(function): def log_time(): start_time = time.time() function() elapsed = time.time() - start_time logger.info('Function "{name}" finished in {time:.2f} s'.format(name=function.__name__, time=elapsed)) return log_time() # generate small embedding files according given vocabs def gen_small_embedding(vocab_file, embed_file, output_file): vocab = set([word.strip() for word in open(vocab_file, encoding='utf8')]) print('total vocab: ', len(vocab)) fin = io.open(embed_file, 'r', encoding='utf-8', newline='\n', errors='ignore') try: n, d = map(int, fin.readline().split()) except Exception: print('please make sure the embed file is gensim-formatted') def gen(): for line in fin: token = line.rstrip().split(' ', 1)[0] if token in vocab: yield line result = [line for line in gen()] rate = 1 - len(result) / len(vocab) print('oov rate: {:4.2f}%'.format(rate * 100)) with open(output_file, 'w', encoding='utf8') as fout: fout.write(str(len(result)) + ' ' + str(d) + '\n') for line in result: fout.write(line) # load embeddings from text file def load_vectors(fname, vocab): fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore') _, d = map(int, fin.readline().split()) data = {} for line in fin: tokens = line.rstrip().split(' ') data[tokens[0]] = np.asarray(tokens[1:], dtype='float32') scale = 0.25 # scale = np.sqrt(3.0 / n_dim) embedding_matrix = np.random.uniform(-scale, scale, [len(vocab), d]) embedding_matrix[0] = np.zeros(d) cnt = 0 for word, i in vocab._token2id.items(): embedding_vector = data.get(word) if embedding_vector is not None: cnt += 1 embedding_matrix[i] = embedding_vector logger.info('OOV rate: {:04.2f} %'.format(1 - cnt / len(vocab._token2id))) return embedding_matrix, d def load_tc_data(fname, label_prefix='__label__', max_tokens_per_doc=256): def gen(): with open(fname, 'r', encoding='utf8') as fin: for line in fin: words = line.strip().split() if words: nb_labels = 0 label_line = [] for word in words: if word.startswith(label_prefix): nb_labels += 1 label = word.replace(label_prefix, "") label_line.append(label) else: break text = words[nb_labels:] if len(text) > max_tokens_per_doc: text = text[:max_tokens_per_doc] yield (text, label_line) texts, labels = zip(*[item for item in gen()]) return texts, labels def load_sl_data(fname, data_format='basic'): def process_conll(data): sents, labels = [], [] tokens, tags = [], [] for line in data: if line: token, tag = line.split('\t') tokens.append(token) tags.append(tag) else: sents.append(tokens) labels.append(tags) tokens, tags = [], [] return sents, labels data = (line.strip() for line in open(fname, 'r', encoding='utf8')) if data_format: if data_format == 'basic': texts, labels = zip( *[zip(*[item.rsplit('###', 1) for item in line.split('\t')]) for line in data]) elif data_format == 'conll': texts, labels = process_conll(data) return texts, labels else: print('invalid data format for sequence labeling task') def convert_seq_format(fin_name, fout_name, dest_format='conll'): if dest_format == 'conll': basic2conll(fin_name, fout_name) elif dest_format == 'basic': conll2basic(fin_name, fout_name) else: logger.warning('invalid data format') def basic2conll(fin_name, fout_name): data = [line.strip() for line in open(fin_name, 'r', encoding='utf8')] with open(fout_name, 'w', encoding='utf8') as fout: for line in data: for item in line.split('\t'): token, label = item.rsplit('###') label = label.strip('#') fout.write(token + '\t' + label + '\n') fout.write('\n') def conll2basic(fin_name, fout_name): data = [line.strip() for line in open(fin_name, 'r', encoding='utf8')] with open(fout_name, 'w', encoding='utf8') as fout: tmp = [] for line in data: if line: token, label = line.split('\t') label = label.strip('\t') item = token + '###' + label tmp.append(item) else: new_line = '\t'.join(tmp) + '\n' fout.write(new_line) tmp = []