""" Used for creating a graph of attention over a fixed number of logits over a sequence. E.g., attention over an input sequence while generating an output sequence. """ import matplotlib.pyplot as plt import numpy as np class AttentionGraph(): """Creates a graph showing attention distributions for inputs and outputs. Attributes: keys (list of str): keys over which attention is done during generation. generated_values (list of str): keeps track of the generated values. attentions (list of list of float): keeps track of the probability distributions. """ def __init__(self, keys): """ Initializes the attention graph. Args: keys (list of string): a list of keys over which attention is done during generation. """ if not keys: raise ValueError("Expected nonempty keys for attention graph.") self.keys = keys self.generated_values = [] self.attentions = [] def add_attention(self, gen_value, probabilities): """ Adds attention scores for all item in `self.keys`. Args: gen_value (string): a generated value for this timestep. probabilities (np.array): probability distribution over the keys. Assumes the order of probabilities corresponds to the order of the keys. Raises: ValueError if `len(probabilities)` is not the same as `len(self.keys)` ValueError if `sum(probabilities)` is not 1 """ if len(probabilities) != len(self.keys): raise ValueError("Length of attention keys is " + str(len(self.keys)) + " but got probabilities of length " + str(len(probabilities))) self.generated_values.append(gen_value) self.attentions.append(probabilities) def render(self, filename): """ Renders the attention graph over timesteps. Args: filename (string): filename to save the figure to. """ figure, axes = plt.subplots() graph = np.stack(self.attentions) axes.imshow(graph, cmap=plt.cm.Blues, interpolation="nearest") axes.xaxis.tick_top() axes.set_xticks(range(len(self.keys))) axes.set_xticklabels(self.keys) plt.setp(axes.get_xticklabels(), rotation=90) axes.set_yticks(range(len(self.generated_values))) axes.set_yticklabels(self.generated_values) axes.set_aspect(1, adjustable='box') plt.tick_params(axis='x', which='both', bottom='off', top='off') plt.tick_params(axis='y', which='both', left='off', right='off') figure.savefig(filename) def render_as_latex(self, filename): """Renders the attention graph as a LaTeX plot Input: filename (str): Name of the file to write to. """ ofile = open(filename, "w") ofile.write( "\\documentclass{article}\\usepackage[margin=0.5in]{geometry}\\usepackage{tikz}" \ + "\\begin{document}\\begin{tikzpicture}[scale=0.25]\\begin{tiny}\\begin{scope}<+->;\n") xstart = 0 ystart = 0 xend = len(self.keys) yend = len(self.generated_values) ofile.write( "\\draw[step=1cm,gray,very thin] (" + str(xstart) + "," + str(ystart) + ") grid (" + str(xend) + ", " + str(yend) + ");\n") for i, tok in enumerate(self.keys): tok = tok.replace("_", r"\_") tok = tok.replace("#", r"\#") ofile.write( "\\draw[gray, xshift=" + str(i) + ".5 cm] (0,0.3) -- (0,0) node[below,rotate=90,anchor=east] {" + tok + "};\n") for i, tok in enumerate(self.generated_values[::-1]): tok = tok.replace("_", r"\_") tok = tok.replace("#", r"\#") ofile.write( "\\draw[gray, yshift=" + str(i) + ".5 cm] (0.3,0) -- (0,0) node[left] {" + tok + "};\n") for i, gentok_atts in enumerate(self.attentions[::-1]): for j, val in enumerate(gentok_atts): if val < 0.001: val = 0 ofile.write("\\filldraw[thin,red,opacity=" + "%.2f" % val + "] (" + str(j) + ", " + str(i) + ") rectangle (" + str(j + 1) + "," + str(i + 1) + ");\n") ofile.write("\\end{scope}\\end{tiny}\\end{tikzpicture}{\\end{document}")