''' Created on Dec, 2016 @author: hugo ''' from __future__ import absolute_import import numpy as np from sklearn.decomposition import PCA from sklearn.manifold import TSNE import matplotlib as mpl mpl.use('TkAgg') import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator, MultipleLocator from mpl_toolkits.mplot3d import Axes3D from scipy import interpolate class neural_net_visualizer(object): def __init__(self): pass def heatmap(data, save_file='heatmap.png'): ax = plt.figure().gca() ax.yaxis.set_major_locator(MaxNLocator(integer=True)) ax.yaxis.set_major_locator(MultipleLocator(5)) plt.pcolor(data, cmap=plt.cm.jet) plt.savefig(save_file) # plt.show() def word_cloud(word_embedding_matrix, vocab, s, save_file='scatter.png'): words = [(i, vocab[i]) for i in s] model = TSNE(n_components=2, random_state=0) #Note that the following line might use a good chunk of RAM tsne_embedding = model.fit_transform(word_embedding_matrix) words_vectors = tsne_embedding[np.array([item[1] for item in words])] plt.subplots_adjust(bottom = 0.1) plt.scatter( words_vectors[:, 0], words_vectors[:, 1], marker='o', cmap=plt.get_cmap('Spectral')) for label, x, y in zip(s, words_vectors[:, 0], words_vectors[:, 1]): plt.annotate( label, xy=(x, y), xytext=(-20, 20), textcoords='offset points', ha='right', va='bottom', fontsize=20, # bbox=dict(boxstyle='round,pad=1.', fc='yellow', alpha=0.5), arrowprops=dict(arrowstyle = '<-', connectionstyle='arc3,rad=0') ) plt.show() # plt.savefig(save_file) def plot_tsne(doc_codes, doc_labels, classes_to_visual, save_file): # markers = ["D", "p", "*", "s", "d", "8", "^", "H", "v", ">", "<", "h", "|"] markers = ["o", "v", "8", "s", "p", "*", "h", "H", "+", "x", "D"] plt.rc('legend',**{'fontsize':30}) classes_to_visual = list(set(classes_to_visual)) C = len(classes_to_visual) while True: if C <= len(markers): break markers += markers class_ids = dict(zip(classes_to_visual, range(C))) if isinstance(doc_codes, dict) and isinstance(doc_labels, dict): codes, labels = zip(*[(code, doc_labels[doc]) for doc, code in doc_codes.items() if doc_labels[doc] in classes_to_visual]) else: codes, labels = doc_codes, doc_labels X = np.r_[list(codes)] tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000) np.set_printoptions(suppress=True) X = tsne.fit_transform(X) plt.figure(figsize=(10, 10), facecolor='white') for c in classes_to_visual: idx = np.array(labels) == c # idx = get_indices(labels, c) plt.plot(X[idx, 0], X[idx, 1], linestyle='None', alpha=1, marker=markers[class_ids[c]], markersize=10, label=c) legend = plt.legend(loc='upper right', shadow=True) # plt.title("tsne") # plt.savefig(save_file) plt.savefig(save_file, format='eps', dpi=2000) plt.show() def plot_tsne_3d(doc_codes, doc_labels, classes_to_visual, save_file, maker_size=None, opaque=None): markers = ["D", "p", "*", "s", "d", "8", "^", "H", "v", ">", "<", "h", "|"] plt.rc('legend',**{'fontsize':20}) colors = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] C = len(classes_to_visual) while True: if C <= len(markers): break markers += markers while True: if C <= len(colors): break colors += colors class_ids = dict(zip(classes_to_visual, range(C))) if isinstance(doc_codes, dict) and isinstance(doc_labels, dict): codes, labels = zip(*[(code, doc_labels[doc]) for doc, code in doc_codes.items() if doc_labels[doc] in classes_to_visual]) else: codes, labels = doc_codes, doc_labels X = np.r_[list(codes)] tsne = TSNE(perplexity=30, n_components=3, init='pca', n_iter=5000) np.set_printoptions(suppress=True) X = tsne.fit_transform(X) fig = plt.figure(figsize=(10, 10), facecolor='white') ax = fig.add_subplot(111, projection='3d') # The problem is that the legend function don't support the type returned by a 3D scatter. # So you have to create a "dummy plot" with the same characteristics and put those in the legend. scatter_proxy = [] for i in range(C): cls = classes_to_visual[i] idx = np.array(labels) == cls ax.scatter(X[idx, 0], X[idx, 1], X[idx, 2], c=colors[i], alpha=opaque[i] if opaque else 1, s=maker_size[i] if maker_size else 20, marker=markers[i], label=cls) scatter_proxy.append(mpl.lines.Line2D([0],[0], linestyle="none", c=colors[i], marker=markers[i], label=cls)) ax.legend(scatter_proxy, classes_to_visual, numpoints=1) plt.savefig(save_file) plt.show() def visualize_pca_2d(doc_codes, doc_labels, classes_to_visual, save_file): """ Visualize the input data on a 2D PCA plot. Depending on the number of components, the plot will contain an X amount of subplots. @param doc_codes: @param number_of_components: The number of principal components for the PCA plot. """ # markers = ["D", "p", "*", "s", "d", "8", "^", "H", "v", ">", "<", "h", "|"] markers = ["o", "v", "8", "s", "p", "*", "h", "H", "+", "x", "D"] plt.rc('legend',**{'fontsize':28}) classes_to_visual = list(set(classes_to_visual)) C = len(classes_to_visual) while True: if C <= len(markers): break markers += markers class_ids = dict(zip(classes_to_visual, range(C))) if isinstance(doc_codes, dict) and isinstance(doc_labels, dict): codes, labels = zip(*[(code, doc_labels[doc]) for doc, code in doc_codes.items() if doc_labels[doc] in classes_to_visual]) else: codes, labels = doc_codes, doc_labels X = np.r_[list(codes)] X = PCA(n_components=3).fit_transform(X) plt.figure(figsize=(10, 10), facecolor='white') x_pc, y_pc = 1, 2 for c in classes_to_visual: idx = np.array(labels) == c # idx = get_indices(labels, c) plt.plot(X[idx, x_pc], X[idx, y_pc], linestyle='None', alpha=1, marker=markers[class_ids[c]], markersize=10, label=c) # plt.legend(c) # plt.title('Projected on the PCA components') # plt.xlabel('PC %s' % x_pc) # plt.ylabel('PC %s' % y_pc) legend = plt.legend(loc='upper right', shadow=True) # plt.savefig(save_file) plt.savefig(save_file, format='eps', dpi=2000) plt.show() def visualize_pca_3d(doc_codes, doc_labels, classes_to_visual, save_file, maker_size=None, opaque=None): """ Visualize the input data on a 2D PCA plot. Depending on the number of components, the plot will contain an X amount of subplots. @param doc_codes: @param number_of_components: The number of principal components for the PCA plot. """ markers = ["D", "p", "*", "s", "d", "8", "^", "H", "v", ">", "<", "h", "|"] plt.rc('legend',**{'fontsize':20}) colors = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] C = len(classes_to_visual) while True: if C <= len(markers): break markers += markers while True: if C <= len(colors): break colors += colors if isinstance(doc_codes, dict) and isinstance(doc_labels, dict): codes, labels = zip(*[(code, doc_labels[doc]) for doc, code in doc_codes.items() if doc_labels[doc] in classes_to_visual]) else: codes, labels = doc_codes, doc_labels X = np.r_[list(codes)] X = PCA(n_components=3).fit_transform(X) fig = plt.figure(figsize=(10, 10), facecolor='white') ax = fig.add_subplot(111, projection='3d') x_pc, y_pc, z_pc = 0, 1, 2 # The problem is that the legend function don't support the type returned by a 3D scatter. # So you have to create a "dummy plot" with the same characteristics and put those in the legend. scatter_proxy = [] for i in range(C): cls = classes_to_visual[i] idx = np.array(labels) == cls ax.scatter(X[idx, x_pc], X[idx, y_pc], X[idx, z_pc], c=colors[i], alpha=opaque[i] if opaque else 1, s=maker_size[i] if maker_size else 20, marker=markers[i], label=cls) scatter_proxy.append(mpl.lines.Line2D([0],[0], linestyle="none", c=colors[i], marker=markers[i], label=cls)) ax.legend(scatter_proxy, classes_to_visual, numpoints=1) # plt.title('Projected on the PCA components') ax.set_xlabel('%sst component' % (x_pc + 1), fontsize=14) ax.set_ylabel('%snd component' % (y_pc + 1), fontsize=14) ax.set_zlabel('%srd component' % (z_pc + 1), fontsize=14) plt.savefig(save_file) plt.show() def DBN_plot_tsne(doc_codes, doc_labels, classes_to_visual, save_file): markers = ["o", "v", "8", "s", "p", "*", "h", "H", "+", "x", "D"] C = len(classes_to_visual) while True: if C <= len(markers): break markers += markers class_ids = dict(zip(classes_to_visual.keys(), range(C))) codes, labels = doc_codes, doc_labels X = np.r_[list(codes)] tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000) np.set_printoptions(suppress=True) X = tsne.fit_transform(X) plt.figure(figsize=(10, 10), facecolor='white') for c in classes_to_visual.keys(): idx = np.array(labels) == c # idx = get_indices(labels, c) plt.plot(X[idx, 0], X[idx, 1], linestyle='None', alpha=0.6, marker=markers[class_ids[c]], markersize=6, label=classes_to_visual[c]) legend = plt.legend(loc='upper center', shadow=True) plt.title("tsne") plt.savefig(save_file) plt.show() def DBN_visualize_pca_2d(doc_codes, doc_labels, classes_to_visual, save_file): """ Visualize the input data on a 2D PCA plot. Depending on the number of components, the plot will contain an X amount of subplots. @param doc_codes: @param number_of_components: The number of principal components for the PCA plot. """ # markers = ["p", "s", "h", "H", "+", "x", "D"] markers = ["o", "v", "8", "s", "p", "*", "h", "H", "+", "x", "D"] C = len(classes_to_visual) while True: if C <= len(markers): break markers += markers class_ids = dict(zip(classes_to_visual.keys(), range(C))) codes, labels = doc_codes, doc_labels X = np.r_[list(codes)] X = PCA(n_components=3).fit_transform(X) plt.figure(figsize=(10, 10), facecolor='white') x_pc, y_pc = 1, 2 for c in classes_to_visual.keys(): idx = np.array(labels) == c # idx = get_indices(labels, c) plt.plot(X[idx, x_pc], X[idx, y_pc], linestyle='None', alpha=0.6, marker=markers[class_ids[c]], markersize=6, label=classes_to_visual[c]) # plt.legend(c) plt.title('Projected on the first 2 PCs') plt.xlabel('PC %s' % x_pc) plt.ylabel('PC %s' % y_pc) # legend = plt.legend(loc='upper center', shadow=True) plt.savefig(save_file) plt.show() def reuters_visualize_tsne(doc_codes, doc_labels, classes_to_visual, save_file): """ Visualize the input data on a 2D PCA plot. Depending on the number of components, the plot will contain an X amount of subplots. @param doc_codes: @param number_of_components: The number of principal components for the PCA plot. """ # markers = ["p", "s", "h", "H", "+", "x", "D"] markers = ["o", "v", "8", "s", "p", "*", "h", "H", "+", "x", "D"] C = len(classes_to_visual) while True: if C <= len(markers): break markers += markers class_names = classes_to_visual.keys() class_ids = dict(zip(class_names, range(C))) class_names = set(class_names) codes, labels = zip(*[(code, doc_labels[doc]) for doc, code in doc_codes.items() if class_names.intersection(set(doc_labels[doc]))]) X = np.r_[list(codes)] tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000) np.set_printoptions(suppress=True) X = tsne.fit_transform(X) plt.figure(figsize=(10, 10), facecolor='white') for c in classes_to_visual.keys(): idx = get_indices(labels, c) plt.plot(X[idx, 0], X[idx, 1], linestyle='None', alpha=0.6, marker=markers[class_ids[c]], markersize=6, label=classes_to_visual[c]) legend = plt.legend(loc='upper center', shadow=True) plt.title("tsne") plt.savefig(save_file) plt.show() def reuters_visualize_pca_2d(doc_codes, doc_labels, classes_to_visual, save_file): """ Visualize the input data on a 2D PCA plot. Depending on the number of components, the plot will contain an X amount of subplots. @param doc_codes: @param number_of_components: The number of principal components for the PCA plot. """ # markers = ["p", "s", "h", "H", "+", "x", "D"] markers = ["o", "v", "8", "s", "p", "*", "h", "H", "+", "x", "D"] C = len(classes_to_visual) while True: if C <= len(markers): break markers += markers class_names = classes_to_visual.keys() class_ids = dict(zip(class_names, range(C))) class_names = set(class_names) codes, labels = zip(*[(code, class_names.intersection(set(doc_labels[doc]))) for doc, code in doc_codes.items() if len(class_names.intersection(set(doc_labels[doc]))) == 1]) # codes = [] # labels = [] # for doc, code in doc_codes.items(): # y = set(doc_labels[doc]) # x = list(class_names.intersection(y)) # if x: # codes.append(code) # labels.append(x[0]) # x = 0 # pairs = [] # for each in labels: # if len(class_names.intersection(set(each))) > 1: # x += 1 # pairs.append(class_names.intersection(set(each))) # print x X = np.r_[list(codes)] X = PCA(n_components=3).fit_transform(X) plt.figure(figsize=(10, 10), facecolor='white') x_pc, y_pc = 0, 1 for c in class_names: idx = get_indices(labels, c) plt.plot(X[idx, x_pc], X[idx, y_pc], linestyle='None', alpha=0.6, marker=markers[class_ids[c]], markersize=6, label=classes_to_visual[c]) # plt.legend(c) plt.title('Projected on the first 2 PCs') plt.xlabel('PC %s' % x_pc) plt.ylabel('PC %s' % y_pc) legend = plt.legend(loc='upper center', shadow=True) plt.savefig(save_file) plt.show() def get_indices(labels, c): idx = np.zeros(len(labels), dtype=bool) for i in range(len(labels)): tmp = [labels[i]] if not isinstance(labels[i], (list, set)) else labels[i] if c in tmp: idx[i] = True return idx def plot_info_retrieval(precisions, save_file): # markers = ["|", "D", "8", "v", "^", ">", "h", "H", "s", "*", "p", "d", "<"] markers = ["D", "p", 's', "*", "d", "8", "^", "H", "v", ">", "<", "h", "|"] ticks = zip(*zip(*precisions)[1][0])[0] plt.xticks(range(len(ticks)), ticks) new_x = interpolate.interp1d(ticks, range(len(ticks)))(ticks) i = 0 for model_name, val in precisions: fr, pr = zip(*val) plt.plot(new_x, pr, linestyle='-', alpha=0.7, marker=markers[i], markersize=8, label=model_name) i += 1 # plt.legend(model_name) plt.xlabel('Fraction of Retrieved Documents') plt.ylabel('Precision') legend = plt.legend(loc='upper right', shadow=True) plt.savefig(save_file) plt.show() def plot_info_retrieval_by_length(precisions, save_file): markers = ["o", "v", "8", "s", "p", "*", "h", "H", "^", "x", "D"] ticks = zip(*zip(*precisions)[1][0])[0] plt.xticks(range(len(ticks)), ticks) new_x = interpolate.interp1d(ticks, range(len(ticks)))(ticks) i = 0 for model_name, val in precisions: fr, pr = zip(*val) plt.plot(new_x, pr, linestyle='-', alpha=0.6, marker=markers[i], markersize=6, label=model_name) i += 1 # plt.legend(model_name) plt.xlabel('Document Sorted by Length') plt.ylabel('Precision (%)') legend = plt.legend(loc='upper right', shadow=True) plt.savefig(save_file) plt.show() def plot(x, y, x_label, y_label, save_file): ticks = x plt.xticks(range(len(ticks)), ticks, fontsize = 15) plt.yticks(fontsize = 15) new_x = interpolate.interp1d(ticks, range(len(ticks)))(ticks) plt.plot(new_x, y, linestyle='-', alpha=1.0, markersize=12, marker='p', color='b') plt.xlabel(x_label, fontsize=24) plt.ylabel(y_label, fontsize=20) plt.savefig(save_file) plt.show() if __name__ == '__main__': import sys # 20news_retrieval_128D precisions = [ ('VAE', [(0.001, 0.587348525080869), (0.002, 0.5651402500844888), (0.005, 0.5327151771489245), (0.01, 0.5014839340348453), (0.02, 0.4584269359288251), (0.05, 0.3658133556412997), (0.1, 0.2687164883998648), (0.2, 0.17739560251738207), (0.5, 0.09136516909151776), (1.0, 0.05050301672031405)]), ('DocNADE', [(0.001, 0.5718148022980761), (0.002, 0.5435414956790445), (0.005, 0.5074230900538642), (0.01, 0.4746133312027964), (0.02, 0.43102761550716634), (0.05, 0.3383512940656766), (0.1, 0.25088957318799715), (0.2, 0.16893256617330504), (0.5, 0.0898931631614369), (1.0, 0.05050301672031405)]), ('KATE', [(0.001, 0.5543982040264583), (0.002, 0.5213392555399969), (0.005, 0.4739445034519384), (0.01, 0.4347574243698827), (0.02, 0.3869114198299623), (0.05, 0.30403564261511706), (0.1, 0.2277761656366975), (0.2, 0.15699569840064684), (0.5, 0.08683891514289452), (1.0, 0.05050301672031405)]), ('DBN', [(0.001, 0.535038381692656), (0.002, 0.5077608265340592), (0.005, 0.465912108337758), (0.01, 0.4264154357337848), (0.02, 0.37657322856108594), (0.05, 0.29198182151435376), (0.1, 0.2197600288870639), (0.2, 0.15325609847145583), (0.5, 0.08605947016611057), (1.0, 0.05050301672031403)]), ('LDA', [(0.001, 0.4867957321488915), (0.002, 0.46359774054941044), (0.005, 0.42999155982095444), (0.01, 0.40179481997753264), (0.02, 0.36320959775165296), (0.05, 0.2823678558504475), (0.1, 0.20784423242441585), (0.2, 0.14168207983103592), (0.5, 0.0800605531419018), (1.0, 0.05050301672031405)]), ('Word2Vec_pre', [(0.001, 0.4619200502100128), (0.002, 0.4201226283010617), (0.005, 0.363601016614824), (0.01, 0.3199258385461033), (0.02, 0.27425227583548745), (0.05, 0.2083981501934101), (0.1, 0.15890115524777892), (0.2, 0.1170970848576276), (0.5, 0.0746015280886044), (1.0, 0.05050301672031405)]), ('Word2Vec', [(0.001, 0.3815960990682146), (0.002, 0.33990126973397794), (0.005, 0.28322016538957506), (0.01, 0.23994026666166862), (0.02, 0.1996232005978027), (0.05, 0.15086380704863955), (0.1, 0.11886672273161164), (0.2, 0.09250234660438485), (0.5, 0.06552700581695815), (1.0, 0.05050301672031405)]), ('CAE', [(0.001, 0.25605899676530824), (0.002, 0.2178643846859439), (0.005, 0.17500331917153453), (0.01, 0.14827356083073284), (0.02, 0.12567969583465557), (0.05, 0.1013255537435644), (0.1, 0.08658477146491544), (0.2, 0.07424743141317958), (0.5, 0.059710587487142405), (1.0, 0.05050301672031405)]), ('KSAE', [(0.001, 0.23964418481146235), (0.002, 0.20264447448462075), (0.005, 0.16342178135194532), (0.01, 0.1395696943777457), (0.02, 0.12070563824438621), (0.05, 0.09967618984957147), (0.1, 0.08657995851945419), (0.2, 0.07516400405132624), (0.5, 0.061121338068411066), (1.0, 0.05050301672031405)]), ('AE', [(0.001, 0.22827451359049142), (0.002, 0.18935571863080838), (0.005, 0.14794495865260598), (0.01, 0.12336861250406352), (0.02, 0.10404868431566056), (0.05, 0.08489066120247557), (0.1, 0.07413015988839661), (0.2, 0.06568426232571814), (0.5, 0.0563362391994616), (1.0, 0.05050301672031405)]), ('DAE', [(0.001, 0.2095785255636476), (0.002, 0.17031574373581437), (0.005, 0.1300285448751998), (0.01, 0.10855864535504914), (0.02, 0.09279581161675608), (0.05, 0.07767683840981233), (0.1, 0.06946348101328252), (0.2, 0.06284691358720304), (0.5, 0.05536974244871757), (1.0, 0.05050301672031405)]), ('Doc2Vec', [(0.001, 0.16486023270409583), (0.002, 0.1494834162120367), (0.005, 0.12679472346559542), (0.01, 0.11052195000447207), (0.02, 0.0953665540302444), (0.05, 0.07877845088096704), (0.1, 0.06914265711214816), (0.2, 0.06190158066520009), (0.5, 0.05536713733618202), (1.0, 0.05050301672031404)]), ('NVDM', [(0.001, 0.05129628735576586), (0.002, 0.0513143919277744), (0.005, 0.0513784045216623), (0.01, 0.05057947447821584), (0.02, 0.04999729766565648), (0.05, 0.05015274063700316), (0.1, 0.050297158296132634), (0.2, 0.05053768818029844), (0.5, 0.050492220758456205), (1.0, 0.05050301672031404)]) ] # 20news_retrieval_512D # precisions = { # 'LDA': [(0.001, 0.4058682952734931), (0.002, 0.37058851928739733), (0.005, 0.3309972687959938), (0.01, 0.3016110612419513), (0.02, 0.2693117036925588), (0.05, 0.2161839279252353), (0.1, 0.1702690976502034), (0.2, 0.12488871530981528), (0.5, 0.07553462776603063), (1.0, 0.05050301672031405)], # 'DBN': [(0.001, 0.5553034326268522), (0.002, 0.5285147009124683), (0.005, 0.488347337076094), (0.01, 0.45024297510562405), (0.02, 0.4033891972422051), (0.05, 0.31771321417997606), (0.1, 0.23800250085341837), (0.2, 0.1643866217959289), (0.5, 0.08997211449990615), (1.0, 0.050503016720313966)], # 'DocNADE': [(0.001, 0.5771737556124188), (0.002, 0.5443682711340714), (0.005, 0.5036226386465347), (0.01, 0.47000408874935945), (0.02, 0.42806973432528334), (0.05, 0.3391154672218619), (0.1, 0.2523257091581672), (0.2, 0.16856842576301645), (0.5, 0.08886419064880063), (1.0, 0.05050301672031405)], # 'NVDM': [(0.001, 0.051827354801331486), (0.002, 0.05166441365326164), (0.005, 0.05022143615810757), (0.01, 0.050504279087694), (0.02, 0.04984220717270456), (0.05, 0.050210312107871934), (0.1, 0.05031805352276878), (0.2, 0.050525479733275175), (0.5, 0.05048740951458409), (1.0, 0.05050301672031404)], # 'Word2Vec_pre': [(0.001, 0.4619200502100128), (0.002, 0.4201226283010617), (0.005, 0.363601016614824), (0.01, 0.3199258385461033), (0.02, 0.27425227583548745), (0.05, 0.2083981501934101), (0.1, 0.15890115524777892), (0.2, 0.1170970848576276), (0.5, 0.0746015280886044), (1.0, 0.05050301672031405)], # 'Word2Vec': [(0.001, 0.3842755757253872), (0.002, 0.34131946120793055), (0.005, 0.28388162885972246), (0.01, 0.24101532576054588), (0.02, 0.20010492106833672), (0.05, 0.1509728403648974), (0.1, 0.118880457234514), (0.2, 0.09250264007666884), (0.5, 0.0655253160142321), (1.0, 0.05050301672031405)], # 'Doc2Vec': [(0.001, 0.2199705498961938), (0.002, 0.19429826678896794), (0.005, 0.15887688718610055), (0.01, 0.13127352793274671), (0.02, 0.1067768670780567), (0.05, 0.08213522011101435), (0.1, 0.06901493797404573), (0.2, 0.05975776562880724), (0.5, 0.05340344575184013), (1.0, 0.050503016720314126)], # 'AE': [(0.001, 0.25062762516293363), (0.002, 0.2055713802925667), (0.005, 0.15574975343297148), (0.01, 0.12843020222861343), (0.02, 0.10797588107849927), (0.05, 0.08867698410088286), (0.1, 0.07749710871105607), (0.2, 0.06824739056183722), (0.5, 0.057956525318736705), (1.0, 0.05050301672031405)], # 'DAE': [(0.001, 0.2441703278134424), (0.002, 0.19592767826968308), (0.005, 0.14740440785979805), (0.01, 0.12183063178228161), (0.02, 0.1024848551783867), (0.05, 0.08429144793424902), (0.1, 0.07442198872784735), (0.2, 0.06632051023795682), (0.5, 0.05714302612312982), (1.0, 0.05050301672031405)], # 'CAE': [(0.001, 0.2564210882054672), (0.002, 0.20924660841017487), (0.005, 0.16076881496092835), (0.01, 0.1325765230591466), (0.02, 0.11132618820467256), (0.05, 0.09089712800606133), (0.1, 0.07922084751978385), (0.2, 0.06933241629113954), (0.5, 0.058349474860945535), (1.0, 0.05050301672031405)], # 'VAE': [(0.001, 0.2864384685945949), (0.002, 0.22214913339448517), (0.005, 0.15730739321751025), (0.01, 0.12415463932062157), (0.02, 0.10121123325141194), (0.05, 0.08010235972535741), (0.1, 0.06876865603310822), (0.2, 0.06069940080002707), (0.5, 0.05329987492643488), (1.0, 0.05050301672031405)], # 'KSAE': [(0.001, 0.2766257905663044), (0.002, 0.23146091826389079), (0.005, 0.18327753964039104), (0.01, 0.15379102261032626), (0.02, 0.13065375342492558), (0.05, 0.10653729926356474), (0.1, 0.09147214149778002), (0.2, 0.07821235936221227), (0.5, 0.0621660351341905), (1.0, 0.05050301672031405)], # 'KATE': [(0.001, 0.5370057451841842), (0.002, 0.49623424902234925), (0.005, 0.4398091950534882), (0.01, 0.39517410082761617), (0.02, 0.3440230238886337), (0.05, 0.26452986431933806), (0.1, 0.19919513465212774), (0.2, 0.14045712651660616), (0.5, 0.08193839335997657), (1.0, 0.05050301672031405)]} # precisions = { # 'DocNADE': [(100, 0.5620457973399164), (120, 0.6721578198088268), (150, 0.6984651711924437), (200, 0.6809496236247824), (300, 0.518887505188875), (1000, 0.3119956966110817), (1500, 0.1818181818181818), (2000, 0.13636363636363635), (4000, 0.03305785123966942)], # 'KCAE': [(100, 0.517573929338634), (120, 0.6815131177547284), (150, 0.7079102715466347), (200, 0.7348002316155173), (300, 0.6832710668327107), (1000, 0.6503496503496502), (1500, 0.6969696969696969), (2000, 0.8522727272727273), (4000, 0.42975206611570244)], # } # plot_info_retrieval_by_length(precisions, sys.argv[1]) plot_info_retrieval(precisions, sys.argv[1]) # # Effect of number of topics # x = [20, 32, 64, 128, 256, 512, 1024, 1500] # y = [0.546, 0.694, 0.719, 0.744, 0.747, 0.761, 0.767, 0.713] # plot(x, y, 'Number of topics', 'Classification accuracy', sys.argv[1]) # Effect of alpha # x = [0.0625, 0.3, 1, 3, 6, 9, 12] # y = [0.711, 0.706, 0.739, 0.738, 0.743, 0.746, 0.743] # plot(x, y, r'$\alpha$', 'Classification accuracy', sys.argv[1]) # # # Effect of k # x = [2, 4, 6, 8, 16, 32, 64, 96, 128] # y = [0.729, 0.728, 0.720, 0.738, 0.737, 0.744, 0.739, 0.733, 0.714] # plot(x, y, r'$k$', 'Classification accuracy', sys.argv[1]) # # scalability # x = [0.5, 1, 1.5, 2] # y = [6 , 12.3, 15.8, 49.5] # plot(x, y, r'training set size', 'runtime (h)', sys.argv[1])