#!/usr/bin/env python import sys import os import time import amr import smatch ERROR_LOG = sys.stderr DEBUG_LOG = sys.stderr verbose = False # directory on isi machine # change if needed isi_dir_pre = "/nfs/web/isi.edu/cgi-bin/div3/mt/save-amr" def get_names(file_dir, files): """ Get the annotator name list based on a list of files Args: file_dir: AMR file folder files: a list of AMR names, e.g. nw_wsj_0001_1 Returns: a list of user names who annotate all the files """ # for each user, check if they have files available # return user name list total_list = [] name_list = [] get_sub = False for path, subdir, dir_files in os.walk(file_dir): if not get_sub: total_list = subdir[:] get_sub = True else: break for user in total_list: has_file = True for f in files: file_path = file_dir + user + "/" + f + ".txt" if not os.path.exists(file_path): has_file = False break if has_file: name_list.append(user) if len(name_list) == 0: print("********Error: Cannot find any user who completes the files*************", file=ERROR_LOG) return name_list def compute_files(user1, user2, file_list, dir_pre, start_num): """ Compute the smatch scores for a file list between two users Args: user1: user 1 name user2: user 2 name file_list: file list dir_pre: the file location prefix start_num: the number of restarts in smatch Returns: smatch f score. """ match_total = 0 test_total = 0 gold_total = 0 for fi in file_list: file1 = dir_pre + user1 + "/" + fi + ".txt" file2 = dir_pre + user2 + "/" + fi + ".txt" if not os.path.exists(file1): print("*********Error: ", file1, "does not exist*********", file=ERROR_LOG) return -1.00 if not os.path.exists(file2): print("*********Error: ", file2, "does not exist*********", file=ERROR_LOG) return -1.00 try: file1_h = open(file1, "r") file2_h = open(file2, "r") except IOError: print("Cannot open the files", file1, file2, file=ERROR_LOG) break cur_amr1 = amr.AMR.get_amr_line(file1_h) cur_amr2 = amr.AMR.get_amr_line(file2_h) if cur_amr1 == "": print("AMR 1 is empty", file=ERROR_LOG) continue if cur_amr2 == "": print("AMR 2 is empty", file=ERROR_LOG) continue amr1 = amr.AMR.parse_AMR_line(cur_amr1) amr2 = amr.AMR.parse_AMR_line(cur_amr2) test_label = "a" gold_label = "b" amr1.rename_node(test_label) amr2.rename_node(gold_label) (test_inst, test_rel1, test_rel2) = amr1.get_triples() (gold_inst, gold_rel1, gold_rel2) = amr2.get_triples() if verbose: print("Instance triples of file 1:", len(test_inst), file=DEBUG_LOG) print(test_inst, file=DEBUG_LOG) print("Attribute triples of file 1:", len(test_rel1), file=DEBUG_LOG) print(test_rel1, file=DEBUG_LOG) print("Relation triples of file 1:", len(test_rel2), file=DEBUG_LOG) print(test_rel2, file=DEBUG_LOG) print("Instance triples of file 2:", len(gold_inst), file=DEBUG_LOG) print(gold_inst, file=DEBUG_LOG) print("Attribute triples of file 2:", len(gold_rel1), file=DEBUG_LOG) print(gold_rel1, file=DEBUG_LOG) print("Relation triples of file 2:", len(gold_rel2), file=DEBUG_LOG) print(gold_rel2, file=DEBUG_LOG) (best_match, best_match_num) = smatch.get_best_match(test_inst, test_rel1, test_rel2, gold_inst, gold_rel1, gold_rel2, test_label, gold_label) if verbose: print("best match number", best_match_num, file=DEBUG_LOG) print("Best Match:", smatch.print_alignment(best_match, test_inst, gold_inst), file=DEBUG_LOG) match_total += best_match_num test_total += (len(test_inst) + len(test_rel1) + len(test_rel2)) gold_total += (len(gold_inst) + len(gold_rel1) + len(gold_rel2)) smatch.match_triple_dict.clear() (precision, recall, f_score) = smatch.compute_f(match_total, test_total, gold_total) return "%.2f" % f_score def get_max_width(table, index): return max([len(str(row[index])) for row in table]) def pprint_table(table): """ Print a table in pretty format """ col_paddings = [] for i in range(len(table[0])): col_paddings.append(get_max_width(table,i)) for row in table: print(row[0].ljust(col_paddings[0] + 1), end="") for i in range(1, len(row)): col = str(row[i]).rjust(col_paddings[i]+2) print(col, end='') print("\n") def cb(option, value, parser): """ Callback function to handle variable number of arguments in optparse """ arguments = [value] for arg in parser.rargs: if arg[0] != "-": arguments.append(arg) else: del parser.rargs[:len(arguments)] break if getattr(parser.values, option.dest): arguments.extend(getattr(parser.values, option.dest)) setattr(parser.values, option.dest, arguments) def check_args(args): """ Parse arguments and check if the arguments are valid """ if not os.path.exists(args.fd): print("Not a valid path", args.fd, file=ERROR_LOG) return [], [], False if args.fl is not None: # we already ensure the file can be opened and opened the file file_line = args.fl.readline() amr_ids = file_line.strip().split() elif args.f is None: print("No AMR ID was given", file=ERROR_LOG) return [], [], False else: amr_ids = args.f names = [] check_name = True if args.p is None: names = get_names(args.fd, amr_ids) # no need to check names check_name = False if len(names) == 0: print("Cannot find any user who tagged these AMR", file=ERROR_LOG) return [], [], False else: names = args.p if len(names) == 0: print("No user was given", file=ERROR_LOG) return [], [], False if len(names) == 1: print("Only one user is given. Smatch calculation requires at least two users.", file=ERROR_LOG) return [], [], False if "consensus" in names: con_index = names.index("consensus") names.pop(con_index) names.append("consensus") # check if all the AMR_id and user combinations are valid if check_name: pop_name = [] for i, name in enumerate(names): for amr in amr_ids: amr_path = args.fd + name + "/" + amr + ".txt" if not os.path.exists(amr_path): print("User", name, "fails to tag AMR", amr, file=ERROR_LOG) pop_name.append(i) break if len(pop_name) != 0: pop_num = 0 for p in pop_name: print("Deleting user", names[p - pop_num], "from the name list", file=ERROR_LOG) names.pop(p - pop_num) pop_num += 1 if len(names) < 2: print("Not enough users to evaluate. Smatch requires >2 users who tag all the AMRs", file=ERROR_LOG) return "", "", False return amr_ids, names, True def main(arguments): global verbose (ids, names, result) = check_args(arguments) if arguments.v: verbose = True if not result: return 0 acc_time = 0 len_name = len(names) table = [] for i in range(0, len_name + 1): table.append([]) table[0].append("") for i in range(0, len_name): table[0].append(names[i]) for i in range(0, len_name): table[i+1].append(names[i]) for j in range(0, len_name): if i != j: start = time.perf_counter() table[i+1].append(compute_files(names[i], names[j], ids, args.fd, args.r)) end = time.perf_counter() if table[i+1][-1] != -1.0: acc_time += end-start else: table[i+1].append("") # check table for i in range(0, len_name + 1): for j in range(0, len_name + 1): if i != j: if table[i][j] != table[j][i]: if table[i][j] > table[j][i]: table[j][i] = table[i][j] else: table[i][j] = table[j][i] pprint_table(table) return acc_time if __name__ == "__main__": whole_start = time.perf_counter() import argparse parser = argparse.ArgumentParser(description="Smatch table calculator") parser.add_argument( "--fl", type=argparse.FileType('r'), help='AMR ID list file') parser.add_argument( '-f', nargs='+', help='AMR IDs (at least one)') parser.add_argument( "-p", nargs='*', help="User list (can be none)") parser.add_argument( "--fd", default=isi_dir_pre, help="AMR File directory. Default=location on isi machine") parser.add_argument( '-r', type=int, default=4, help='Restart number (Default:4)') parser.add_argument( '-v', action='store_true', help='Verbose output (Default:False)') args = parser.parse_args() # Regularize fd, add "/" at the end if needed if args.fd[-1] != "/": args.fd += "/" # acc_time is the smatch calculation time acc_time = main(args) whole_end = time.perf_counter() # time of the whole running process whole_time = whole_end - whole_start # print if needed # print("Accumulated computation time", acc_time, file=ERROR_LOG) # print("Total time", whole_time, file=ERROR_LOG) # print("Percentage", float(acc_time)/float(whole_time), file=ERROR_LOG)