from __future__ import print_function import tensorflow as tf import os from os import listdir from os.path import isfile, join from skimage import io import shutil import sys import math import time import json import logging import numpy as np from PIL import Image from datetime import datetime from tensorflow.core.framework import summary_pb2 import matplotlib.pyplot as plt def make_summary(name, val): return summary_pb2.Summary(value=[summary_pb2.Summary.Value(tag=name, simple_value=val)]) def summary_losses(sess,model,N=1000): step,loss_g,loss_d=sess.run([model.step,model.loss_g,model.loss_d],{model.data.N:N,model.gen.N:N}) lgsum=make_summary(model.data.name+'_gloss',loss_g) ldsum=make_summary(model.data.name+'_dloss',loss_d) return step,lgsum, ldsum def calc_tvd(sess,Generator,Data,N=50000,nbins=10): Xd=sess.run(Data.X,{Data.N:N}) step,Xg=sess.run([Generator.step,Generator.X],{Generator.N:N}) p_gen,_ = np.histogramdd(Xg,bins=nbins,range=[[0,1],[0,1],[0,1]],normed=True) p_dat,_ = np.histogramdd(Xd,bins=nbins,range=[[0,1],[0,1],[0,1]],normed=True) p_gen/=nbins**3 p_dat/=nbins**3 tvd=0.5*np.sum(np.abs( p_gen-p_dat )) mvd=np.max(np.abs( p_gen-p_dat )) return step,tvd, mvd s_tvd=make_summary(Data.name+'_tvd',tvd) s_mvd=make_summary(Data.name+'_mvd',mvd) return step,s_tvd,s_mvd #return make_summary('tvd/'+Generator.name,tvd) def summary_stats(name,tensor,hist=False): ave=tf.reduce_mean(tensor) std=tf.sqrt(tf.reduce_mean(tf.square(ave-tensor))) tf.summary.scalar(name+'_ave',ave) tf.summary.scalar(name+'_std',std) if hist: tf.summary.histogram(name+'_hist',tensor) def summary_scatterplots(X1,X2,X3): with tf.name_scope('scatter'): img1=summary_scatter2d(X1,X2,'X1X2',xlabel='X1',ylabel='X2') img2=summary_scatter2d(X1,X3,'X1X3',xlabel='X1',ylabel='X3') img3=summary_scatter2d(X2,X3,'X2X3',xlabel='X2',ylabel='X3') plt.close() return img1,img2,img3 def summary_scatter2d(x,y,title='2dscatterplot',xlabel=None,ylabel=None): fig=scatter2d(x,y,title,xlabel=xlabel,ylabel=ylabel) fig.canvas.draw() rgb=fig.canvas.tostring_rgb() buf=np.fromstring(rgb,dtype=np.uint8) w,h = fig.canvas.get_width_height() img=buf.reshape(1,h,w,3) #summary=tf.summary.image(title,img) plt.close(fig) #fig.clf() return img def scatter2d(x,y,title='2dscatterplot',xlabel=None,ylabel=None): fig=plt.figure() plt.scatter(x,y) plt.title(title) if xlabel: plt.xlabel(xlabel) if ylabel: plt.ylabel(ylabel) if not 0<=np.min(x)<=np.max(x)<=1: raise ValueError('summary_scatter2d title:',title,' input x exceeded [0,1] range.\ min:',np.min(x),' max:',np.max(x)) if not 0<=np.min(y)<=np.max(y)<=1: raise ValueError('summary_scatter2d title:',title,' input y exceeded [0,1] range.\ min:',np.min(y),' max:',np.max(y)) plt.xlim([0,1]) plt.ylim([0,1]) return fig def prepare_dirs_and_logger(config): formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s") logger = logging.getLogger() for hdlr in logger.handlers: logger.removeHandler(hdlr) handler = logging.StreamHandler() handler.setFormatter(formatter) logger.addHandler(handler) if config.load_path: if config.load_path.startswith(config.log_dir): config.model_dir = config.load_path else: if config.load_path.startswith(config.dataset): config.model_name = config.load_path else: config.model_name = "{}_{}".format(config.dataset, config.load_path) else: config.model_name = "{}_{}".format(config.dataset, get_time()) if not hasattr(config, 'model_dir'): config.model_dir = os.path.join(config.log_dir, config.model_name) config.data_path = os.path.join(config.data_dir, config.dataset) if config.is_train: config.log_code_dir=os.path.join(config.model_dir,'code') for path in [config.log_dir, config.data_dir, config.model_dir, config.log_code_dir]: if not os.path.exists(path): os.makedirs(path) #Copy python code in directory into model_dir/code for future reference: code_dir=os.path.dirname(os.path.realpath(sys.argv[0])) model_files = [f for f in listdir(code_dir) if isfile(join(code_dir, f))] for f in model_files: if f.endswith('.py'): shutil.copy2(f,config.log_code_dir) def get_time(): return datetime.now().strftime("%m%d_%H%M%S") def save_config(config): param_path = os.path.join(config.model_dir, "params.json") print("[*] MODEL dir: %s" % config.model_dir) print("[*] PARAM path: %s" % param_path) with open(param_path, 'w') as fp: json.dump(config.__dict__, fp, indent=4, sort_keys=True) class Timer(object): def __init__(self): self.total_section_time=0. self.iter=0 def on(self): self.t0=time.time() def off(self): self.total_section_time+=time.time()-self.t0 self.iter+=1 def __str__(self): n_min=self.total_section_time/60. return '%.2fmin'%n_min