"""Train resnet50 on ILSVRC2017 Data using homemade scripts.""" import tensorflow as tf from multiprocessing import Process, Queue import os import sys FILE_DIR = os.path.dirname(__file__) sys.path.append(FILE_DIR + '/../') import config as cfg from img_dataset.ilsvrc_cls_multithread_scipy import ilsvrc_cls from yolo2_nets import inception_resnet_v2 from yolo2_nets.net_utils import restore_inception_resnet_variables_from_weight from utils.timer import Timer from utils.helpers import add_contrast_on_batch from scipy import ndimage slim = tf.contrib.slim TRAIN_BATCH_SIZE = 18 ############################################################## # Use Inception v3 to generate adversarial examples to train # ############################################################## from cleverhans.attacks import FastGradientMethod import numpy as np from tensorflow.contrib.slim.nets import inception checkpoint_path = "/home/wenxi/Projects/tensorflow_yolo2/weights/inception_v3.ckpt" max_epsilon = 16.0 eps = 2.0 * max_epsilon / 255.0 batch_shape = [TRAIN_BATCH_SIZE, 299, 299, 3] tensorflow_master = "" class InceptionModel(object): def __init__(self, num_classes): self.num_classes = num_classes self.built = False def __call__(self, x_input): """Constructs model and return probabilities for given input.""" reuse = True if self.built else None with slim.arg_scope(inception.inception_v3_arg_scope()): _, end_points = inception.inception_v3( x_input, num_classes=self.num_classes, is_training=False, reuse=reuse) self.built = True output = end_points['Predictions'] probs = output.op.inputs[0] return probs g_inception_v3 = tf.Graph() with g_inception_v3.as_default(): x_input = tf.placeholder(tf.float32, shape=batch_shape) model = InceptionModel(1001) fgsm = FastGradientMethod(model) x_adv = fgsm.generate(x_input, eps=eps, clip_min=-1., clip_max=1.) saver = tf.train.Saver(slim.get_model_variables()) session_creator = tf.train.ChiefSessionCreator( scaffold=tf.train.Scaffold(saver=saver), checkpoint_filename_with_path=checkpoint_path, master=tensorflow_master) sess_inception_v3 = tf.train.MonitoredSession( session_creator=session_creator) ############################## # inception resnet predictor # ############################## inception_imagenet_labels = ['background'] with open(os.path.join(cfg.SRC_DIR, 'img_dataset', 'imagenet_lsvrc_2015_synsets.txt'), 'r') as f: for line in f.readlines(): if line.strip(): inception_imagenet_labels.append(line.strip()) assert len(inception_imagenet_labels) == 1001 synset_to_ind = dict(list(zip(inception_imagenet_labels, list(range(1001))))) def get_validation_process(imdb, queue_in, queue_out): """Get validation dataset. Run in a child process.""" while True: queue_in.get() images, labels = imdb.get() queue_out.put([images, labels]) # NOTE: check fix the data imread and label imdb = ilsvrc_cls('train', multithread=cfg.MULTITHREAD, batch_size=TRAIN_BATCH_SIZE, image_size=299, random_noise=True) val_imdb = ilsvrc_cls('val', batch_size=18, image_size=299, random_noise=True) # set up child process for getting validation data queue_in = Queue() queue_out = Queue() val_data_process = Process(target=get_validation_process, args=(val_imdb, queue_in, queue_out)) val_data_process.start() queue_in.put(True) # start getting the first batch CKPTS_DIR = cfg.get_ckpts_dir('inception_resnet', imdb.name) TENSORBOARD_TRAIN_DIR, TENSORBOARD_VAL_DIR = cfg.get_output_tb_dir( 'inception_resnet', imdb.name) TENSORBOARD_TRAIN_ADV_DIR = os.path.abspath(os.path.join( cfg.ROOT_DIR, 'tensorboard', 'inception_resnet', imdb.name, 'train_adv')) if not os.path.exists(TENSORBOARD_TRAIN_ADV_DIR): os.makedirs(TENSORBOARD_TRAIN_ADV_DIR) TENSORBOARD_VAL_ADV_DIR = os.path.abspath(os.path.join( cfg.ROOT_DIR, 'tensorboard', 'inception_resnet', imdb.name, 'val_adv')) if not os.path.exists(TENSORBOARD_VAL_ADV_DIR): os.makedirs(TENSORBOARD_VAL_ADV_DIR) g_inception_resnet = tf.Graph() with g_inception_resnet.as_default(): input_data = tf.placeholder(tf.float32, [None, 299, 299, 15]) label_data = tf.placeholder(tf.int32, None) is_training = tf.placeholder(tf.bool) with slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()): # NOTE: check fix the number of classes logits, end_points = inception_resnet_v2.inception_resnet_v2(input_data, num_classes=1001, is_training=is_training) loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=label_data, logits=logits) loss = tf.reduce_mean(loss) # NOTE: check fix the variable to train vars_to_train_trf = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionResnetV2/Conv2d_tr_3x3') vars_to_train_1 = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionResnetV2/Conv2d_1a_3x3') vars_to_train_2 = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionResnetV2/Conv2d_2a_3x3') # vars_to_train_3 = tf.get_collection( # tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionResnetV2/Conv2d_2b_3x3') # vars_to_train_4 = tf.get_collection( # tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionResnetV2/Conv2d_3b_1x1') # vars_to_train_5 = tf.get_collection( # tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionResnetV2/Conv2d_4a_3x3') # assert len(vars_to_train_1) != 0 # assert len(vars_to_train_2) != 0 # assert len(vars_to_train_trf) != 0 # print "###vars to train###:", vars_to_train update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): # train_op_trf = tf.train.AdamOptimizer().minimize(loss, var_list=vars_to_train_trf) train_op_1 = tf.train.AdamOptimizer( 0.00001).minimize(loss, var_list=(vars_to_train_1 + vars_to_train_2)) # train_op_2 = tf.train.AdamOptimizer( # 0.00005).minimize(loss, var_list=vars_to_train_2) train_op_2 = tf.train.AdamOptimizer().minimize(loss, var_list=(vars_to_train_trf)) # train_op_4 = tf.train.AdamOptimizer( # 0.00001).minimize(loss, var_list=vars_to_train_4) # train_op_5 = tf.train.AdamOptimizer( # 0.00001).minimize(loss, var_list=vars_to_train_5) train_op = tf.group(train_op_1, train_op_2) correct_pred = tf.equal( tf.cast(tf.argmax(logits, 1), tf.int32), label_data) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) tf.summary.scalar('loss', loss) tf.summary.scalar('accuracy', accuracy) ###################### # Initialize Session # ###################### tfconfig = tf.ConfigProto(allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True sess_inception_resent = tf.Session(config=tfconfig) # TODO: fix restore # old_epoch = restore_inception_resnet_variables_from_weight( # sess_inception_resent, os.path.join(cfg.WEIGHTS_PATH, 'ens_adv_inception_resnet_v2.ckpt')) # imdb.epoch = old_epoch + 1 # TODO: not sure why adam needs to be reinitialized adam_vars = [var for var in tf.global_variables() if 'Adam' in var.name or 'beta1_power' in var.name or 'beta2_power' in var.name] uninit_vars = adam_vars \ + tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionResnetV2/Conv2d_tr_3x3') init_op = tf.variables_initializer(uninit_vars) variables_to_restore = slim.get_variables_to_restore() for var in uninit_vars: if var in variables_to_restore: variables_to_restore.remove(var) ckpt_file = os.path.join(CKPTS_DIR, "train_iter_47500.ckpt") print 'Restorining model snapshots from {:s}'.format(ckpt_file) saver = tf.train.Saver(variables_to_restore) sess_inception_resent.run(init_op) saver.restore(sess_inception_resent, ckpt_file) print 'Restored.' merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter(TENSORBOARD_TRAIN_DIR) val_writer = tf.summary.FileWriter(TENSORBOARD_VAL_DIR) train_adv_writer = tf.summary.FileWriter(TENSORBOARD_TRAIN_ADV_DIR) val_adv_writer = tf.summary.FileWriter(TENSORBOARD_VAL_ADV_DIR) # simple model saver cur_saver = tf.train.Saver() T = Timer() for i in range(47500, 200000): T.tic() images, labels = imdb.get() # images = ndimage.gaussian_filter( # images, sigma=(0, 1, 1, 0), order=0, truncate=2.0) contrast_images = add_contrast_on_batch(images) labels = [synset_to_ind[imdb.classes[int(item)]] for item in labels] _, loss_value, acc_value, train_summary = sess_inception_resent.run( [train_op, loss, accuracy, merged], {input_data: contrast_images, label_data: labels, is_training: 1}) _time = T.toc(average=False) print('iter {:d}, training loss: {:.3}, training acc: {:.3}, take {:.2}s' .format(i + 1, loss_value, acc_value, _time)) # Training on fgsm on inception v3 adversarial examples T.tic() nontargeted_images = sess_inception_v3.run( x_adv, feed_dict={x_input: images}) adv_contrast_images = add_contrast_on_batch(nontargeted_images) _, loss_adv_value, acc_adv_value, train_adv_summary = sess_inception_resent.run( [train_op, loss, accuracy, merged], {input_data: adv_contrast_images, label_data: labels, is_training: 1}) _time = T.toc(average=False) print('iter {:d}, adv training loss: {:.3}, adv training acc: {:.3}, take {:.2}s' .format(i + 1, loss_adv_value, acc_adv_value, _time)) if (i + 1) % 25 == 0: T.tic() val_images, val_labels = queue_out.get() # val_images = ndimage.gaussian_filter( # val_images, sigma=(0, 1, 1, 0), order=0, truncate=2.0) val_contrast_images = add_contrast_on_batch(val_images) val_labels = [ synset_to_ind[imdb.classes[int(item)]] for item in val_labels] val_loss_value, val_acc_value, val_summary = sess_inception_resent.run( [loss, accuracy, merged], {input_data: val_contrast_images, label_data: val_labels, is_training: 0}) _val_time = T.toc(average=False) print('###validation loss: {:.3}, validation acc: {:.3}, take {:.2}s' .format(val_loss_value, val_acc_value, _val_time)) nontargeted_val_images = sess_inception_v3.run( x_adv, feed_dict={x_input: val_images}) val_adv_contrast_images = add_contrast_on_batch(nontargeted_val_images) val_adv_loss_value, val_adv_acc_value, val_adv_summary = sess_inception_resent.run( [loss, accuracy, merged], {input_data: val_adv_contrast_images, label_data: val_labels, is_training: 0}) _val_time = T.toc(average=False) print('###adv validation loss: {:.3}, adv validation acc: {:.3}, take {:.2}s' .format(val_adv_loss_value, val_adv_acc_value, _val_time)) queue_in.put(True) global_step = i + 1 train_writer.add_summary(train_summary, global_step) train_adv_writer.add_summary(train_adv_summary, global_step) val_writer.add_summary(val_summary, global_step) val_adv_writer.add_summary(val_adv_summary, global_step) if ((i + 1) % 2500 == 0): save_path = cur_saver.save(sess_inception_resent, os.path.join( CKPTS_DIR, cfg.TRAIN_SNAPSHOT_PREFIX + '_iter_' + str(i + 1) + '.ckpt')) print("Model saved in file: %s" % save_path) # terminate child processes if cfg.MULTITHREAD: imdb.close_all_processes() queue_in.cancel_join_thread() queue_out.cancel_join_thread() val_data_process.terminate()