#!/usr/bin/env python # -*- coding: utf-8 -*- """ This module is for visualizing the results """ import os import seaborn import torch import numpy as np import matplotlib.pyplot as plt import networkx as nx import pandas as pd from sklearn.manifold import TSNE from matplotlib import colors as mcolors from pykg2vec.utils.logger import Logger seaborn.set_style("darkgrid") class Visualization(object): """Class to aid in visualizing the results and embddings. Args: model (object): Model object vis_opts (list): Options for visualization. sess (object): TensorFlow session object, initialized by the trainer. Examples: >>> from pykg2vec.utils.visualization import Visualization >>> from pykg2vec.utils.trainer import Trainer >>> from pykg2vec.models.TransE import TransE >>> model = TransE() >>> trainer = Trainer(model=model) >>> trainer.build_model() >>> trainer.train_model() >>> viz = Visualization(model=model) >>> viz.plot_train_result() """ _logger = Logger().get_logger(__name__) def __init__(self, model, config, vis_opts=None): if vis_opts: self.ent_only_plot = vis_opts["ent_only_plot"] self.rel_only_plot = vis_opts["rel_only_plot"] self.ent_and_rel_plot = vis_opts["ent_and_rel_plot"] else: self.ent_only_plot = False self.rel_only_plot = False self.ent_and_rel_plot = False self.model = model self.config = config self.algo_list = ['ANALOGY', 'Complex', 'ComplexN3', 'ConvE', 'CP', 'DistMult', 'DistMult2', 'HoLE', 'KG2E', 'NTN', 'ProjE_pointwise', 'Rescal', 'RotatE', 'SimplE_avg', 'SimplE_ignr', 'SLM', 'SME_Bilinear', 'SME_Linear', 'TransD', 'TransE', 'TransH', 'TransM', 'TransR', 'TuckER'] self.h_name = [] self.r_name = [] self.t_name = [] self.h_emb = [] self.r_emb = [] self.t_emb = [] self.h_proj_emb = [] self.r_proj_emb = [] self.t_proj_emb = [] if self.model != None: self.validation_triples_ids = self.config.knowledge_graph.read_cache_data('triplets_valid') self.idx2entity = self.config.knowledge_graph.read_cache_data('idx2entity') self.idx2relation = self.config.knowledge_graph.read_cache_data('idx2relation') self.get_idx_n_emb() def get_idx_n_emb(self): """Function to get the integer ids and the embedding.""" idx = np.random.choice(len(self.validation_triples_ids), self.config.disp_triple_num) triples = [] for i in range(len(idx)): triples.append(self.validation_triples_ids[idx[i]]) for t in triples: self.h_name.append(self.idx2entity[t.h]) self.r_name.append(self.idx2relation[t.r]) self.t_name.append(self.idx2entity[t.t]) emb_h, emb_r, emb_t = self.model.embed(torch.tensor([t.h]).to(self.config.device), torch.tensor([t.r]).to(self.config.device), torch.tensor([t.t]).to(self.config.device)) self.h_emb.append(emb_h) self.r_emb.append(emb_r) self.t_emb.append(emb_t) if self.ent_and_rel_plot: try: emb_h, emb_r, emb_t = self.model.embed(torch.tensor([t.h]).to(self.config.device), torch.tensor([t.r]).to(self.config.device), torch.tensor([t.t]).to(self.config.device)) self.h_proj_emb.append(emb_h) self.r_proj_emb.append(emb_r) self.t_proj_emb.append(emb_t) except Exception as e: self._logger.error(e.args) def plot_embedding(self, resultpath=None, algos=None, show_label=False, disp_num_r_n_e = 20): """Function to plot the embedding. Args: resultpath (str): Path where the result will be saved. show_label (bool): If True, will display the labels. algos (str): Name of the algorithms that generated the embedding. disp_num_r_n_e (int): Total number of entities to display for head, tail and relation. """ if not self.model: raise NotImplementedError('Please provide a model!') if self.ent_only_plot: x = torch.cat(self.h_emb + self.t_emb, dim=0) ent_names = np.concatenate((self.h_name, self.t_name), axis=0) self._logger.info("\t Reducing dimension using TSNE to 2!") x = TSNE(n_components=2).fit_transform(x.detach().cpu()) x = np.asarray(x) ent_names = np.asarray(ent_names) self.draw_embedding(x, ent_names, resultpath, algos + '_entity_plot', show_label) if self.rel_only_plot: x = torch.cat(self.r_emb, dim=0) self._logger.info("\t Reducing dimension using TSNE to 2!") x = TSNE(n_components=2).fit_transform(x.detach().cpu()) self.draw_embedding(x, self.r_name, resultpath, algos + '_rel_plot', show_label) if self.ent_and_rel_plot: length = len(self.h_proj_emb) x = torch.cat(self.h_proj_emb + self.r_proj_emb + self.t_proj_emb, dim=0) self._logger.info("\t Reducing dimension using TSNE to 2!") x = TSNE(n_components=2).fit_transform(x.detach().cpu()) h_embs = x[:length, :] r_embs = x[length:2 * length, :] t_embs = x[2 * length:3 * length, :] self.draw_embedding_rel_space(h_embs[:disp_num_r_n_e], r_embs[:disp_num_r_n_e], t_embs[:disp_num_r_n_e], self.h_name[:disp_num_r_n_e], self.r_name[:disp_num_r_n_e], self.t_name[:disp_num_r_n_e], resultpath, algos + '_ent_n_rel_plot', show_label) def plot_train_result(self): """Function to plot the training result.""" algo = self.algo_list path = self.config.path_result result = self.config.path_figures data = [self.config.dataset_name] files = os.listdir(str(path)) files_lwcase = [f.lower() for f in files] for d in data: df = pd.DataFrame() for a in algo: file_no = len([c for c in files_lwcase if a.lower() in c if 'training' in c]) if file_no < 1: continue file_path = str(path / (a.lower() + '_Training_results_' + str(file_no - 1) + '.csv')) if os.path.exists(file_path): with open(str(path / (a.lower() + '_Training_results_' + str(file_no - 1) + '.csv')), 'r') as fh: df_2 = pd.read_csv(fh) if df.empty: df['Epochs'] = df_2['Epochs'] df['Loss'] = df_2['Loss'] df['Algorithm'] = [a] * len(df_2) else: df_3 = pd.DataFrame() df_3['Epochs'] = df_2['Epochs'] df_3['Loss'] = df_2['Loss'] df_3['Algorithm'] = [a] * len(df_2) frames = [df, df_3] df = pd.concat(frames) plt.figure() ax = seaborn.lineplot(x="Epochs", y="Loss", hue="Algorithm", markers=True, dashes=False, data=df) files = os.listdir(str(result)) files_lwcase = [f.lower() for f in files] file_no = len([c for c in files_lwcase if d.lower() in c if 'training' in c]) plt.savefig(str(result / (d + '_training_loss_plot_' + str(file_no) + '.pdf')), bbox_inches='tight', dpi=300) # plt.show() def plot_test_result(self): """Function to plot the testing result.""" algo = self.algo_list path = self.config.path_result result = self.config.path_figures data = [self.config.dataset_name] hits = self.config.hits if path is None or algo is None or data is None: raise NotImplementedError('Please provide valid path, algorithm and dataset!') files = os.listdir(str(path)) # files_lwcase = [f.lower() for f in files if 'Testing' in f] # self._logger.info(files_lwcase) for d in data: df = pd.DataFrame() for a in algo: file_algo = [c for c in files if a.lower() in c.lower() if 'testing' in c.lower()] if not file_algo: continue with open(str(path / file_algo[-1]), 'r') as fh: df_2 = pd.read_csv(fh) if df.empty: df['Algorithm'] = [a] * len(df_2) df['Epochs'] = df_2['Epoch'] df['Mean Rank'] = df_2['Mean Rank'] df['Filt Mean Rank'] = df_2['Filtered Mean Rank'] for hit in hits: df['Hits' + str(hit)] = df_2['Hit-%d Ratio'%hit] df['Filt Hits' + str(hit)] = df_2['Filtered Hit-%d Ratio'%hit] else: df_3 = pd.DataFrame() df_3['Algorithm'] = [a] * len(df_2) df_3['Epochs'] = df_2['Epoch'] df_3['Mean Rank'] = df_2['Mean Rank'] df_3['Filt Mean Rank'] = df_2['Filtered Mean Rank'] for hit in hits: df_3['Hits' + str(hit)] = df_2['Hit-%d Ratio'%hit] df_3['Filt Hits' + str(hit)] = df_2['Filtered Hit-%d Ratio'%hit] frames = [df, df_3] df = pd.concat(frames) files = os.listdir(str(result)) df_4 = df.loc[df['Epochs'] == max(df['Epochs'])] df_4 = df_4.loc[:, df_4.columns != 'Epochs'] file_no = len( [c for c in files if d.lower() in c.lower() if 'testing' in c.lower() if 'latex' in c.lower()]) with open(str(result / (d + '_testing_latex_table_' + str(file_no + 1) + '.txt')), 'w') as fh: fh.write(df_4.to_latex(index=False)) file_no = len( [c for c in files if d.lower() in c.lower() if 'testing' in c.lower() if 'table' in c.lower() if 'csv' in c.lower()]) with open(str(result / (d + '_testing_table_' + str(file_no + 1) + '.csv')), 'w') as fh: df_4.to_csv(fh, index=False) df_5 = pd.DataFrame(columns=['Metrics', 'Algorithm', 'Score']) metrics = [f for f in df_4.columns if f != 'Algorithm'] for i in range(len(df_4)): # import pdb # pdb.set_trace() if df_5.empty: df_5['Algorithm'] = [df_4.iloc[i]['Algorithm']] * len(metrics) df_5['Metrics'] = metrics df_5['Score'] = df_4.iloc[i][metrics].values else: df_t = pd.DataFrame() df_t['Algorithm'] = [df_4.iloc[i]['Algorithm']] * len(metrics) df_t['Metrics'] = metrics df_t['Score'] = df_4.iloc[i][metrics].values frame = [df_5, df_t] df_5 = pd.concat(frame) df_6 = df_5[df_5['Metrics'].str.contains('Hits') == False] plt.figure() flatui = ["#d46a7e", "#d5b60a", "#9b59b6", "#3498db", "#95a5a6", "#34495e", "#2ecc71", "#e74c3c"] g = seaborn.barplot(x="Metrics", y='Score', hue="Algorithm", palette=flatui, data=df_6) g.legend(loc='upper center', bbox_to_anchor=(0.5, 1.14), ncol=6) g.tick_params(labelsize=6) # ax = seaborn.lineplot(x="Metrics", y='Score', hue="Algorithm", # markers=True, dashes=False, data=df_5) files_lwcase = [f.lower() for f in files] file_no = len([c for c in files_lwcase if d.lower() in c if 'testing' in c if 'rank_plot' in c]) plt.savefig(str(result / (d + '_testing_rank_plot_' + str(file_no + 1) + '.pdf')), bbox_inches='tight', dpi=300) # plt.show() df_6 = df_5[df_5['Metrics'].str.contains('Hits') == True] plt.figure() flatui = ["#3498db", "#95a5a6", "#34495e", "#2ecc71", "#e74c3c", "#d46a7e", "#d5b60a", "#9b59b6"] g = seaborn.barplot(x="Metrics", y='Score', hue="Algorithm", palette=flatui, data=df_6) g.legend(loc='upper center', bbox_to_anchor=(0.5, 1.14), ncol=6) g.tick_params(labelsize=6) files_lwcase = [f.lower() for f in files] file_no = len([c for c in files_lwcase if d.lower() in c if 'testing' in c if 'hits_plot' in c]) plt.savefig(str(result / (d + '_testing_hits_plot_' + str(file_no + 1) + '.pdf')), bbox_inches='tight', dpi=300) # plt.show() @staticmethod def draw_embedding(embs, names, resultpath, algos, show_label): """Function to draw the embedding. Args: embs (matrix): Two dimesnional embeddings. names (list):List of string name. resultpath (str):Path where the result will be save. algos (str): Name of the algorithms which generated the algorithm. show_label (bool): If True, prints the string names of the entities and relations. """ pos = {} node_color_mp = {} unique_ent = set(names) colors = list(dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS).keys()) tot_col = len(colors) j = 0 for i, e in enumerate(unique_ent): node_color_mp[e] = colors[j] j += 1 if j >= tot_col: j = 0 G = nx.Graph() hm_ent = {} for i, ent in enumerate(names): hm_ent[i] = ent G.add_node(i) pos[i] = embs[i] colors = [] for n in list(G.nodes): colors.append(node_color_mp[hm_ent[n]]) plt.figure() nodes_draw = nx.draw_networkx_nodes(G, pos, node_color=colors, node_size=50) nodes_draw.set_edgecolor('k') if show_label: nx.draw_networkx_labels(G, pos, font_size=8) if not os.path.exists(resultpath): os.mkdir(resultpath) files = os.listdir(resultpath) file_no = len( [c for c in files if algos + '_embedding_plot' in c]) filename = algos + '_embedding_plot_' + str(file_no) + '.png' plt.savefig(str(resultpath / filename), bbox_inches='tight', dpi=300) # plt.show() @staticmethod def draw_embedding_rel_space(h_emb, r_emb, t_emb, h_name, r_name, t_name, resultpath, algos, show_label): """Function to draw the embedding in relation space. Args: h_emb (matrix): Two dimesnional embeddings of head. r_emb (matrix): Two dimesnional embeddings of relation. t_emb (matrix): Two dimesnional embeddings of tail. h_name (list):List of string name of the head. r_name (list):List of string name of the relation. t_name (list):List of string name of the tail. resultpath (str):Path where the result will be save. algos (str): Name of the algorithms which generated the algorithm. show_label (bool): If True, prints the string names of the entities and relations. """ pos = {} node_color_mp_ent = {} node_color_mp_rel = {} unique_ent = set(h_name) | set(t_name) unique_rel = set(r_name) colors = list(dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS).keys()) tot_col = len(colors) j = 0 for i, e in enumerate(unique_ent): node_color_mp_ent[e] = colors[j] j += 1 if j >= tot_col: j = 0 tot_col = len(colors) j = 0 for i, r in enumerate(unique_rel): node_color_mp_rel[r] = colors[j] j += 1 if j >= tot_col: j = 0 G = nx.DiGraph() idx = 0 head_colors = [] rel_colors = [] tail_colors = [] head_nodes = [] tail_nodes = [] rel_nodes = [] for i in range(len(h_name)): G.add_edge(idx, idx + 1) G.add_edge(idx + 1, idx + 2) head_nodes.append(idx) rel_nodes.append(idx + 1) tail_nodes.append(idx + 2) head_colors.append(node_color_mp_ent[h_name[i]]) rel_colors.append(node_color_mp_rel[r_name[i]]) tail_colors.append(node_color_mp_ent[t_name[i]]) pos[idx] = h_emb[i] pos[idx + 1] = r_emb[i] pos[idx + 2] = t_emb[i] idx += 3 plt.figure() nodes_draw = nx.draw_networkx_nodes(G, pos, nodelist=head_nodes, node_color=head_colors, node_shape='o', node_size=50) nodes_draw.set_edgecolor('k') nodes_draw = nx.draw_networkx_nodes(G, pos, nodelist=rel_nodes, node_color=rel_colors, node_size=50, node_shape='D', with_labels=show_label) nodes_draw.set_edgecolor('k') nodes_draw = nx.draw_networkx_nodes(G, pos, nodelist=tail_nodes, node_color=tail_colors, node_shape='*', node_size=50) nodes_draw.set_edgecolor('k') if show_label: nx.draw_networkx_labels(G, pos, font_size=8) nx.draw_networkx_edges(G, pos, arrows=True, width=0.5, alpha=0.5) if not os.path.exists(resultpath): os.mkdir(resultpath) files = os.listdir(resultpath) file_no = len( [c for c in files if algos + '_embedding_plot' in c]) plt.savefig(str(resultpath / (algos + '_embedding_plot_' + str(file_no) + '.png')), bbox_inches='tight', dpi=300) # plt.show()