import numpy as np import os import copy from math import ceil from collections import Counter def resplit(train, facts, no_link_percent): num_train = len(train) num_facts = len(facts) all = train + facts if no_link_percent == 0.: np.random.shuffle(all) new_train = all[:num_train] new_facts = all[num_train:] else: link_cntr = Counter() for tri in all: link_cntr[(tri[1], tri[2])] += 1 tmp_train = [] tmp_facts = [] for tri in all: if link_cntr[(tri[1], tri[2])] + link_cntr[(tri[2], tri[1])] > 1: if np.random.random() < no_link_percent: tmp_facts.append(tri) else: tmp_train.append(tri) else: tmp_train.append(tri) if len(tmp_train) > num_train: np.random.shuffle(tmp_train) new_train = tmp_train[:num_train] new_facts = tmp_train[num_train:] + tmp_facts else: np.random.shuffle(tmp_facts) num_to_fill = num_train - len(tmp_train) new_train = tmp_train + tmp_facts[:num_to_fill] new_facts = tmp_facts[num_to_fill:] assert(len(new_train) == num_train) assert(len(new_facts) == num_facts) return new_train, new_facts class Data(object): def __init__(self, folder, seed, type_check, domain_size, no_extra_facts): np.random.seed(seed) self.seed = seed self.type_check = type_check self.domain_size = domain_size self.use_extra_facts = not no_extra_facts self.query_include_reverse = True self.relation_file = os.path.join(folder, "relations.txt") self.entity_file = os.path.join(folder, "entities.txt") self.relation_to_number, self.entity_to_number = self._numerical_encode() self.number_to_entity = {v: k for k, v in self.entity_to_number.items()} self.num_relation = len(self.relation_to_number) self.num_query = self.num_relation * 2 self.num_entity = len(self.entity_to_number) self.test_file = os.path.join(folder, "test.txt") self.train_file = os.path.join(folder, "train.txt") self.valid_file = os.path.join(folder, "valid.txt") if os.path.isfile(os.path.join(folder, "facts.txt")): self.facts_file = os.path.join(folder, "facts.txt") self.share_db = True else: self.train_facts_file = os.path.join(folder, "train_facts.txt") self.test_facts_file = os.path.join(folder, "test_facts.txt") self.share_db = False self.test, self.num_test = self._parse_triplets(self.test_file) self.train, self.num_train = self._parse_triplets(self.train_file) if os.path.isfile(self.valid_file): self.valid, self.num_valid = self._parse_triplets(self.valid_file) else: self.valid, self.train = self._split_valid_from_train() self.num_valid = len(self.valid) self.num_train = len(self.train) if self.share_db: self.facts, self.num_fact = self._parse_triplets(self.facts_file) self.matrix_db = self._db_to_matrix_db(self.facts) self.matrix_db_train = self.matrix_db self.matrix_db_test = self.matrix_db self.matrix_db_valid = self.matrix_db if self.use_extra_facts: extra_mdb = self._db_to_matrix_db(self.train) self.augmented_mdb = self._combine_two_mdbs(self.matrix_db, extra_mdb) self.augmented_mdb_valid = self.augmented_mdb self.augmented_mdb_test = self.augmented_mdb else: self.train_facts, self.num_train_fact \ = self._parse_triplets(self.train_facts_file) self.test_facts, self.num_test_fact \ = self._parse_triplets(self.test_facts_file) self.matrix_db_train = self._db_to_matrix_db(self.train_facts) self.matrix_db_test = self._db_to_matrix_db(self.test_facts) self.matrix_db_valid = self._db_to_matrix_db(self.train_facts) if self.type_check: self.domains_file = os.path.join(folder, "stats/domains.txt") self.domains = self._parse_domains_file(self.domains_file) self.train = sorted(self.train, key=lambda x: x[0]) self.test = sorted(self.test, key=lambda x: x[0]) self.valid = sorted(self.valid, key=lambda x: x[0]) self.num_operator = 2 * self.domain_size else: self.domains = None self.num_operator = 2 * self.num_relation # get rules for queries and their inverses appeared in train and test self.query_for_rules = list(set(zip(*self.train)[0]) | set(zip(*self.test)[0]) | set(zip(*self._augment_with_reverse(self.train))[0]) | set(zip(*self._augment_with_reverse(self.test))[0])) self.parser = self._create_parser() def _create_parser(self): """Create a parser that maps numbers to queries and operators given queries""" assert(self.num_query==2*len(self.relation_to_number)==2*self.num_relation) parser = {"query":{}, "operator":{}} number_to_relation = {value: key for key, value in self.relation_to_number.items()} for key, value in self.relation_to_number.items(): parser["query"][value] = key parser["query"][value + self.num_relation] = "inv_" + key for query in xrange(self.num_relation): d = {} if self.type_check: for i, o in enumerate(self.domains[query]): d[i] = number_to_relation[o] d[i + self.domain_size] = "inv_" + number_to_relation[o] else: for k, v in number_to_relation.items(): d[k] = v d[k + self.num_relation] = "inv_" + v parser["operator"][query] = d parser["operator"][query + self.num_relation] = d return parser def _parse_domains_file(self, file_name): result = {} with open(file_name, "r") as f: for line in f: l = line.strip().split(",") l = [self.relation_to_number[i] for i in l] relation = l[0] this_domain = l[1:1+self.domain_size] if len(this_domain) == self.domain_size: pass else: # fill in blanks num_remain = self.domain_size - len(this_domain) remains = [i for i in xrange(self.num_relation) if i not in this_domain] pads = np.random.choice(remains, num_remain, replace=False) this_domain += list(pads) this_domain.sort() assert(len(set(this_domain)) == self.domain_size) assert(len(this_domain) == self.domain_size) result[relation] = this_domain for r in xrange(self.num_relation): if r not in result.keys(): result[r] = np.random.choice(range(self.num_relation), self.domain_size, replace=False) return result def _numerical_encode(self): relation_to_number = {} with open(self.relation_file) as f: for line in f: l = line.strip().split() assert(len(l) == 1) relation_to_number[l[0]] = len(relation_to_number) entity_to_number = {} with open(self.entity_file) as f: for line in f: l = line.strip().split() assert(len(l) == 1) entity_to_number[l[0]] = len(entity_to_number) return relation_to_number, entity_to_number def _parse_triplets(self, file): """Convert (head, relation, tail) to (relation, head, tail)""" output = [] with open(file) as f: for line in f: l = line.strip().split("\t") assert(len(l) == 3) output.append((self.relation_to_number[l[1]], self.entity_to_number[l[0]], self.entity_to_number[l[2]])) return output, len(output) def _split_valid_from_train(self): valid = [] new_train = [] for fact in self.train: dice = np.random.uniform() if dice < 0.1: valid.append(fact) else: new_train.append(fact) np.random.shuffle(new_train) return valid, new_train def _db_to_matrix_db(self, db): matrix_db = {r: ([[0,0]], [0.], [self.num_entity, self.num_entity]) for r in xrange(self.num_relation)} for i, fact in enumerate(db): rel = fact[0] head = fact[1] tail = fact[2] value = 1. matrix_db[rel][0].append([head, tail]) matrix_db[rel][1].append(value) return matrix_db def _combine_two_mdbs(self, mdbA, mdbB): """Assume mdbA and mdbB contain distinct elements.""" new_mdb = {} for key, value in mdbA.items(): new_mdb[key] = value for key, value in mdbB.items(): try: value_A = mdbA[key] new_mdb[key] = [value_A[0] + value[0], value_A[1] + value[1], value_A[2]] except KeyError: new_mdb[key] = value return new_mdb def _count_batch(self, samples, batch_size): relations = zip(*samples)[0] relations_counts = Counter(relations) num_batches = [ceil(1. * x / batch_size) for x in relations_counts.values()] return int(sum(num_batches)) def reset(self, batch_size): self.batch_size = batch_size self.train_start = 0 self.valid_start = 0 self.test_start = 0 if not self.type_check: self.num_batch_train = self.num_train / batch_size + 1 self.num_batch_valid = self.num_valid / batch_size + 1 self.num_batch_test = self.num_test / batch_size + 1 else: self.num_batch_train = self._count_batch(self.train, batch_size) self.num_batch_valid = self._count_batch(self.valid, batch_size) self.num_batch_test = self._count_batch(self.test, batch_size) def train_resplit(self, no_link_percent): new_train, new_facts = resplit(self.train, self.facts, no_link_percent) self.train = new_train self.matrix_db_train = self._db_to_matrix_db(new_facts) ######################################################################### def _subset_of_matrix_db(self, matrix_db, domain): subset_matrix_db = {} for i, r in enumerate(domain): subset_matrix_db[i] = matrix_db[r] return subset_matrix_db def _augment_with_reverse(self, triplets): augmented = [] for triplet in triplets: augmented += [triplet, (triplet[0]+self.num_relation, triplet[2], triplet[1])] return augmented def _next_batch(self, start, size, samples): assert(start < size) end = min(start + self.batch_size, size) if self.type_check: this_batch_tmp = samples[start:end] major_relation = this_batch_tmp[0][0] # assume sorted by relations batch_size = next((i for i in range(len(this_batch_tmp)) if this_batch_tmp[i][0] != major_relation), len(this_batch_tmp)) end = start + batch_size assert(end <= size) next_start = end % size this_batch = samples[start:end] if self.query_include_reverse: this_batch = self._augment_with_reverse(this_batch) this_batch_id = range(start, end) return next_start, this_batch, this_batch_id def _triplet_to_feed(self, triplets): queries, heads, tails = zip(*triplets) return queries, heads, tails def next_test(self): self.test_start, this_batch, _ = self._next_batch(self.test_start, self.num_test, self.test) if self.share_db and self.use_extra_facts: matrix_db = self.augmented_mdb_test else: matrix_db = self.matrix_db_test if self.type_check: query = this_batch[0][0] matrix_db = self._subset_of_matrix_db(matrix_db, self.domains[query]) return self._triplet_to_feed(this_batch), matrix_db def next_valid(self): self.valid_start, this_batch, _ = self._next_batch(self.valid_start, self.num_valid, self.valid) if self.share_db and self.use_extra_facts: matrix_db = self.augmented_mdb_valid else: matrix_db = self.matrix_db_valid if self.type_check: query = this_batch[0][0] matrix_db = self._subset_of_matrix_db(matrix_db, self.domains[query]) return self._triplet_to_feed(this_batch), matrix_db def next_train(self): self.train_start, this_batch, this_batch_id = self._next_batch(self.train_start, self.num_train, self.train) if self.share_db and self.use_extra_facts: extra_facts = [fact for i, fact in enumerate(self.train) if i not in this_batch_id] extra_mdb = self._db_to_matrix_db(extra_facts) augmented_mdb = self._combine_two_mdbs(extra_mdb, self.matrix_db_train) matrix_db = augmented_mdb else: matrix_db = self.matrix_db_train if self.type_check: query = this_batch[0][0] matrix_db = self._subset_of_matrix_db(matrix_db, self.domains[query]) return self._triplet_to_feed(this_batch), matrix_db class DataPlus(Data): def __init__(self, folder, seed): np.random.seed(seed) self.seed = seed self.kb_relation_file = os.path.join(folder, "kb_relations.txt") self.kb_entity_file = os.path.join(folder, "kb_entities.txt") self.query_vocab_file = os.path.join(folder, "query_vocabs.txt") self.kb_relation_to_number = self._numerical_encode(self.kb_relation_file) self.kb_entity_to_number = self._numerical_encode(self.kb_entity_file) self.query_vocab_to_number = self._numerical_encode(self.query_vocab_file) self.test_file = os.path.join(folder, "test.txt") self.train_file = os.path.join(folder, "train.txt") self.valid_file = os.path.join(folder, "valid.txt") self.facts_file = os.path.join(folder, "facts.txt") self.test, self.num_test = self._parse_examples(self.test_file) self.train, self.num_train = self._parse_examples(self.train_file) self.valid, self.num_valid = self._parse_examples(self.valid_file) self.facts, self.num_fact = self._parse_facts(self.facts_file) self.all_exams = set([tuple(q + [h, t]) for (q, h, t) in self.train + self.test + self.valid]) self.num_word = len(self.test[0][0]) self.num_vocab = len(self.query_vocab_to_number) self.num_relation = len(self.kb_relation_to_number) self.num_operator = 2 * self.num_relation self.num_entity = len(self.kb_entity_to_number) self.matrix_db = self._db_to_matrix_db(self.facts) self.matrix_db_train = self.matrix_db self.matrix_db_test = self.matrix_db self.matrix_db_valid = self.matrix_db self.type_check = False self.domain_size = None self.use_extra_facts = False self.query_include_reverse = False self.share_db = False self.parser = self._create_parser() #self.query_for_rules = [list(q) for q in Counter([tuple(q) for (q, _, _) in self.test]).keys()] self.query_for_rules = [list(q) for q in set([tuple(q) for (q, _, _) in self.test + self.train])] def _numerical_encode(self, file_name): lines = [l.strip() for l in open(file_name, "r").readlines()] line_to_number = {line: i for i, line in enumerate(lines)} return line_to_number def _parse_examples(self, file_name): lines = [l.strip().split("\t") for l in open(file_name, "r").readlines()] triplets = [[[self.query_vocab_to_number[w] for w in l[1].split(",")], self.kb_entity_to_number[l[0]], self.kb_entity_to_number[l[2]],] for l in lines] return triplets, len(triplets) def _parse_facts(self, file_name): lines = [l.strip().split("\t") for l in open(file_name, "r").readlines()] facts = [[self.kb_relation_to_number[l[1]], self.kb_entity_to_number[l[0]], self.kb_entity_to_number[l[2]]] for l in lines] return facts, len(facts) def _create_parser(self): parser = {"operator":{}} number_to_relation = {value: key for key, value in self.kb_relation_to_number.items()} number_to_query_vocab = {value: key for key, value in self.query_vocab_to_number.items()} parser["query"] = lambda ws: ",".join([number_to_query_vocab[w] for w in ws]) + " " d = {} for k, v in number_to_relation.items(): d[k] = v d[k + self.num_relation] = "inv_" + v parser["operator"] = d return parser def is_true(self, q, h, t): if tuple(q + [h, t]) in self.all_exams: return True else: return False