#!/usr/bin/env python # coding=utf-8 from __future__ import (print_function, division, absolute_import, unicode_literals) import os import numpy as np import tensorflow as tf import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from matplotlib.colors import hsv_to_rgb from sklearn.metrics import adjusted_mutual_info_score def save_image(filename, image_array): import scipy.misc if image_array.shape[2] == 1: if np.min(image_array) >= 0.: scipy.misc.toimage(image_array[:, :, 0], cmin=0.0, cmax=1.0).save(filename) else: scipy.misc.toimage(image_array[:, :, 0], cmin=-1.0, cmax=1.0).save(filename) else: scipy.misc.toimage(255*image_array).save(filename) def delete_files(folder, recursive=False): for the_file in os.listdir(folder): file_path = os.path.join(folder, the_file) try: if os.path.isfile(file_path): os.unlink(file_path) elif recursive and os.path.isdir(file_path): delete_files(file_path, recursive) os.unlink(file_path) except Exception as e: print(e) def create_directory(directory): if not os.path.exists(directory): os.makedirs(directory) def print_vars(_vars): total_n_vars = 0 for var in _vars: sh = var.get_shape().as_list() total_n_vars += np.prod(sh) print(var.name, sh) print(total_n_vars, 'total variables') return total_n_vars ACTIVATION_FUNCTIONS = { 'sigmoid': tf.nn.sigmoid, 'tanh': tf.nn.tanh, 'relu': tf.nn.relu, 'elu': tf.nn.elu, 'linear': lambda x: x, 'exp': lambda x: tf.exp(x), 'softplus': tf.nn.softplus, 'clip': lambda x: tf.clip_by_value(x, -1., 1.), 'clip_low': lambda x: tf.clip_by_value(x, -1., 1e6) } def parse_activation_function(name_list): return [ACTIVATION_FUNCTIONS[name] for name in name_list] def evaluate_groups_seq(true_groups, predicted, weights): """ Compute the weighted AMI score and corresponding mean confidence for given gammas. :param true_groups: (T, B, 1, W, H, 1) :param predicted: (T, B, K, W, H, 1) :param weights: (T) :return: scores, confidences (B,) """ w_scores, w_confidences = 0., 0. assert true_groups.ndim == predicted.ndim == 6, true_groups.shape for t in range(true_groups.shape[0]): scores, confidences = evaluate_groups(true_groups[t], predicted[t]) w_scores += weights[t] * np.array(scores) w_confidences += weights[t] * np.array(confidences) norm = np.sum(weights) return w_scores/norm, w_confidences/norm def evaluate_groups(true_groups, predicted): """ Compute the AMI score and corresponding mean confidence for given gammas. :param true_groups: (B, 1, W, H, 1) :param predicted: (B, K, W, H, 1) :return: scores, confidences (B,) """ scores, confidences = [], [] assert true_groups.ndim == predicted.ndim == 5, true_groups.shape batch_size, K = predicted.shape[:2] true_groups = true_groups.reshape(batch_size, -1) predicted = predicted.reshape(batch_size, K, -1) predicted_groups = predicted.argmax(1) predicted_conf = predicted.max(1) for i in range(batch_size): true_group = true_groups[i] idxs = np.where(true_group != 0.0)[0] scores.append(adjusted_mutual_info_score(true_group[idxs], predicted_groups[i, idxs])) confidences.append(np.mean(predicted_conf[i, idxs])) return scores, confidences def tf_adjusted_rand_index(groups, gammas, iter_weights): """ Inputs: groups: shape=(T, B, 1, W, H, 1) These are the masks as stored in the hdf5 files gammas: shape=(T, B, K, W, H, 1) These are the gammas as predicted by the network """ with tf.name_scope('ARI'): # ignore first iteration groups = groups[1:] gammas = gammas[1:] # reshape gammas and convert to one-hot yshape = tf.shape(gammas) gammas = tf.reshape(gammas, shape=tf.stack([yshape[0] * yshape[1], yshape[2], yshape[3] * yshape[4] * yshape[5]])) Y = tf.one_hot(tf.argmax(gammas, axis=1), depth=yshape[2], axis=1) # reshape masks gshape = tf.shape(groups) groups = tf.reshape(groups, shape=tf.stack([gshape[0] * gshape[1], 1, gshape[3] * gshape[4] * gshape[5]])) G = tf.one_hot(tf.cast(groups[:, 0], tf.int32), depth=tf.cast(tf.reduce_max(groups) + 1, tf.int32), axis=1) # now Y and G both have dim (B*T, K, N) where N=W*H*C # mask entries with group 0 M = tf.cast(tf.greater(groups, 0.5), tf.float32) n = tf.cast(tf.reduce_sum(M, axis=[1, 2]), tf.float32) DM = G * M YM = Y * M # contingency table for overlap between G and Y nij = tf.einsum('bij,bkj->bki', YM, DM) a = tf.reduce_sum(nij, axis=1) b = tf.reduce_sum(nij, axis=2) # rand index rindex = tf.cast(tf.reduce_sum(nij * (nij-1), axis=[1, 2]), tf.float32) aindex = tf.cast(tf.reduce_sum(a * (a-1), axis=1), tf.float32) bindex = tf.cast(tf.reduce_sum(b * (b-1), axis=1), tf.float32) expected_rindex = aindex * bindex / (n*(n-1) + 1e-6) max_rindex = (aindex + bindex) / 2 ARI = (rindex - expected_rindex)/tf.clip_by_value(max_rindex - expected_rindex, 1e-6, 1e6) ARI = tf.reshape(ARI, shape=(yshape[0], yshape[1])) iter_weigths= tf.constant(np.array(iter_weights)[:, None], dtype=tf.float32) sum_iter_weights = tf.constant(np.sum(iter_weights), dtype=tf.float32) seq_ARI = tf.reduce_mean(tf.reduce_sum(ARI * iter_weigths, axis=0) / sum_iter_weights) last_ARI = tf.reduce_mean(ARI[-1]) confidences = tf.reduce_sum(tf.reduce_max(gammas, axis=1, keep_dims=True) * M, axis=[1, 2]) / n confidences = tf.reshape(confidences, shape=(yshape[0], yshape[1])) seq_conf = tf.reduce_mean(tf.reduce_sum(confidences * iter_weigths, axis=0) / sum_iter_weights) last_conf = tf.reduce_mean(confidences[-1]) return seq_ARI, last_ARI, seq_conf, last_conf def color_spines(ax, color, lw=2): for sn in ['top', 'bottom', 'left', 'right']: ax.spines[sn].set_linewidth(lw) ax.spines[sn].set_color(color) ax.spines[sn].set_visible(True) def color_half_spines(ax, color1, color2, lw=2): for sn in ['top', 'left']: ax.spines[sn].set_linewidth(lw) ax.spines[sn].set_color(color1) ax.spines[sn].set_visible(True) for sn in ['bottom', 'right']: ax.spines[sn].set_linewidth(lw) ax.spines[sn].set_color(color2) ax.spines[sn].set_visible(True) def get_gamma_colors(nr_colors): hsv_colors = np.ones((nr_colors, 3)) hsv_colors[:, 0] = (np.linspace(0, 1, nr_colors, endpoint=False) + 2/3) % 1.0 color_conv = hsv_to_rgb(hsv_colors) return color_conv def overview_plot(i, gammas, preds, inputs, corrupted=None, **kwargs): attentions = np.array(kwargs['attentions']) if 'attentions' in kwargs else None T, B, K, W, H, C = gammas.shape T -= 1 # the initialization doesn't count as iteration corrupted = corrupted if corrupted is not None else inputs gamma_colors = get_gamma_colors(K) # restrict to sample i and get rid of useless dims inputs = inputs[:, i, 0] gammas = gammas[:, i, :, :, :, 0] if preds.shape[1] != B: preds = preds[:, 0] preds = preds[:, i] corrupted = corrupted[:, i, 0] inputs = np.clip(inputs, 0., 1.) preds = np.clip(preds, 0., 1.) corrupted = np.clip(corrupted, 0., 1.) def plot_img(ax, data, cmap='Greys_r', xlabel=None, ylabel=None, border_color=None): if data.shape[-1] == 1: ax.matshow(data[:, :, 0], cmap=cmap, vmin=0., vmax=1., interpolation='nearest') else: ax.imshow(data, interpolation='nearest') ax.set_xticks([]); ax.set_yticks([]) ax.set_xlabel(xlabel, color=border_color or 'k') if xlabel else None ax.set_ylabel(ylabel, color=border_color or 'k') if ylabel else None if border_color: color_spines(ax, color=border_color) def plot_attention_summary_img(ax, attention, k_excluded, preds, cmap='Greys_r', xlabel=None, ylabel=None, border_color=None): # copy so we don't mutate preds = np.copy(preds) attention = np.copy(attention) # get focus object as rgb version of black-white focus_pred = np.tile(np.copy(preds[k_excluded]), [1, 1, 3]) # we are safe to do what we want to do to the k_excluded row preds[k_excluded] = 0 attention = np.insert(attention, k_excluded, 0) # zero out the focus object # mask preds by attention preds = np.transpose(preds[:, :, :, 0], [1, 2, 0]) # (28, 28, K) preds *= attention # color the preds preds = preds.reshape(-1, preds.shape[-1]).dot(gamma_colors).reshape(preds.shape[:-1] + (3,)) # add in the focus object preds += focus_pred # plot ax.imshow(preds, interpolation='nearest') ax.set_xticks([]); ax.set_yticks([]) ax.set_xlabel(xlabel, color=border_color or 'k') if xlabel else None ax.set_ylabel(ylabel, color=border_color or 'k') if ylabel else None if border_color: color_spines(ax, color=border_color) def plot_gamma(ax, gamma, xlabel=None, ylabel=None): gamma = np.transpose(gamma, [1, 2, 0]) gamma = gamma.reshape(-1, gamma.shape[-1]).dot(gamma_colors).reshape(gamma.shape[:-1] + (3,)) ax.imshow(gamma, interpolation='nearest') ax.set_xticks([]) ax.set_yticks([]) ax.set_xlabel(xlabel) if xlabel else None ax.set_ylabel(ylabel) if ylabel else None nrows, ncols = (K + 4, T + 1) fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(2 * ncols, 2 * nrows)) axes[0, 0].set_visible(False) axes[1, 0].set_visible(False) plot_gamma(axes[2, 0], gammas[0], ylabel='Gammas') for k in range(K + 1): axes[k + 3, 0].set_visible(False) for t in range(1, T + 1): g = gammas[t] p = preds[t] reconst = np.sum(g[:, :, :, None] * p, axis=0) plot_img(axes[0, t], inputs[t]) plot_img(axes[1, t], reconst) plot_gamma(axes[2, t], g) for k in range(K): if attentions is not None: plot_attention_summary_img(axes[k + 3, t], attentions[t-1, i, k], k, p, border_color=tuple(gamma_colors[k]), ylabel=('contexts {} for {}'.format(k-1, k) if t == 1 else None)) else: plot_img(axes[k + 3, t], p[k], border_color=tuple(gamma_colors[k]), ylabel=('mu_{}'.format(k) if t == 1 else None)) plot_img(axes[K + 3, t], corrupted[t - 1]) plt.subplots_adjust(hspace=0.1, wspace=0.1) return fig def curve_plot(values_dict, coarse_range, fine_range): if fine_range is not None: fig, ax = plt.subplots(1, 2, figsize=(40, 10)) else: fig, ax = plt.subplots(1, 1, figsize=(20, 10)) ax = [ax] for key, values in values_dict.items(): # coarse ax[0].plot(values, label=key) ax[0].set_xlabel('epochs') ax[0].axis([0, len(values), coarse_range[0], coarse_range[1]]) ax[0].set_title("coarse range") ax[0].legend() # fine if fine_range is not None: ax[1].plot(values, label=key) ax[1].set_xlabel('epochs') ax[1].axis([0, len(values), fine_range[0], fine_range[1]]) ax[1].set_title("fine range") ax[1].legend() return fig