#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Sat Aug 26 20:40:38 2017 @author: dhaval """ import os import sys import glob import argparse import matplotlib matplotlib.use('agg') import matplotlib.pyplot as plt from keras import backend as K from keras import __version__ from keras.applications.inception_v3 import InceptionV3, preprocess_input from keras.models import Model from keras.layers import Dense, AveragePooling2D, GlobalAveragePooling2D, Input, Flatten, Dropout from keras.preprocessing.image import ImageDataGenerator from keras.optimizers import SGD IM_WIDTH, IM_HEIGHT = 299, 299 #fixed size for InceptionV3 NB_EPOCHS = 3 BAT_SIZE = 32 FC_SIZE = 1024 #NB_IV3_LAYERS_TO_FREEZE = 172 def get_nb_files(directory): """Get number of files by searching directory recursively""" if not os.path.exists(directory): return 0 cnt = 0 for r, dirs, files in os.walk(directory): for dr in dirs: cnt += len(glob.glob(os.path.join(r, dr + "/*"))) return cnt def setup_to_transfer_learn(model, base_model): """Freeze all layers and compile the model""" for layer in base_model.layers: layer.trainable = False model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy']) def add_new_last_layer(base_model, nb_classes): """Add last layer to the convnet Args: base_model: keras model excluding top nb_classes: # of classes Returns: new keras model with last layer """ x = base_model.output x = AveragePooling2D((8, 8), border_mode='valid', name='avg_pool')(x) x = Dropout(0.4)(x) x = Flatten()(x) predictions = Dense(2, activation='softmax')(x) model = Model(input=base_model.input, output=predictions) return model """ def setup_to_finetune(model): Freeze the bottom NB_IV3_LAYERS and retrain the remaining top layers. note: NB_IV3_LAYERS corresponds to the top 2 inception blocks in the inceptionv3 arch Args: model: keras model for layer in model.layers[:NB_IV3_LAYERS_TO_FREEZE]: layer.trainable = False for layer in model.layers[NB_IV3_LAYERS_TO_FREEZE:]: layer.trainable = True model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy']) """ def train(args): """Use transfer learning and fine-tuning to train a network on a new dataset""" train_img = 'training_set/' validation_img = 'test_set/' nb_epoch = int(args.nb_epoch) nb_train_samples = get_nb_files(train_img) nb_classes = len(glob.glob(train_img + "/*")) # data prep train_datagen = ImageDataGenerator( rotation_range=40, width_shift_range=0.2, height_shift_range=0.2, rescale=1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode='nearest') validation_datagen = ImageDataGenerator( rotation_range=40, width_shift_range=0.2, height_shift_range=0.2, rescale=1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode='nearest') train_generator = train_datagen.flow_from_directory( train_img, target_size=(299, 299), batch_size=32, class_mode='categorical' ) validation_generator = validation_datagen.flow_from_directory( validation_img, target_size=(299, 299), batch_size=32, class_mode='categorical' ) if(K.image_dim_ordering() == 'th'): input_tensor = Input(shape=(3, 299, 299)) else: input_tensor = Input(shape=(299, 299, 3)) # setup model base_model = InceptionV3(input_tensor = input_tensor,weights='imagenet', include_top=False) #include_top=False excludes final FC layer model = add_new_last_layer(base_model, nb_classes) # transfer learning setup_to_transfer_learn(model, base_model) history_tl = model.fit_generator(train_generator, samples_per_epoch=320, nb_epoch=nb_epoch, validation_data=validation_generator, nb_val_samples=64) model.save(args.output_model_file) if args.plot: plot_training(history_tl) def plot_training(history): acc = history.history['acc'] val_acc = history.history['val_acc'] loss = history.history['loss'] val_loss = history.history['val_loss'] epochs = range(len(acc)) plt.plot(epochs, acc, 'r.') plt.plot(epochs, val_acc, 'r') plt.title('Training and validation accuracy') plt.savefig('accuracy.png') plt.figure() plt.plot(epochs, loss, 'r.') plt.plot(epochs, val_loss, 'r-') plt.title('Training and validation loss') plt.savefig('loss.png') if __name__=="__main__": a = argparse.ArgumentParser() a.add_argument("--nb_epoch", default=NB_EPOCHS) a.add_argument("--batch_size", default=BAT_SIZE) a.add_argument("--plot", action="store_true") a.add_argument("--output_model_file", default="inceptionv3-ft.model") args = a.parse_args() train(args)