#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Tue Sep 11 10:24:14 2018 @author: psanch """ from base.base_model import BaseModel import tensorflow as tf import numpy as np from VAE_graph import VAEGraph from VAECNN_graph import VAECNNGraph from utils.logger import Logger from utils.early_stopping import EarlyStopping from tqdm import tqdm import sys import utils.utils as utils import utils.constants as const class VAEModel(BaseModel): def __init__(self,network_params,sigma=0.001, sigma_act=tf.nn.softplus, transfer_fct= tf.nn.relu,learning_rate=0.002, kinit=tf.contrib.layers.xavier_initializer(),batch_size=32, drop_rate=0., epochs=200, checkpoint_dir='', summary_dir='', result_dir='', restore=0, model_type=0): super().__init__(checkpoint_dir, summary_dir, result_dir) self.batch_size = batch_size self.drop_rate = drop_rate self.epochs = epochs self.z_file = result_dir + '/z_file' self.restore = restore # Creating computational graph for train and test self.graph = tf.Graph() with self.graph.as_default(): if(model_type == const.VAE): self.vae_graph = VAEGraph(network_params,sigma, sigma_act, transfer_fct,learning_rate, kinit,batch_size, reuse=False) if(model_type == const.VAECNN): self.vae_graph = VAECNNGraph(network_params,sigma, sigma_act, transfer_fct,learning_rate, kinit,batch_size, reuse=False) self.vae_graph.build_graph() self.trainable_count = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]) def train_epoch(self, session,logger, data_train, beta=1): loop = tqdm(range(data_train.num_batches(self.batch_size))) losses = [] recons = [] cond_prior = [] L2_loss = [] for _ in loop: batch_x = next(data_train.next_batch(self.batch_size)) loss, recon, cond, L2_loss_curr = self.vae_graph.partial_fit(session, batch_x, beta, self.drop_rate) losses.append(loss) recons.append(recon) cond_prior.append(cond) L2_loss.append(L2_loss_curr) loss_tr = np.mean(losses) recons_tr = np.mean(recons) cond_prior_tr = np.mean(cond_prior) L2_loss = np.mean(L2_loss) cur_it = self.vae_graph.global_step_tensor.eval(session) summaries_dict = { 'loss': loss_tr, 'recons_loss': recons_tr, 'KL_loss': cond_prior_tr, 'L2_loss': L2_loss } logger.summarize(cur_it, summaries_dict=summaries_dict) return loss_tr, recons_tr, cond_prior_tr, L2_loss def valid_epoch(self, session, logger, data_valid,beta=1): # COMPUTE VALID LOSS loop = tqdm(range(data_valid.num_batches(self.batch_size))) losses_val = [] recons_val = [] cond_prior_val = [] for _ in loop: batch_x = next(data_valid.next_batch(self.batch_size)) loss, recon, cond, _ = self.vae_graph.evaluate(session, batch_x, beta) losses_val.append(loss) recons_val.append(recon) cond_prior_val.append(cond) loss_val = np.mean(losses_val) recons_val = np.mean(recons_val) cond_prior_val = np.mean(cond_prior_val) cur_it = self.vae_graph.global_step_tensor.eval(session) summaries_dict = { 'loss': loss_val, 'recons_loss': recons_val, 'KL_loss': cond_prior_val } logger.summarize(cur_it, summarizer="test", summaries_dict=summaries_dict) return loss_val, recons_val, cond_prior_val def train(self, data_train, data_valid, enable_es=1): with tf.Session(graph=self.graph) as session: tf.set_random_seed(1234) logger = Logger(session, self.summary_dir) # here you initialize the tensorflow saver that will be used in saving the checkpoints. # max_to_keep: defaults to keeping the 5 most recent checkpoints of your model saver = tf.train.Saver() early_stopping = EarlyStopping() if(self.restore==1 and self.load(session, saver) ): num_epochs_trained = self.vae_graph.cur_epoch_tensor.eval(session) print('EPOCHS trained: ', num_epochs_trained) else: print('Initizalizing Variables ...') tf.global_variables_initializer().run() if(self.vae_graph.cur_epoch_tensor.eval(session) == self.epochs): return for cur_epoch in range(self.vae_graph.cur_epoch_tensor.eval(session), self.epochs + 1, 1): print('EPOCH: ', cur_epoch) self.current_epoch = cur_epoch # beta=utils.sigmoid(cur_epoch- 50) beta = 1. loss_tr, recons_tr, cond_prior_tr, L2_loss = self.train_epoch(session, logger, data_train, beta=beta) if np.isnan(loss_tr): print ('Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.') print('Recons: ', recons_tr) print('KL: ', cond_prior_tr) sys.exit() loss_val, recons_val, cond_prior_val = self.valid_epoch(session, logger, data_valid, beta=beta) print('TRAIN | Loss: ', loss_tr, ' | Recons: ', recons_tr, ' | KL: ', cond_prior_tr, ' | L2_loss: ', L2_loss) print('VALID | Loss: ', loss_val, ' | Recons: ', recons_val, ' | KL: ', cond_prior_val) if(cur_epoch>0 and cur_epoch % 10 == 0): self.save(session, saver, self.vae_graph.global_step_tensor.eval(session)) z_matrix = self.vae_graph.get_z_matrix(session, data_valid.random_batch(self.batch_size)) np.savez(self.z_file, z_matrix) session.run(self.vae_graph.increment_cur_epoch_tensor) #Early stopping if(enable_es==1 and early_stopping.stop(loss_val)): print('Early Stopping!') break self.save(session,saver, self.vae_graph.global_step_tensor.eval(session)) z_matrix = self.vae_graph.get_z_matrix(session, data_valid.random_batch(self.batch_size)) np.savez(self.z_file, z_matrix) return def generate_samples(self, data): with tf.Session(graph=self.graph) as session: saver = tf.train.Saver() if(self.load(session, saver)): num_epochs_trained = self.vae_graph.cur_epoch_tensor.eval(session) print('EPOCHS trained: ', num_epochs_trained) else: return x_batch = data.random_batch(self.batch_size) x_samples, z_samples = self.vae_graph.generate_samples(session, x_batch, beta=1) return x_samples, z_samples def reconstruct_input(self, data): with tf.Session(graph=self.graph) as session: saver = tf.train.Saver() if(self.load(session, saver)): num_epochs_trained = self.vae_graph.cur_epoch_tensor.eval(session) print('EPOCHS trained: ', num_epochs_trained) else: return x_batch, x_labels = data.random_batch_with_labels(self.batch_size) x_recons, z_recons = self.vae_graph.reconstruct_input(session, x_batch, beta=1) return x_batch, x_labels, x_recons, z_recons def generate_embedding(self, data): with tf.Session(graph=self.graph) as session: saver = tf.train.Saver() if(self.load(session, saver)): num_epochs_trained = self.vae_graph.cur_epoch_tensor.eval(session) print('EPOCHS trained: ', num_epochs_trained) else: return x_batch, x_labels = data.random_batch_with_labels(self.batch_size) x_recons, z_recons = self.vae_graph.reconstruct_input(session, x_batch, beta=1) return x_batch, x_labels, x_recons, z_recons ''' ------------------------------------------------------------------------------ DISTRIBUTIONS ------------------------------------------------------------------------------ ''' def print_parameters(): print('')