""" Helpers - Mostly plotting functions =================================== """ from matplotlib import pyplot as plt from rdkit import Chem from rdkit.Chem import AllChem from rdkit.Chem.Draw import rdMolDraw2D from rdkit.Chem import rdDepictor import numpy as np import pandas as pd import seaborn as sns def _prepare_mol(mol, kekulize): """Prepare mol for SVG depiction (embed 2D coords) """ mc = Chem.Mol(mol.ToBinary()) if kekulize: try: Chem.Kekulize(mc) except: mc = Chem.Mol(mol.ToBinary()) if not mc.GetNumConformers(): rdDepictor.Compute2DCoords(mc) return mc def mol_to_svg(mol, molSize=(300, 300), kekulize=True, drawer=None, font_size=0.8, **kwargs): """Generates a SVG from mol structure. Inspired by: http://rdkit.blogspot.ch/2016/02/morgan-fingerprint-bit-statistics.html Parameters ---------- mol : rdkit.Chem.rdchem.Mol molSize : tuple kekulize : bool drawer : funct Specify which drawing function to use (default: rdMolDraw2D.MolDraw2DSVG) font_size : float Atom font size Returns ------- IPython.display.SVG """ from IPython.display import SVG mc = _prepare_mol(mol, kekulize) mol_atoms = [a.GetIdx() for a in mc.GetAtoms()] if drawer is None: drawer = rdMolDraw2D.MolDraw2DSVG(*molSize) drawer.SetFontSize(font_size) drawer.DrawMolecule(mc, highlightAtomRadii={x: 0.5 for x in mol_atoms}, **kwargs) drawer.FinishDrawing() svg = drawer.GetDrawingText() return SVG(svg.replace('svg:', '')) def depict_atoms(mol, atom_ids, radii, molSize=(300, 300), atm_color=(0, 1, 0), oth_color=(0.8, 1, 0)): """Get a depiction of molecular substructure. Useful for depicting bits in fingerprints. Inspired by: http://rdkit.blogspot.ch/2016/02/morgan-fingerprint-bit-statistics.html Parameters ---------- mol : rdkit.Chem.rdchem.Mol atom_ids : list List of atoms to depict radii : list List of radii - how many atoms around each atom with atom_id to highlight molSize : tuple atm_color, oth_color : tuple Colors of central atoms and surrounding atoms and bonds Returns ------- IPython.display.SVG """ atoms_to_use = [] bonds = [] for atom_id, radius in zip(atom_ids, radii): if radius > 0: env = Chem.FindAtomEnvironmentOfRadiusN(mol, radius, atom_id) bonds += [x for x in env if x not in bonds] for b in env: atoms_to_use.append(mol.GetBondWithIdx(b).GetBeginAtomIdx()) atoms_to_use.append(mol.GetBondWithIdx(b).GetEndAtomIdx()) atoms_to_use = list(set(atoms_to_use)) else: atoms_to_use.append(atom_id) env = None if sum(radii) == 0: return mol_to_svg(mol, molSize=molSize, highlightBonds=False, highlightAtoms=atoms_to_use, highlightAtomColors={x: atm_color for x in atom_ids}) else: colors = {x: atm_color for x in atom_ids} for x in atoms_to_use: if x not in atom_ids: colors[x] = oth_color bond_colors = {b: oth_color for b in bonds} return mol_to_svg(mol, molSize=molSize, highlightAtoms=atoms_to_use, highlightAtomColors=colors, highlightBonds=bonds, highlightBondColors=bond_colors) def depict_identifier(mol, identifier, radius, useFeatures=False, **kwargs): """Depict an identifier in Morgan fingerprint. Parameters ---------- mol : rdkit.Chem.rdchem.Mol RDKit molecule identifier : int or str Feature identifier from Morgan fingerprint radius : int Radius of Morgan FP useFeatures : bool Use feature-based Morgan FP Returns ------- IPython.display.SVG """ identifier = int(identifier) info = {} AllChem.GetMorganFingerprint(mol, radius, bitInfo=info, useFeatures=useFeatures) if identifier in info.keys(): atoms, radii = zip(*info[identifier]) return depict_atoms(mol, atoms, radii, **kwargs) else: return mol_to_svg(mol, **kwargs) def plot_class_distribution(df, x_col, y_col, c_col, ratio=0.1, n=1, marker='o', alpha=1, x_label='auto', y_label='auto', cmap=plt.cm.viridis, size=(8,8), share_axes=False): """Scatter + histogram plots of x and y, e.g. after t-SNE dimensionality reduction. Colors are wrong in scatter plot if len(class) < 4. Open issue in matplotlib. (See: https://github.com/pandas-dev/pandas/issues/9724) Parameters ---------- df : pd.DataFrame Dataframe with our data {x,y}_col : str Name of a column with {x,y} values c_col : str Name of a column with classes (basis for hue) ratio : float Ratio to determine empty space of limits of x/y-axis marker : str Marker in scatter plot n : float Number of columns of legend alpha : float Alpha for scatter plot x_label : str Label of x-axis, default auto: x_col name y_label : str Label of y-axis, default auto: y_col name cmap : matplotlib.colors.ListedColormap size : tuple """ if y_label is 'auto': y_label = y_col if x_label is 'auto': x_label = x_col f, ((h1, xx), (sc, h2)) = plt.subplots(2,2, squeeze=True, sharex=share_axes, sharey=share_axes, figsize=size, gridspec_kw={'width_ratios': [3, 1], 'height_ratios': [1, 3]}) f.subplots_adjust(hspace=0.1, wspace=0.1) xx.axis('off') ratio_xaxis = (max(df[x_col]) - min(df[x_col])) * ratio ratio_yaxis = (max(df[y_col]) - min(df[y_col])) * ratio x_max = max(df[x_col])+ratio_xaxis x_min = min(df[x_col])-ratio_xaxis y_max = max(df[y_col])+ratio_yaxis y_min = min(df[y_col])-ratio_yaxis h1.set_xlim(x_min, x_max) h1.xaxis.set_visible(False) sc.set_xlim(x_min, x_max) sc.set_xlabel(x_label) h2.set_ylim(y_min, y_max) h2.yaxis.set_visible(False) sc.set_ylim(y_min, y_max) sc.set_ylabel(y_label) c_unique = np.sort(df[c_col].unique(),) h, bins = np.histogram(range(len(cmap.colors)), bins=len(c_unique)) # get equally spaced colors from cmap colors = [cmap.colors[int(x)] for x in bins[1:]] for cl, color in zip(c_unique, colors): if len(df[df[c_col] == cl]) > 1: sns.kdeplot(df[df[c_col] == cl][x_col], ax=h1, c=color, label=cl, legend=False) # hist1 sns.kdeplot(df[df[c_col] == cl][y_col], ax=h2, c=color, vertical=True, label=cl, legend=False) # hist2 handles, labels = h1.get_legend_handles_labels() h1.legend(handles, labels, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=n) sc.scatter(df[df[c_col] == cl][x_col], df[df[c_col] == cl][y_col], c=color, marker=marker, alpha=alpha) return f def plot_2D_vectors(vectors, sumup=True, min_max_x=None, min_max_y=None, cmap=plt.cm.viridis_r, colors=None, vector_labels=None, ax=None): """Plots 2d vectors by adding them in sequence and transposing them. Parameters ---------- vectors : list 2D vectors eg: [[0,1], [3,4]] sumup : bool Show a vector that represents a sum of vectors min_max_{x,y} : tuple min and max of {x,y} axis cmap : plt.cm Default: plt.cm.viridis_r colors : list List of matplotlib colors. Number of colors has to match number of vecors (including sum vector if sumup=True). Default=None selects colors from cmap vector_labels : list Has to match number of vecors (including sum vector if sumup=True) ax : plt.ax Name of axis to plot to Returns ------- plt.figure() """ # Transform the vectors soa = [] # vectors with x,y of start point and x,y of end point for x in vectors: if len(soa) == 0: soa.append([0, 0]+list(x)) else: last = soa[-1] soa.append([last[0]+last[2]]+[last[1]+last[3]]+list(x)) if sumup: soa.append([0, 0]+list(sum(vectors))) X, Y, U, V = zip(*soa) if not ax: f = plt.figure() ax = plt.gca() if not colors and sumup: colors = [[cmap.colors[120]]*(len(soa)-1)][0] + [cmap.colors[-1]] if not colors and not sumup: colors = [[cmap.colors[120]]*(len(soa))][0] if vector_labels: if (len(vector_labels) != len(vectors)) and sumup is False: raise Exception('Number of vectors does not match the number of labels') if (len(vector_labels) != len(vectors) + 1) and sumup is True: raise Exception('Number of vectors does not match the number of labels') for x, y, u, v, c, vl in zip(X, Y, U, V, colors, vector_labels): Q = ax.quiver(x, y, u, v, color=c, angles='xy', scale_units='xy', scale=1) ax.quiverkey(Q, x, y, u, vl, coordinates='data', color=[0, 0, 0, 0], labelpos='N') else: ax.quiver(X, Y, U, V, color=colors, angles='xy', scale_units='xy', scale=1) # set plot limits based on positions of vectors if not min_max_x: min_max_x = min([x[0] + x[2] for x in soa]), max([x[0] for x in soa]) if not min_max_y: min_max_y = min([x[1] + x[3] for x in soa]), max([x[1] for x in soa]) # margins on each side margin_x, margin_y = sum(min_max_x)/10., sum(min_max_y)/10. ax.set_xlim(min_max_x[0]+margin_x, min_max_x[1]-margin_x) ax.set_ylim(min_max_y[0]-margin_y, min_max_y[1]+margin_y) return ax class IdentifierTable(object): def _get_depictions(self): """Depicts an identifier on the first molecules that contains that identifier""" for idx in self.identifiers: for mol, sentence in zip(self.mols, self.sentences): if idx in sentence: self.depictions.append(depict_identifier(mol, idx, self.radius, molSize=self.size).data) break def __init__(self, identifiers, mols, sentences, cols, radius, size=(150, 150)): self.mols = mols self.sentences = sentences self.identifiers = identifiers self.cols = cols self.radius = radius self.depictions = [] self.size = size self._get_depictions() def _repr_html_(self): table = '<table style="width:100%">' c = 1 for depict, idx in zip(self.depictions, self.identifiers): if c == 1: table += '<tr>' table += '<td><div align="center">%s</div>\n<div align="center">%s</div></td>' % (depict, idx) if c == self.cols: table += '</tr>' c = 0 c += 1 table += '</table>' return table