#Modified from https://github.com/pfnet-research/chainer-gan-lib/blob/master/common/evaluation.py import os import sys import math import numpy as np import scipy.linalg import pickle import chainer import chainer.cuda from chainer import Variable from chainer import serializers import chainer.functions as F import tqdm from .utils import get_classifer def calc_FID(m0, c0, m1, c1): ret = 0 ret += np.sum((m0-m1)**2) ret += np.trace(c0 + c1 - 2.0 * scipy.linalg.sqrtm(np.dot(c0, c1))) return np.real(ret) def fid_extension(fidapi, generate_func, seed=None, report_key='FID', verbose=True): @chainer.training.make_extension() def calc(trainer): if verbose: print('Running FID...') fidapi.calc_fake(generate_func, seed) fid = fidapi.calc_FID() chainer.report({report_key: fid}) if verbose: print(report_key + ' Value: ', fid) return calc class API: def __init__(self, clsf_type, clsf_path, gpu, load_real_stat=None, n_batches=4000, batch_size=5): self.cnn, self.input_args = get_classifer(clsf_type, clsf_path) if gpu >= 0: self.cnn.to_gpu(gpu) print("Send to GPU ", gpu) self.n_batches = n_batches self.batch_size = batch_size self.features = {} if load_real_stat is not None: self.load_real_statistics(load_real_stat) def get_mean_cov(self, seed, get_image_func=None, n_batches=None): xp = self.cnn.xp batch_size = self.batch_size n_batches = self.n_batches if n_batches is None else n_batches if seed is not None: np.random.seed(seed) result = [] print("Calculating FID Features...") for i in tqdm.tqdm(range(n_batches)): # should return numpy array on CPU imgs = get_image_func(batch_size) imgs = xp.asarray(imgs) imgs = Variable(imgs) # Feed images to the inception module to get the features with chainer.using_config('train', False): with chainer.using_config('enable_backprop', False): y = self.cnn(imgs, **self.input_args) result.append(y.data.get()) if seed is not None: np.random.seed() result = np.asarray(result) result = result.reshape(batch_size*n_batches, result.shape[2]) mean = np.mean(result, axis=0) cov = np.cov(result.T) return mean, cov def init_real(self, generate_func, seed=None, n_batches=None): mean, cov = self.get_mean_cov(seed, generate_func, n_batches=n_batches) real_feature = { 'mean': mean, 'cov': cov } self.features['real'] = real_feature def calc_fake(self, generate_func, seed=None, n_batches=None): mean, cov = self.get_mean_cov(seed, generate_func, n_batches=n_batches) fake_feature = { 'mean': mean, 'cov': cov } self.features['fake'] = fake_feature def calc_FID(self): assert 'fake' in self.features and 'real' in self.features return calc_FID(self.features['fake']['mean'], self.features['fake']['cov'], self.features['real']['mean'], self.features['real']['cov']) def save_real_statistics(self, path): with open(path, 'wb') as f: pickle.dump(self.features['real'], f) def load_real_statistics(self, path): with open(path, 'rb') as f: self.features['real'] = pickle.load(f)