import os import itertools import numpy as np import tensorflow as tf from utils import * from collections import Counter from nltk.tokenize import TreebankWordTokenizer EOS_TOKEN = "_eos_" class TextReader(object): def __init__(self, data_path): train_path = os.path.join(data_path, "train.txt") valid_path = os.path.join(data_path, "valid.txt") test_path = os.path.join(data_path, "test.txt") vocab_path = os.path.join(data_path, "vocab.pkl") if os.path.exists(vocab_path): self._load(vocab_path, train_path, valid_path, test_path) else: self._build_vocab(train_path, vocab_path) self.train_data = self._file_to_data(train_path) self.valid_data = self._file_to_data(valid_path) self.test_data = self._file_to_data(test_path) self.idx2word = {v:k for k, v in self.vocab.items()} self.vocab_size = len(self.vocab) def _read_text(self, file_path): with open(file_path) as f: return f.read().replace("\n", " %s " % EOS_TOKEN) def _build_vocab(self, file_path, vocab_path): counter = Counter(self._read_text(file_path).split()) count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) words, _ = list(zip(*count_pairs)) self.vocab = dict(zip(words, range(len(words)))) save_pkl(vocab_path, self.vocab) def _file_to_data(self, file_path): texts = self._read_text(file_path).split(EOS_TOKEN) data = [] for text in texts: data.append(np.array(map(self.vocab.get, text.split()))) save_npy(file_path + ".npy", data) return data def _load(self, vocab_path, train_path, valid_path, test_path): self.vocab = load_pkl(vocab_path) self.train_data = load_npy(train_path + ".npy") self.valid_data = load_npy(valid_path + ".npy") self.test_data = load_npy(test_path + ".npy") def get_data_from_type(self, data_type): if data_type == "train": raw_data = self.train_data elif data_type == "valid": raw_data = self.valid_data elif data_type == "test": raw_data = self.test_data else: raise Exception(" [!] Unkown data type %s: %s" % data_type) return raw_data def onehot(self, data, min_length=None): if min_length == None: min_length = self.vocab_size return np.bincount(data, minlength=min_length) def iterator(self, data_type="train"): raw_data = self.get_data_from_type(data_type) return itertools.cycle(([self.onehot(data), data] for data in raw_data if data != [])) def get(self, text=["medical"]): if type(text) == str: text = text.lower() text = TreebankWordTokenizer().tokenize(text) try: data = np.array(map(self.vocab.get, text)) return self.onehot(data), data except: unknowns = [] for word in text: if self.vocab.get(word) == None: unknowns.append(word) raise Exception(" [!] unknown words: %s" % ",".join(unknowns)) def random(self, data_type="train"): raw_data = self.get_data_from_type(data_type) idx = np.random.randint(len(raw_data)) data = raw_data[idx] return self.onehot(data), data