# -*- coding: utf-8 -*- # MIT License # # Copyright (c) 2018 ZhicongYan # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ============================================================================== import os import sys import numpy as np sys.path.append('.') sys.path.append('../') import tensorflow as tf import tensorflow.contrib.layers as tcl from tensorflow.contrib.tensorboard.plugins import projector from .base_validator import BaseValidator class TensorboardEmbedding(BaseValidator): """ Plot the model output or mediate features into tensorboard embedding panel. """ def __init__(self, config): super(TensorboardEmbedding, self).__init__(config) self.assets_dir = self.config['assets dir'] self.log_dir = self.config.get('log dir', 'embedding') self.log_dir = os.path.join(self.assets_dir, self.log_dir) self.z_shape = list(self.config['z shape']) self.x_shape = list(self.config['x shape']) self.nb_samples = self.config.get('nb samples', 1000) self.batch_size = self.config.get('batch_size', 100) self.nb_samples = self.nb_samples // self.batch_size * self.batch_size if not os.path.exists(self.log_dir): os.mkdir(self.log_dir) with open(os.path.join(self.log_dir, 'metadata.tsv'), 'w') as f: f.write("Index\tLabel\n") for i in range(self.nb_samples): f.write("%d\t%d\n"%(i, 0)) for i in range(self.nb_samples): f.write("%d\t%d\n"%(i+self.nb_samples, 1)) summary_writer = tf.summary.FileWriter(self.log_dir) config = projector.ProjectorConfig() embedding = config.embeddings.add() embedding.tensor_name = "test" embedding.metadata_path = "metadata.tsv" projector.visualize_embeddings(summary_writer, config) self.plot_array_var = tf.get_variable('test', shape=[self.nb_samples*2, int(np.product(self.x_shape))]) self.saver = tf.train.Saver([self.plot_array_var]) def validate(self, model, dataset, sess, step): plot_array_list = [] indices = dataset.get_image_indices(phase='train', method='unsupervised') indices = np.random.choice(indices, size=self.nb_samples) for i, ind in enumerate(indices): test_x = dataset.read_image_by_index(ind, phase='train', method='unsupervised') if isinstance(test_x, list): for x in test_x: x = x.reshape([-1,]) plot_array_list.append(x) if len(plot_array_list) >= self.nb_samples: break elif test_x is not None: test_x = test_x.reshape([-1,]) plot_array_list.append(test_x) if len(plot_array_list) >= self.nb_samples: break for i in range(self.nb_samples // self.batch_size): batch_z = np.random.randn(*([self.nb_samples,] + self.z_shape)) batch_x = model.generate(sess, batch_z) for i in range(self.batch_size): plot_array_list.append(batch_x[i].reshape([-1])) plot_array_list = np.array(plot_array_list) sess.run(self.plot_array_var.assign(plot_array_list)) self.saver.save(sess, os.path.join(self.log_dir, 'model.ckpt'), global_step=step, write_meta_graph=False, strip_default_attrs=True)