# -*- coding: utf-8 -*- import os import json from gensim.models import word2vec from gensim import models class Rule(object): """ Store the concept terms of a rule, and calculate the rule similarity. """ def __init__(self, domain, rule_terms, children, response, word2vec_model): self.id_term = domain self.terms = rule_terms self.model = word2vec_model self.response = response self.children = children def __str__(self): res = 'Domain:' + self.id_term if self.has_child(): res += ' with children: ' for child in self.children: res += ' ' + str(child) return res def serialize(self): """ Convert the instance to json format. """ ch_list = [] for child in self.children: ch_list.append(child.id_term) cp_list = [] for t in self.terms: cp_list.append(t) response = [] data = { "domain": str(self.id_term), "concepts": cp_list, "children": ch_list, "response": response } return data def add_child(self,child_rule): """ Add child rule into children list , e.g: Purchase(Parent) -> Drinks(Child). """ self.children.append(child_rule) def has_child(self): return len(self.children) def has_response(self): return len(self.response) def match(self, sentence, threshold=0): """ Calculate the similarity between the input and concept term. Args: threshold: a threshold to ignore the low similarity. sentence : a list of words. Returns: a struct : [similarity, domain_name, matchee in the sentence] """ max_sim = 0.0 matchee = "" for word in sentence: for term in self.terms: try: sim = self.model.similarity(term,word) if sim > max_sim and sim > threshold: max_sim = sim matchee = word except Exception as e: if term == word: max_sim = 1 matchee = word return [max_sim, self.id_term, matchee] class RuleBase(object): """ to store rules, and load the trained word2vec model. """ def __init__(self, domain="general"): self.rules = {} self.domain = domain self.model = None self.forest_base_roots = [] def __str__(self): res = "There are " + str(self.rule_amount()) + " rules in the rulebase:" res+= "\n-------\n" for key,rulebody in self.rules.items(): res += str(rulebody) + '\n' return res def rule_amount(self): return len(self.rules) def output_as_json(self, path='rule.json'): rule_list = [] for rule in self.rules.values(): rule_list.append(rule.serialize()) with open(path,'w',encoding='utf-8') as op: op.write(json.dumps(rule_list, indent=4)) def load_rules_old_format(self,path): """ Deprecated. Build the rulebase by loading the rules terms from the given file. The data format is: child term, parent term(optional) Args: the path of file. """ assert self.model is not None, "Please load the model before loading rules." self.rules.clear() with open(path, 'r', encoding='utf-8') as input: for line in input: rule_terms = line.strip('\n').split(' ') new_rule = Rule(self.rule_amount(), rule_terms[0].split(','), self.model) if new_rule.id_term not in self.rules: self.rules[new_rule.id_term] = new_rule #else # self.rules[new_rule.id_term].terms = rule_terms if len(rule_terms) > 1: # this rule has parents. for parent in rule_terms[1:]: #if parent not in self.rules: self.rules[parent].children.append(new_rule) else: # is the root of classification tree. self.forest_base_roots.append(new_rule) def load_rules(self, path, reload=False, is_root=False): """ Build the rulebase by loading the rules terms from the given file. Args: the path of file. """ assert self.model is not None, "Please load the model before loading rules." if reload: self.rules.clear() with open(path, 'r', encoding='utf-8') as input: json_data = json.load(input) # load rule and build an instance for data in json_data: domain = data["domain"] concepts_list = data["concepts"] children_list = data["children"] response = data["response"] if domain not in self.rules: rule = Rule(domain, concepts_list, children_list, response, self.model) self.rules[domain] = rule if is_root: self.forest_base_roots.append(rule) else: print("[Rules]: Detect a duplicate domain name '%s'." % domain) def load_rules_from_dic(self,path): """ load all rule_files in given path """ for file_name in os.listdir(path): if not file_name.startswith('.'): # escape .DS_Store on OSX. if file_name == "rule.json": # roots of forest self.load_rules(path + file_name, is_root=True) else: self.load_rules(path + file_name) def load_model(self,path): """ Load a trained word2vec model(binary format only). Args: path: the path of the model. """ try: self.model = models.Word2Vec.load(path) # current loading method except FileNotFoundError as file_not_found_err: print("[Gensim] FileNotFoundError", file_not_found_err) exit() except UnicodeDecodeError as unicode_decode_err: print("[Gensim] UnicodeDecodeError", unicode_decode_err) self.model = models.KeyedVectors.load_word2vec_format(path, binary=True) # old loading method except Exception as ex: print("[Gensim] Exception", ex) exit() def match(self, sentence, topk=1, threshold=0, root=None): """ match the sentence with rules then order by similarity. Args: sentence: a list of words threshold: a threshold to ignore the low similarity. Return: a list holds the top k-th rules and the classification tree travel path. """ log = open("matching_log.txt",'w',encoding='utf-8') assert self.model is not None, "Please load the model before any match." result_list = [] at_leaf_node = False term_trans = "" if root is None: # then search from roots of forest. focused_rule = self.forest_base_roots[:] else: focused_rule = [self.rules[root]] while not at_leaf_node: at_leaf_node = True for rule in focused_rule: result_list.append(rule.match(sentence, threshold)) result_list = sorted(result_list, reverse=True , key=lambda k: k[0]) top_domain = result_list[0][1] # get the best matcher's term. # Output matching_log. log.write("---") for result in result_list: s,d,m = result log.write("Sim: %f, Domain: %s, Matchee: %s\n" % (s,d,m)) log.write("---") if self.rules[top_domain].has_child(): result_list = [] term_trans += top_domain+'>' at_leaf_node = False # travel to the best node's children. focused_rule = [] for rule_id in self.rules[top_domain].children: focused_rule.append(self.rules[rule_id]) else: term_trans += top_domain return [result_list,term_trans]