from __future__ import print_function import tensorflow as tf import numpy as np import logging class VGG(object): def __init__(self, name, include_top=False, weights='imagenet'): with tf.variable_scope(name, reuse=tf.AUTO_REUSE) as scope: if name.upper() == 'VGG19': self.vgg = tf.keras.applications.VGG19(include_top=include_top, weights=weights) elif name.upper() == 'VGG16': self.vgg = tf.keras.applications.VGG16(include_top=include_top, weights=weights) else: raise TypeError('Not supported model: VGG{}'.format(name)) self.model = tf.keras.Model(inputs=self.vgg.input, outputs = self.vgg.get_layer('block3_conv3').output) self.model.trainable=False print(" [*] ", name, " model was created") def get_pair_feature(self, gen_img, real_img): assert gen_img.shape.as_list() == real_img.shape.as_list() batch_num = gen_img.shape.as_list()[0] pair = tf.concat([gen_img, real_img], axis=0) output = self.model(pair) gen_feat, real_feat = output[:batch_num,:,:,:], output[batch_num:,:,:,:] return gen_feat, real_feat if __name__=='__main__': model = VGG('vgg19') vars = tf.trainable_variables() for i, var in enumerate(vars): print(i,"-th variable: ", var) print(model.get_feature(np.ones([1,256,256,3])))