#!/usr/bin/env python # -*- coding: utf-8 -*- import jellyfish import fileinput import functools import multiprocessing import re import time import sys import os import pickle from breds.sentence import Sentence from whoosh.index import open_dir, os from whoosh.query import spans from whoosh import query from nltk import word_tokenize, bigrams from nltk.corpus import stopwords from collections import defaultdict __author__ = "David S. Batista" __email__ = "dsbatista@inesc-id.pt" # relational words used in calculating the set C and D with the proximity PMI founded_unigrams = ['founder', 'co-founder', 'cofounder', 'co-founded', 'cofounded', 'founded', 'founders'] founded_bigrams = ['started by'] acquired_unigrams = ['owns', 'acquired', 'bought', 'acquisition'] acquired_bigrams = [] headquarters_unigrams = ['headquarters', 'headquartered', 'offices', 'office', 'building', 'buildings', 'factory', 'plant', 'compound'] headquarters_bigrams = ['based in', 'located in', 'main office', ' main offices', 'offices in', 'building in','office in', 'branch in', 'store in', 'firm in', 'factory in', 'plant in', 'head office', 'head offices', 'in central', 'in downtown', 'outskirts of', 'suburs of'] employment_unigrams = ['chief', 'scientist', 'professor', 'biologist', 'ceo', 'CEO', 'employer'] employment_bigrams = [] bad_tokens = [",", "(", ")", ";", "''", "``", "'s", "-", "vs.", "v", "'", ":", ".", "--"] stopwords_list = stopwords.words('english') not_valid = bad_tokens + stopwords_list # PMI value for proximity PMI = 0.7 # Parameters for relationship extraction from Sentence MAX_TOKENS_AWAY = 6 MIN_TOKENS_AWAY = 1 CONTEXT_WINDOW = 2 # DEBUG stuff PRINT_NOT_FOUND = False # stores all variations matched with database manager = multiprocessing.Manager() all_in_database = manager.dict() class ExtractedFact(object): def __init__(self, _e1, _e2, _score, _bef, _bet, _aft, _sentence, _passive_voice): self.ent1 = _e1 self.ent2 = _e2 self.score = _score self.bef_words = _bef self.bet_words = _bet self.aft_words = _aft self.sentence = _sentence self.passive_voice = _passive_voice def __cmp__(self, other): if other.score > self.score: return -1 elif other.score < self.score: return 1 else: return 0 def __hash__(self): sig = hash(self.ent1) ^ hash(self.ent2) ^ hash(self.bef_words) ^ \ hash(self.bet_words) ^ hash(self.aft_words) ^ \ hash(self.score) ^ hash(self.sentence) return sig def __eq__(self, other): if self.ent1 == other.ent1 and \ self.ent2 == other.ent2 and \ self.score == other.score and \ self.bef_words == other.bef_words and \ self.bet_words == other.bet_words and \ self.aft_words == other.aft_words and \ self.sentence == other.sentence: return True else: return False # ########################################### # Misc., Utils, parsing corpus into memory # # ########################################### def timecall(f): @functools.wraps(f) def wrapper(*args, **kw): start = time.time() result = f(*args, **kw) end = time.time() # print "%s %.2f seconds" % (f.__name__, end - start) print("Time taken: %.2f seconds" % (end - start)) return result return wrapper def is_acronym(entity): if len(entity.split()) == 1 and entity.isupper(): return True else: return False def process_corpus(queue, g_dash, e1_type, e2_type): count = 0 added = 0 while True: try: if count % 25000 == 0: print(multiprocessing.current_process(), "In Queue", queue.qsize(), "Total added: ", added) line = queue.get_nowait() s = Sentence(line.strip(), e1_type, e2_type, MAX_TOKENS_AWAY, MIN_TOKENS_AWAY, CONTEXT_WINDOW) for r in s.relationships: tokens = word_tokenize(r.between) if all(x in not_valid for x in word_tokenize(r.between)): continue elif "," in tokens and tokens[0] != ',': continue else: g_dash.append(r) added += 1 count += 1 except queue.Empty: break def process_output(data, threshold, rel_type): """ parses the file with the relationships extracted by the system each relationship is transformed into a ExtracteFact class """ system_output = list() for line in fileinput.input(data): if line.startswith('instance'): instance_parts, score = line.split("score:") e1, e2 = instance_parts.split("instance:")[1].strip().split('\t') if line.startswith('sentence'): sentence = line.split("sentence:")[1].strip() if line.startswith('pattern_bef:'): bef = line.split("pattern_bef:")[1].strip() if line.startswith('pattern_bet:'): bet = line.split("pattern_bet:")[1].strip() if line.startswith('pattern_aft:'): aft = line.split("pattern_aft:")[1].strip() if line.startswith('passive voice:'): tmp = line.split("passive voice:")[1].strip() if tmp == 'False': passive_voice = False elif tmp == 'True': passive_voice = True if line.startswith('\n') and float(score) >= threshold: if 'bef' not in locals(): bef = '' if 'aft' not in locals(): aft = '' if passive_voice is True and rel_type in ['acquired', 'headquarters']: r = ExtractedFact(e2, e1, float(score), bef, bet, aft, sentence, passive_voice) else: r = ExtractedFact(e1, e2, float(score), bef, bet, aft, sentence, passive_voice) if ("'s parent" in bet or 'subsidiary of' in bet or bet == 'subsidiary') and rel_type == 'acquired': r = ExtractedFact(e2, e1, float(score), bef, bet, aft, sentence, passive_voice) system_output.append(r) fileinput.close() return system_output def process_freebase(data, rel_type): # Load relationships from Freebase and keep them in the same direction has # the output of the extraction system """ # rel_type Gold standard directions founder_arg2_arg1 PER-ORG headquarters_arg1_arg2 ORG-LOC acquired_arg1_arg2 ORG-ORG contained_by_arg1_arg2 LOC-LOC """ # store a tuple (entity1, entity2) in a dictionary database_1 = defaultdict(list) # store in a dictionary per relationship: dict['ent1'] = 'ent2' database_2 = defaultdict(list) # store in a dictionary per relationship: dict['ent2'] = 'ent1' database_3 = defaultdict(list) # regex used to clean entities numbered = re.compile(r'#[0-9]+$') # for the 'founder' relationships don't load those from freebase, as it # lists countries (i.e., LOC entities) as founders and not persons founder_to_ignore = ['UNESCO', 'World Trade Organization', 'European Union', 'United Nations'] for line in fileinput.input(data): if line.startswith('#'): continue try: e1, r, e2 = line.split('\t') except Exception: print(line) print(line.split('\t')) sys.exit() # ignore some entities, which are Freebase identifiers or are ambigious if e1.startswith('/') or e2.startswith('/'): continue if e1.startswith('m/') or e2.startswith('m/'): continue if re.search(numbered, e1) or re.search(numbered, e2): continue if e2.strip() in founder_to_ignore: continue else: if "(" in e1: e1 = re.sub(r"\(.*\)", "", e1).strip() if "(" in e2: e2 = re.sub(r"\(.*\)", "", e2).strip() if rel_type == 'founder' or rel_type == 'employer': database_1[(e2.strip(), e1.strip())].append(r) database_2[e2.strip()].append(e1.strip()) database_3[e1.strip()].append(e2.strip()) else: database_1[(e1.strip(), e2.strip())].append(r) database_2[e1.strip()].append(e2.strip()) database_3[e2.strip()].append(e1.strip()) return database_1, database_2, database_3 def load_acronyms(data): acronyms = defaultdict(list) for line in fileinput.input(data): parts = line.split('\t') acronym = parts[0].strip() if "/" in acronym: continue expanded = parts[-1].strip() if "/" in expanded: continue acronyms[acronym].append(expanded) fileinput.close() return acronyms def load_dbpedia(data, database_1, database_2): for line in fileinput.input(data): e1, rel, e2, p = line.split() e1 = e1.split('<http://dbpedia.org/resource/')[1].replace(">", "") e2 = e2.split('<http://dbpedia.org/resource/')[1].replace(">", "") e1 = re.sub("_", " ", e1) e2 = re.sub("_", " ", e2) if "(" in e1 or "(" in e2: e1 = re.sub("\(.*\)", "", e1) e2 = re.sub("\(.*\)", "", e2) # store a tuple (entity1, entity2) in a dictionary database_1[(e1.strip(), e2.strip())].append(p) # store in a dictionary per relationship: dict['ent1'] = 'ent2' database_2[e1.strip()].append(e2.strip()) else: e1 = e1.decode("utf8").strip() e2 = e2.decode("utf8").strip() # store a tuple (entity1, entity2) in a dictionary database_1[(e1, e2)].append(p) # store in a dictionary per relationship: dict['ent1'] = 'ent2' database_2[e1.strip()].append(e2.strip()) fileinput.close() return database_1, database_2 def extract_bigrams(text): tokens = word_tokenize(text) return [gram[0]+' '+gram[1] for gram in bigrams(tokens)] # ######################################## # Estimations of sets and intersections # # ######################################## @timecall def calculate_a(not_in_database, e1_type, e2_type, index, rel_words_unigrams, rel_words_bigrams): m = multiprocessing.Manager() queue = m.Queue() num_cpus = multiprocessing.cpu_count() results = [m.list() for _ in range(num_cpus)] not_found = [m.list() for _ in range(num_cpus)] for r in not_in_database: queue.put(r) processes = [multiprocessing.Process( target=proximity_pmi_a, args=(e1_type, e2_type, queue, index, results[i], not_found[i], rel_words_unigrams, rel_words_bigrams)) for i in range(num_cpus)] for proc in processes: proc.start() for proc in processes: proc.join() a = list() for l in results: a.extend(l) wrong = list() for l in not_found: wrong.extend(l) return a, wrong @timecall def calculate_b(output, database_1, database_2, database_3, e1_type, e2_type): # intersection between the system output and the database # it is assumed that every fact in this region is correct m = multiprocessing.Manager() queue = m.Queue() num_cpus = multiprocessing.cpu_count() results = [m.list() for _ in range(num_cpus)] no_matches = [m.list() for _ in range(num_cpus)] for r in output: queue.put(r) processes = [multiprocessing.Process( target=string_matching_parallel, args=(results[i], no_matches[i], database_1, database_2, database_3, queue, e1_type, e2_type)) for i in range(num_cpus)] for proc in processes: proc.start() for proc in processes: proc.join() b = set() for l in results: b.update(l) not_found = set() for l in no_matches: not_found.update(l) return b, not_found @timecall def calculate_c(corpus, database_1, database_2, database_3, b, e1_type, e2_type, rel_type, rel_words_unigrams, rel_words_bigrams): # contains the database facts described in the corpus # but not extracted by the system # # G' = superset of G, cartesian product of all possible entities and # relations (i.e., G' = E x R x E) # for now, all relationships from a sentence print("Building G', a superset of G") m = multiprocessing.Manager() queue = m.Queue() g_dash = m.list() num_cpus = multiprocessing.cpu_count() # check if superset G' for e1_type, e2_type already exists and # if G' minus KB for rel_type exists # if it exists load into g_dash_set if os.path.isfile("superset_" + e1_type + "_" + e2_type + ".pkl"): f = open("superset_" + e1_type + "_" + e2_type + ".pkl") print("\nLoading superset G'", "superset_" + e1_type + "_" + \ e2_type + ".pkl") g_dash_set = pickle.load(f) f.close() # else generate G' and G minus D else: with open(corpus) as f: data = f.readlines() count = 0 print("Storing in shared Queue") for l in data: if count % 50000 == 0: sys.stdout.write(".") sys.stdout.flush() queue.put(l) count += 1 print("\nQueue size:", queue.qsize()) processes = [multiprocessing.Process( target=process_corpus, args=(queue, g_dash, e1_type, e2_type)) for _ in range(num_cpus)] print("Extracting all possible " + e1_type + "," + e2_type + \ " relationships from the corpus") print("Running", len(processes), "threads") for proc in processes: proc.start() for proc in processes: proc.join() print(len(g_dash), "relationships built") g_dash_set = set(g_dash) print(len(g_dash_set), "unique relationships") print("Dumping into file", "superset_" + e1_type + "_" + e2_type + ".pkl") f = open("superset_" + e1_type + "_" + e2_type + ".pkl", "wb") pickle.dump(g_dash_set, f) f.close() # Estimate G \in D, look for facts in G' that a match a fact in the database # check if already exists for this particular relationship if os.path.isfile(rel_type + "_g_intersection_d.pkl") and \ os.path.isfile(rel_type + "_g_minus_d.pkl"): f = open(rel_type + "_g_intersection_d.pkl", "r") print("\nLoading G intersected with D", rel_type + "_g_intersection_d.pkl") g_intersect_d = pickle.load(f) f.close() f = open(rel_type + "_g_minus_d.pkl") print("\nLoading superset G' minus D", rel_type + "_g_minus_d.pkl") g_minus_d = pickle.load(f) f.close() else: print("Estimating G intersection with D") g_intersect_d = set() print("G':", len(g_dash_set)) print("Database:", len(list(database_1.keys()))) # Facts not in the database, to use in estimating set d g_minus_d = set() queue = manager.Queue() results = [manager.list() for _ in range(num_cpus)] no_matches = [manager.list() for _ in range(num_cpus)] # Load everything into a shared queue for r in g_dash_set: queue.put(r) processes = [multiprocessing.Process( target=string_matching_parallel, args=(results[i], no_matches[i], database_1, database_2, database_3, queue, e1_type, e2_type)) for i in range(num_cpus)] for proc in processes: proc.start() for proc in processes: proc.join() for l in results: g_intersect_d.update(l) for l in no_matches: g_minus_d.update(l) print("Extra filtering: from the intersection of G' with D, " \ "select only those based on keywords") print(len(g_intersect_d)) filtered = set() for r in g_intersect_d: unigrams_bet = word_tokenize(r.between) unigrams_bef = word_tokenize(r.before) unigrams_aft = word_tokenize(r.after) bigrams_bet = extract_bigrams(r.between) if any(x in rel_words_unigrams for x in unigrams_bet): filtered.add(r) continue if any(x in rel_words_unigrams for x in unigrams_bef): filtered.add(r) continue if any(x in rel_words_unigrams for x in unigrams_aft): filtered.add(r) continue elif any(x in rel_words_bigrams for x in bigrams_bet): filtered.add(r) continue g_intersect_d = filtered print(len(g_intersect_d), "relationships in the corpus " \ "which are in the KB") if len(g_intersect_d) > 0: # dump G intersected with D to file f = open(rel_type + "_g_intersection_d.pkl", "wb") pickle.dump(g_intersect_d, f) f.close() print("Extra filtering: from the G' not in D, select only " \ "those based on keywords") filtered = set() for r in g_minus_d: unigrams_bet = word_tokenize(r.between) unigrams_bef = word_tokenize(r.before) unigrams_aft = word_tokenize(r.after) bigrams_bet = extract_bigrams(r.between) if any(x in rel_words_unigrams for x in unigrams_bet): filtered.add(r) continue if any(x in rel_words_unigrams for x in unigrams_bef): filtered.add(r) continue if any(x in rel_words_unigrams for x in unigrams_aft): filtered.add(r) continue elif any(x in rel_words_bigrams for x in bigrams_bet): filtered.add(r) continue g_minus_d = filtered print(len(g_minus_d), "relationships in the corpus not in the KB") if len(g_minus_d) > 0: # dump G - D to file, relationships in the corpus not in KB f = open(rel_type + "_g_minus_d.pkl", "wb") pickle.dump(g_minus_d, f) f.close() # having B and G_intersect_D => |c| = |G_intersect_D| - |b| c = g_intersect_d.difference(set(b)) assert len(g_minus_d) > 0 return c, g_minus_d @timecall def calculate_d(g_minus_d, a, e1_type, e2_type, index, rel_type, rel_words_unigrams, rel_words_bigrams): # contains facts described in the corpus that are not # in the system output nor in the database # # by applying the PMI of the facts not in the database (i.e., G' \in D) # we determine |G \ D|, then we can estimate |d| = |G \ D| - |a| # # |G' \ D| # determine facts not in the database, with high PMI, that is, # facts that are true and are not in the database # check if it was already calculated and stored in disk if os.path.isfile(rel_type + "_high_pmi_not_in_database.pkl"): f = open(rel_type + "_high_pmi_not_in_database.pkl") print("\nLoading high PMI facts not in the database", \ rel_type + "_high_pmi_not_in_database.pkl") g_minus_d = pickle.load(f) f.close() else: m = multiprocessing.Manager() queue = m.Queue() num_cpus = multiprocessing.cpu_count() results = [m.list() for _ in range(num_cpus)] for r in g_minus_d: queue.put(r) # calculate PMI for r not in database processes = [multiprocessing.Process( target=proximity_pmi_rel_word, args=(e1_type, e2_type, queue, index, results[i], rel_words_unigrams, rel_words_bigrams)) for i in range(num_cpus)] for proc in processes: proc.start() for proc in processes: proc.join() g_minus_d = set() for l in results: g_minus_d.update(l) print("High PMI facts not in the database", len(g_minus_d)) # dump high PMI facts not in the database if len(g_minus_d) > 0: f = open(rel_type + "_high_pmi_not_in_database.pkl", "wb") print("Dumping high PMI facts not in the database to", \ rel_type + "_high_pmi_not_in_database.pkl") pickle.dump(g_minus_d, f) f.close() return g_minus_d.difference(a) ######################################################################## # Parallelized functions: each function will run as a different process # ######################################################################## def proximity_pmi_rel_word(e1_type, e2_type, queue, index, results, rel_words_unigrams, rel_words_bigrams): idx = open_dir(index) count = 0 distance = MAX_TOKENS_AWAY q_limit = 500 with idx.searcher() as searcher: while True: try: r = queue.get_nowait() if count % 50 == 0: print("\n", multiprocessing.current_process(), \ "In Queue", queue.qsize(), \ "Total Matched: ", len(results)) if (r.ent1, r.ent2) not in all_in_database: # if its not in the database calculate the PMI entity1 = "<" + e1_type + ">" + r.ent1 + "</" + e1_type + ">" entity2 = "<" + e2_type + ">" + r.ent2 + "</" + e2_type + ">" t1 = query.Term('sentence', entity1) t3 = query.Term('sentence', entity2) # Entities proximity query without relational words q1 = spans.SpanNear2( [t1, t3], slop=distance, ordered=True, mindist=1) hits = searcher.search(q1, limit=q_limit) # Entities proximity considering relational words # From the results above count how many contain a # valid relational word hits_with_r = 0 hits_without_r = 0 for s in hits: sentence = s.get("sentence") s = Sentence(sentence, e1_type, e2_type, MAX_TOKENS_AWAY, MIN_TOKENS_AWAY, CONTEXT_WINDOW) for s_r in s.relationships: if r.ent1.decode("utf8") == s_r.ent1 and \ r.ent2.decode("utf8") == s_r.ent2: unigrams_rel_words = word_tokenize(s_r.between) bigrams_rel_words = extract_bigrams(s_r.between) if all(x in not_valid for x in unigrams_rel_words): hits_without_r += 1 continue elif any(x in rel_words_unigrams for x in unigrams_rel_words): hits_with_r += 1 elif any(x in rel_words_bigrams for x in bigrams_rel_words): hits_with_r += 1 else: hits_without_r += 1 if hits_with_r > 0 and hits_without_r > 0: pmi = float(hits_with_r) / float(hits_without_r) if pmi >= PMI: if word_tokenize(s_r.between)[-1] == 'by': tmp = s_r.ent2 s_r.ent2 = s_r.ent1 s_r.ent1 = tmp results.append(r) count += 1 except queue.Empty: break def string_matching_parallel(matches, no_matches, database_1, database_2, database_3, queue, e1_type, e2_type): count = 0 while True: try: r = queue.get_nowait() found = False count += 1 if count % 500 == 0: print(multiprocessing.current_process(), \ "In Queue", queue.qsize()) # check if its in cache, i.e., if tuple was already matched if (r.ent1, r.ent2) in all_in_database: matches.append(r) found = True # check for a relationship with a direct string matching if found is False: if len(database_1[(r.ent1.decode("utf8"), r.ent2.decode("utf8"))]) > 0: matches.append(r) all_in_database[(r.ent1, r.ent2)] = "Found" found = True if found is False: # database_2: arg_1 rel list(arg_2) # check for a direct string matching with all possible arg2 # FOUNDER : r.ent1:ORG r.ent2:PER # DATABASE_1: (ORG,PER) # DATABASE_2: ORG list<PER> # DATABASE_3: PER list<ORG> ent2 = database_2[r.ent1.decode("utf8")] if len(ent2) > 0: if r.ent2 in ent2: matches.append(r) all_in_database[(r.ent1, r.ent2)] = "Found" found = True # if a direct string matching occur with arg_2, check for a # direct string matching with all possible arg1 entities if found is False: arg1_list = database_3[r.ent2] if arg1_list is not None: for arg1 in arg1_list: if e1_type == 'ORG': new_arg1 = re.sub(r" Corporation| Inc\.", "", arg1) else: new_arg1 = arg1 # Jaccardi set_1 = set(new_arg1.split()) set_2 = set(r.ent1.split()) jaccardi = \ float(len(set_1.intersection(set_2))) / \ float(len(set_1.union(set_2))) if jaccardi >= 0.5: matches.append(r) all_in_database[(r.ent1, r.ent2)] = "Found" found = True # Jaro Winkler elif jaccardi <= 0.5: score = jellyfish.jaro_winkler( new_arg1.upper(), r.ent1.upper() ) if score >= 0.9: matches.append(r) all_in_database[(r.ent1, r.ent2)] = "Found" found = True # if a direct string matching occur with arg_1, # check for a direct string matching # with all possible arg_2 entities if found is False: arg2_list = database_2[r.ent1] if arg2_list is not None: for arg2 in arg2_list: # Jaccardi if e1_type == 'ORG': new_arg2 = re.sub(r" Corporation| Inc\.", "", arg2) else: new_arg2 = arg2 set_1 = set(new_arg2.split()) set_2 = set(r.ent2.split()) jaccardi = \ float(len(set_1.intersection(set_2))) / \ float(len(set_1.union(set_2))) if jaccardi >= 0.5: matches.append(r) all_in_database[(r.ent1, r.ent2)] = "Found" found = True # Jaro Winkler elif jaccardi <= 0.5: score = jellyfish.jaro_winkler( new_arg2.upper(), r.ent2.upper() ) if score >= 0.9: matches.append(r) all_in_database[(r.ent1, r.ent2)] = "Found" found = True if found is False: no_matches.append(r) if PRINT_NOT_FOUND is True: print(r.ent1, '\t', r.ent2) except queue.Empty: break def proximity_pmi_a(e1_type, e2_type, queue, index, results, not_found, rel_words_unigrams, rel_words_bigrams): idx = open_dir(index) count = 0 q_limit = 500 with idx.searcher() as searcher: while True: try: r = queue.get_nowait() count += 1 if count % 50 == 0: print(multiprocessing.current_process(), \ "To Process", queue.qsize(), \ "Correct found:", len(results)) # if its not in the database calculate the PMI entity1 = "<" + e1_type + ">" + r.ent1 + "</" + e1_type + ">" entity2 = "<" + e2_type + ">" + r.ent2 + "</" + e2_type + ">" t1 = query.Term('sentence', entity1) t3 = query.Term('sentence', entity2) # First count the proximity (MAX_TOKENS_AWAY) occurrences # of entities r.e1 and r.e2 q1 = spans.SpanNear2([t1, t3], slop=MAX_TOKENS_AWAY, ordered=True, mindist=1) hits = searcher.search(q1, limit=q_limit) # Entities proximity considering relational words # From the results above count how many contain a # valid relational word hits_with_r = 0 hits_without_r = 0 fact_bet_words_tokens = word_tokenize(r.bet_words) for s in hits: sentence = s.get("sentence") s = Sentence(sentence, e1_type, e2_type, MAX_TOKENS_AWAY, MIN_TOKENS_AWAY, CONTEXT_WINDOW) for s_r in s.relationships: if r.ent1.decode("utf8") == s_r.ent1 and \ r.ent2.decode("utf8") == s_r.ent2: unigrams_bef_words = word_tokenize(s_r.before) unigrams_bet_words = word_tokenize(s_r.between) unigrams_aft_words = word_tokenize(s_r.after) bigrams_rel_words = extract_bigrams(s_r.between) if fact_bet_words_tokens == unigrams_bet_words: hits_with_r += 1 elif any(x in rel_words_unigrams for x in unigrams_bef_words): hits_with_r += 1 elif any(x in rel_words_unigrams for x in unigrams_bet_words): hits_with_r += 1 elif any(x in rel_words_unigrams for x in unigrams_aft_words): hits_with_r += 1 elif rel_words_bigrams == bigrams_rel_words: hits_with_r += 1 else: hits_without_r += 1 if hits_with_r > 0 and hits_without_r > 0: pmi = float(hits_with_r) / float(hits_without_r) if pmi >= PMI: results.append(r) else: not_found.append(r) else: not_found.append(r) count += 1 except queue.Empty: break def main(): # "Automatic Evaluation of Relation Extraction Systems on Large-scale" # https://akbcwekex2012.files.wordpress.com/2012/05/8_paper.pdf # # S - system output # D - database (freebase) # G - will be the resulting ground truth # G' - superset, contains true facts, and wrong facts # a - contains correct facts from the system output # # b - intersection between the system output and the # database (i.e., freebase), # it is assumed that every fact in this region is correct # c - contains the database facts described in the corpus # but not extracted by the system # d - contains the facts described in the corpus that are not # in the system output nor in the database # # Precision = |a|+|b| / |S| # Recall = |a|+|b| / |a| + |b| + |c| + |d| # F1 = 2*P*R / P+R if len(sys.argv) == 1: print("No arguments") print("Use: evaluation.py threshold system_output rel_type database") print("\n") sys.exit(0) threhsold = float(sys.argv[1]) rel_type = sys.argv[3] # load relationships extracted by the system system_output = process_output(sys.argv[2], threhsold, rel_type) print("Relationships score threshold :", threhsold) print("System output relationships :", len(system_output)) # load freebase relationships as the database database_1, database_2, database_3 = process_freebase(sys.argv[4], rel_type) print("Freebase relationships loaded :", len(list(database_1.keys()))) # corpus from which the system extracted relationships corpus = "/home/dsbatista/gigaword/automatic-evaluation/" \ "sentences_matched_freebase_added_tags.txt" # index to be used to estimate proximity PMI index = "/home/dsbatista/gigaword/automatic-evaluation/index_full" # entities semantic type rel_words_unigrams = None rel_words_bigrams = None if rel_type == 'founder': e1_type = "ORG" e2_type = "PER" rel_words_unigrams = founded_unigrams rel_words_bigrams = founded_bigrams elif rel_type == 'acquired': e1_type = "ORG" e2_type = "ORG" rel_words_unigrams = acquired_unigrams rel_words_bigrams = acquired_unigrams elif rel_type == 'headquarters': # load dbpedia relationships print("Loading extra DBPedia relationships for", rel_type) load_dbpedia(sys.argv[5], database_1, database_2) e1_type = "ORG" e2_type = "LOC" rel_words_unigrams = headquarters_unigrams rel_words_bigrams = headquarters_bigrams elif rel_type == 'contained_by': e1_type = "LOC" e2_type = "LOC" elif rel_type == 'employer': e1_type = "ORG" e2_type = "PER" rel_words_unigrams = employment_unigrams rel_words_bigrams = employment_bigrams else: print("Invalid relationship type", rel_type) print("Use: founder, acquired, headquarters, employer") sys.exit(0) print("\nRelationship Type:", rel_type) print("Arg1 Type:", e1_type) print("Arg2 Type:", e2_type) print("\nCalculating set B: intersection between system output and database") b, not_in_database = calculate_b(system_output, database_1, database_2, database_3, e1_type, e2_type) print("System output :", len(system_output)) print("Found in database :", len(b)) print("Not found :", len(not_in_database)) assert len(system_output) == len(not_in_database) + len(b) print("\nCalculating set A: correct facts from system output not in " \ "the database (proximity PMI)") a, not_found = calculate_a(not_in_database, e1_type, e2_type, index, rel_words_unigrams, rel_words_bigrams) print("System output :", len(system_output)) print("Found in database :", len(b)) print("Correct in corpus :", len(a)) print("Not found :", len(not_found)) print("\n") assert len(system_output) == len(a) + len(b) + len(not_found) # Estimate G \intersected D = |b| + |c|, looking for relationships in G' # that match a relationship in D, once we have G \in D and |b|, |c| can be # derived by: |c| = |G \in D| - |b| G' = superset of G, cartesian product # of all possible entities and relations (i.e., G' = E x R x E) print("\nCalculating set C: database facts in the corpus but not " \ "extracted by the system") c, g_minus_d = calculate_c(corpus, database_1, database_2, database_3, b, e1_type, e2_type, rel_type, rel_words_unigrams, rel_words_bigrams) assert len(c) > 0 uniq_c = set() for r in c: uniq_c.add((r.ent1, r.ent2)) # By applying the PMI of the facts not in the database (i.e., G' \in D) # we determine |G \ D|, then we can estimate |d| = |G \ D| - |a| print("\nCalculating set D: facts described in the corpus not in " \ "the system output nor in the database") d = calculate_d(g_minus_d, a, e1_type, e2_type, index, rel_type, rel_words_unigrams, rel_words_bigrams) print("System output :", len(system_output)) print("Found in database :", len(b)) print("Correct in corpus :", len(a)) print("Not found :", len(not_found)) print("\n") assert len(d) > 0 uniq_d = set() for r in d: uniq_d.add((r.ent1, r.ent2)) print("|a| =", len(a)) print("|b| =", len(b)) print("|c| =", len(c), "(", len(uniq_c), ")") print("|d| =", len(d), "(", len(uniq_d), ")") print("|S| =", len(system_output)) print("|G| =", len(set(a).union(set(b).union(set(c).union(set(d)))))) print("Relationships not found:", len(set(not_found))) # Write relationships not found in the Database nor with high PMI # relational words to disk f = open(rel_type + "_" + sys.argv[2][-11:][:-4] + "_negative.txt", "w") for r in sorted(set(not_found), reverse=True): f.write('instance :' + r.ent1 + '\t' + r.ent2 + '\t' + str(r.score) + '\n') f.write('sentence :' + r.sentence + '\n') f.write('bef_words:' + r.bef_words + '\n') f.write('bet_words:' + r.bet_words + '\n') f.write('aft_words:' + r.aft_words + '\n') f.write('\n') f.close() # Write all correct relationships (sentence, entities and score) to file f = open(rel_type + "_" + sys.argv[2][-11:][:-4] + "_positive.txt", "w") for r in sorted(set(a).union(b), reverse=True): f.write('instance :' + r.ent1 + '\t' + r.ent2 + '\t' + str(r.score) + '\n') f.write('sentence :' + r.sentence + '\n') f.write('bef_words:' + r.bef_words + '\n') f.write('bet_words:' + r.bet_words + '\n') f.write('aft_words:' + r.aft_words + '\n') f.write('\n') f.close() a = set(a) b = set(b) output = set(system_output) if len(output) == 0: print("\nPrecision : 0.0") print("Recall : 0.0") print("F1 : 0.0") print("\n") elif float(len(a) + len(b)) == 0: print("\nPrecision : 0.0") print("Recall : 0.0") print("F1 : 0.0") print("\n") else: precision = float(len(a) + len(b)) / float(len(output)) recall = float(len(a) + len(b)) / float(len(a) + len(b) + len(uniq_c) + len(uniq_d)) f1 = 2 * (precision * recall) / (precision + recall) print("\nPrecision : ", precision) print("Recall : ", recall) print("F1 : ", f1) print("\n") if __name__ == "__main__": main()