import torch import torch.nn as nn import torch.nn.init as init import numpy as np import re import const def get_non_pad_mask(seq): assert seq.dim() == 2 return seq.ne(const.PAD).type(torch.float).unsqueeze(-1) def get_padding_mask(x): return x.eq(0) def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): def cal_angle(position, hid_idx): return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) def get_posi_angle_vec(position): return [cal_angle(position, hid_j) for hid_j in range(d_hid)] sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) if padding_idx is not None: sinusoid_table[padding_idx] = 0. return torch.FloatTensor(sinusoid_table) def get_attn_key_pad_mask(seq_k, seq_q, byte=False): len_q = seq_q.size(1) padding_mask = seq_k.eq(const.PAD) padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) if byte: return padding_mask.byte() return padding_mask def get_subsequent_mask(seq): sz_b, len_s = seq.size() subsequent_mask = torch.triu( torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) return subsequent_mask def is_chinese_char(c): if ((c >= 0x4E00 and c <= 0x9FFF) or (c >= 0x3400 and c <= 0x4DBF) or (c >= 0x20000 and c <= 0x2A6DF) or (c >= 0x2A700 and c <= 0x2B73F) or (c >= 0x2B740 and c <= 0x2B81F) or (c >= 0x2B820 and c <= 0x2CEAF) or (c >= 0xF900 and c <= 0xFAFF) or (c >= 0x2F800 and c <= 0x2FA1F)): return True return False def split_char(text): text = "".join([w for w in text.split()]) step, words = 0, [] un_chinese = "" while step < len(text): if is_chinese_char(ord(text[step])): words.append(text[step]) step += 1 else: while step < len(text): if is_chinese_char(ord(text[step])): words.append(un_chinese.lower()) un_chinese = "" break un_chinese += text[step] step += 1 if un_chinese: return words + [un_chinese.lower()] return words def texts2idx(texts, word2idx): return [[word2idx[word] if word in word2idx else const.UNK for word in text] for text in texts] def find_index(text, word): stop_index = text.index(const.WORD[const.EOS]) if word in text[stop_index:]: idx = text.index(word, stop_index) else: idx = text.index(word) text[idx] = "@@@" return idx def find_text_index(q_words, new_tgt_words): word_map, q_words = {}, q_words.copy() t_index = np.zeros(len(new_tgt_words), dtype=int) for index, word in enumerate(new_tgt_words): if word in q_words: pointer = find_index(q_words, word) t_index[index] = pointer word_map[word] = pointer elif word in word_map: t_index[index] = word_map[word] else: raise Exception( f"invalid word {word} from {''.join(q_words)} {''.join(new_tgt_words)}") return t_index