# -*- coding: utf-8 -*- # Copyright (c) 2015-2016 MIT Probabilistic Computing Project # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from math import log import matplotlib.pyplot as plt import numpy as np from cgpm.utils import general as gu from cgpm.utils.config import colors _plot_layout = {1: (1,1), 2: (2,1), 3: (3,1), 4: (2,2), 5: (3,2), 6: (3,2), 7: (4,2), 8: (4,2), 9: (3,3), 10: (5,2), 11: (4,3), 12: (4,3), 13: (5,3), 14: (5,3), 15: (5,3), 16: (4,4), 17: (6,3), 18: (6,3), 19: (5,4), 20: (5,4), 21: (7,3), 22: (6,4), 23: (6,4), 24: (6,4), } def get_state_plot_layout(n_cols): layout = dict() layout['plots_x'] = _plot_layout[n_cols][0] layout['plots_y'] = _plot_layout[n_cols][1] layout['plot_inches_x'] = 13/6. * layout['plots_x'] layout['plot_inches_y'] = 6. * layout['plots_y'] layout['border_color'] = colors() return layout def plot_dist_continuous(X, output, clusters, ax=None, Y=None, hist=True): # Create a new axis? if ax is None: _, ax = plt.subplots() # Set up x axis. x_min = min(X) x_max = max(X) if Y is None: Y = np.linspace(x_min, x_max, 200) # Compute weighted pdfs. pdf = np.zeros((len(clusters), len(Y))) W = [log(clusters[k].N) - log(float(len(X))) for k in clusters] for i, k in enumerate(clusters): pdf[i,:] = np.exp( [W[i] + clusters[k].logpdf(None, {output:y}) for y in Y]) color, alpha = gu.curve_color(i) ax.plot(Y, pdf[i,:], color=color, linewidth=5, alpha=alpha) # Plot the sum of pdfs. ax.plot(Y, np.sum(pdf, axis=0), color='black', linewidth=3) # Plot the samples. if hist: nbins = min([len(X), 50]) ax.hist( X, nbins, normed=True, color='black', alpha=.5, edgecolor='none') else: y_max = ax.get_ylim()[1] for x in X: ax.vlines(x, 0, y_max/10., linewidth=1) # Title. ax.set_title(clusters.values()[0].name()) return ax def plot_dist_discrete(X, output, clusters, ax=None, Y=None, hist=True): # Create a new axis? if ax is None: _, ax = plt.subplots() # Set up x axis. X = np.asarray(X, dtype=int) x_max = max(X) Y = range(int(x_max)+1) X_hist = np.bincount(X) / float(len(X)) ax.bar(Y, X_hist, color='gray', edgecolor='none') # Compute weighted pdfs pdf = np.zeros((len(clusters), len(Y))) W = [log(clusters[k].N) - log(float(len(X))) for k in clusters] for i, k in enumerate(clusters): pdf[i,:] = np.exp( [W[i] + clusters[k].logpdf(None, {output:y}) for y in Y]) color, alpha = gu.curve_color(i) ax.bar(Y, pdf[i,:], color=color, edgecolor='none', alpha=alpha) # Plot the sum of pdfs. ax.bar( Y, np.sum(pdf, axis=0), color='none', edgecolor='black', linewidth=3) ax.set_xlim([0, x_max+1]) # Title. ax.set_title(clusters.values()[0].name()) return ax def plot_clustermap(D, xticklabels=None, yticklabels=None): import seaborn as sns if xticklabels is None: xticklabels = range(D.shape[0]) if yticklabels is None: yticklabels = range(D.shape[1]) zmat = sns.clustermap( D, yticklabels=yticklabels, xticklabels=xticklabels, linewidths=0.2, cmap='BuGn') plt.setp(zmat.ax_heatmap.get_yticklabels(), rotation=0) plt.setp(zmat.ax_heatmap.get_xticklabels(), rotation=90) return zmat def clustermap_ordering(D): zmat = plot_clustermap(D) plt.close(zmat.fig) return zmat.dendrogram_row.reordered_ind def plot_heatmap( D, xordering=None, yordering=None, xticklabels=None, yticklabels=None, vmin=None, vmax=None, ax=None): import seaborn as sns D = np.copy(D) if ax is None: _, ax = plt.subplots() if xticklabels is None: xticklabels = np.arange(D.shape[0]) if yticklabels is None: yticklabels = np.arange(D.shape[1]) if xordering is not None: xticklabels = xticklabels[xordering] D = D[:,xordering] if yordering is not None: yticklabels = yticklabels[yordering] D = D[yordering,:] sns.heatmap( D, yticklabels=yticklabels, xticklabels=xticklabels, linewidths=0.2, cmap='BuGn', ax=ax, vmin=vmin, vmax=vmax) ax.set_xticklabels(xticklabels, rotation=90) ax.set_yticklabels(yticklabels, rotation=0) return ax def plot_samples(X, ax=None): if ax is None: _, ax = plt.subplots() ax.set_ylim([0, 10]) for x in X: ax.vlines(x, 0, 1., linewidth=1) return ax def partition_to_zmatrix(Zv, ordering=None): """Convert a cgpm.crosscat.State view partition Zv into a binary zmatrix.""" # Default ordering of columns by increasing index. if ordering is None: ordering = sorted(Zv) # Converts a column index to its 0-based index in ordering. column_to_index = {col: ordering.index(col) for col in Zv} # block_vectors[i] is a binary vector, with 1 for columns in that view. views = set(Zv.values()) block_vectors = {view: np.zeros(len(Zv)) for view in views} for view in views: cols = [column_to_index[c] for c, v in Zv.iteritems() if v == view] block_vectors[view][cols] = 1 D = np.zeros((len(Zv), len(Zv))) for col in Zv: D[column_to_index[col]] = block_vectors[Zv[col]] return D def partitions_to_zmatrix(Zvs, ordering=None): """Converts a collection cgpm.crosscat.State view partitions Zvs into a real-valued zmatrix, which is the mean of all Zv.""" Ds = [partition_to_zmatrix(Zv, ordering=ordering) for Zv in Zvs] return np.mean(Ds, axis=0) def plot_logscore(logscores, ax=None): assert all(len(l) == len(logscores[0]) for l in logscores) if ax is None: fig, ax = plt.subplots() for logscore in logscores: ax.plot(range(len(logscores[0])), logscore) ax.set_xlabel('Number of Full Gibbs Sweeps') ax.set_ylabel('Log Score') ax.grid() return ax def engine_to_zmatrix_history(engine, ordering=None): num_transitions = len(engine.states[0].diagnostics['column_partition']) Zvs = [[dict(state.diagnostics['column_partition'][i]) for state in engine.states] for i in xrange(num_transitions)] # Find the ordering at the final step. if ordering is None: D = partitions_to_zmatrix(Zvs[-1]) ordering = clustermap_ordering(D) # Return the history of zmatrices. return [partitions_to_zmatrix(Z, ordering=ordering) for Z in Zvs]