# Noel C. F. Codella # Example Triplet Loss Code for Keras / TensorFlow # Implementing Improved Triplet Loss from: # Zhang et al. "Tracking Persons-of-Interest via Adaptive Discriminative Features" ECCV 2016 # Got help from multiple web sources, including: # 1) https://stackoverflow.com/questions/47727679/triplet-model-for-image-retrieval-from-the-keras-pretrained-network # 2) https://ksaluja15.github.io/Learning-Rate-Multipliers-in-Keras/ # 3) https://keras.io/preprocessing/image/ # 4) https://github.com/keras-team/keras/issues/3386 # 5) https://github.com/keras-team/keras/issues/8130 # GLOBAL DEFINES T_G_WIDTH = 224 T_G_HEIGHT = 224 T_G_NUMCHANNELS = 3 T_G_SEED = 1337 # Misc. Necessities import sys import ssl # these two lines solved issues loading pretrained model ssl._create_default_https_context = ssl._create_unverified_context import numpy as np import matplotlib.pyplot as plt import cv2 from scipy.misc import imresize np.random.seed(T_G_SEED) # TensorFlow Includes import tensorflow as tf #from tensorflow.contrib.losses import metric_learning tf.set_random_seed(T_G_SEED) # Keras Imports & Defines import keras import keras.applications from keras import backend as K from keras.models import Model from keras import optimizers import keras.layers as kl from keras.preprocessing.image import ImageDataGenerator # Generator object for data augmentation. # Can change values here to affect augmentation style. datagen = ImageDataGenerator( rotation_range=90, width_shift_range=0.05, height_shift_range=0.05, zoom_range=0.1, horizontal_flip=True, vertical_flip=True, ) # Local Imports from LR_SGD import LR_SGD # generator function for data augmentation def createDataGen(X1, X2, X3, Y, b): local_seed = T_G_SEED genX1 = datagen.flow(X1,Y, batch_size=b, seed=local_seed, shuffle=False) genX2 = datagen.flow(X2,Y, batch_size=b, seed=local_seed, shuffle=False) genX3 = datagen.flow(X3,Y, batch_size=b, seed=local_seed, shuffle=False) while True: X1i = genX1.next() X2i = genX2.next() X3i = genX3.next() yield [X1i[0], X2i[0], X3i[0]], X1i[1] def createModel(emb_size): # Initialize a ResNet50_ImageNet Model resnet_input = kl.Input(shape=(T_G_WIDTH,T_G_HEIGHT,T_G_NUMCHANNELS)) resnet_model = keras.applications.resnet50.ResNet50(weights='imagenet', include_top = False, input_tensor=resnet_input) # New Layers over ResNet50 net = resnet_model.output #net = kl.Flatten(name='flatten')(net) net = kl.GlobalAveragePooling2D(name='gap')(net) #net = kl.Dropout(0.5)(net) net = kl.Dense(emb_size,activation='relu',name='t_emb_1')(net) net = kl.Lambda(lambda x: K.l2_normalize(x,axis=1), name='t_emb_1_l2norm')(net) # model creation base_model = Model(resnet_model.input, net, name="base_model") # triplet framework, shared weights input_shape=(T_G_WIDTH,T_G_HEIGHT,T_G_NUMCHANNELS) input_anchor = kl.Input(shape=input_shape, name='input_anchor') input_positive = kl.Input(shape=input_shape, name='input_pos') input_negative = kl.Input(shape=input_shape, name='input_neg') net_anchor = base_model(input_anchor) net_positive = base_model(input_positive) net_negative = base_model(input_negative) # The Lamda layer produces output using given function. Here its Euclidean distance. positive_dist = kl.Lambda(euclidean_distance, name='pos_dist')([net_anchor, net_positive]) negative_dist = kl.Lambda(euclidean_distance, name='neg_dist')([net_anchor, net_negative]) tertiary_dist = kl.Lambda(euclidean_distance, name='ter_dist')([net_positive, net_negative]) # This lambda layer simply stacks outputs so both distances are available to the objective stacked_dists = kl.Lambda(lambda vects: K.stack(vects, axis=1), name='stacked_dists')([positive_dist, negative_dist, tertiary_dist]) model = Model([input_anchor, input_positive, input_negative], stacked_dists, name='triple_siamese') # Setting up optimizer designed for variable learning rate # Variable Learning Rate per Layers lr_mult_dict = {} last_layer = '' for layer in resnet_model.layers: # comment this out to refine earlier layers # layer.trainable = False # print layer.name lr_mult_dict[layer.name] = 1 # last_layer = layer.name lr_mult_dict['t_emb_1'] = 100 base_lr = 0.0001 momentum = 0.9 v_optimizer = LR_SGD(lr=base_lr, momentum=momentum, decay=0.0, nesterov=False, multipliers = lr_mult_dict) model.compile(optimizer=v_optimizer, loss=triplet_loss, metrics=[accuracy]) return model def triplet_loss(y_true, y_pred): margin = K.constant(1) return K.mean(K.maximum(K.constant(0), K.square(y_pred[:,0,0]) - 0.5*(K.square(y_pred[:,1,0])+K.square(y_pred[:,2,0])) + margin)) def accuracy(y_true, y_pred): return K.mean(y_pred[:,0,0] < y_pred[:,1,0]) def l2Norm(x): return K.l2_normalize(x, axis=-1) def euclidean_distance(vects): x, y = vects return K.sqrt(K.maximum(K.sum(K.square(x - y), axis=1, keepdims=True), K.epsilon())) # loads an image and preprocesses def t_read_image(loc): t_image = cv2.imread(loc) t_image = cv2.resize(t_image, (T_G_HEIGHT,T_G_WIDTH)) t_image = t_image.astype("float32") t_image = keras.applications.resnet50.preprocess_input(t_image, data_format='channels_last') return t_image # loads a set of images from a text index file def t_read_image_list(flist, start, length): with open(flist) as f: content = f.readlines() content = [x.strip().split()[0] for x in content] datalen = length if (datalen < 0): datalen = len(content) if (start + datalen > len(content)): datalen = len(content) - start imgset = np.zeros((datalen, T_G_HEIGHT, T_G_WIDTH, T_G_NUMCHANNELS)) for i in range(start, start+datalen): if ((i-start) < len(content)): imgset[i-start] = t_read_image(content[i]) return imgset def file_numlines(fn): with open(fn) as f: return sum(1 for _ in f) def main(argv): if len(argv) < 2: print 'Usage: \n\t -learn <Train Anchors (TXT)> <Train Positives (TXT)> <Train Negatives (TXT)> <Val Anchors (TXT)> <Val Positives (TXT)> <Val Negatives (TXT)> <embedding size> <batch size> <num epochs> <output model prefix> \n\t -extract <Model Prefix> <Input Image List (TXT)> <Output File (TXT)> \n\t\tBuilds and scores a triplet-loss model ' return if 'learn' in argv[0]: learn(argv[1:]) elif 'extract' in argv[0]: extract(argv[1:]) return def extract(argv): if len(argv) < 3: print 'Usage: \n\t <Model Prefix> <Input Image List (TXT)> <Output File (TXT)> \n\t\tExtracts triplet-loss model' return modelpref = argv[0] imglist = argv[1] outfile = argv[2] with open(modelpref + '.json', "r") as json_file: model_json = json_file.read() loaded_model = keras.models.model_from_json(model_json) loaded_model.load_weights(modelpref + '.h5') base_model = loaded_model.get_layer('base_model') # create a new single input input_shape=(T_G_WIDTH,T_G_HEIGHT,T_G_NUMCHANNELS) input_single = kl.Input(shape=input_shape, name='input_single') # create a new model without the triple loss net_single = base_model(input_single) model = Model(input_single, net_single, name='embedding_net') chunksize = 1000 total_img = file_numlines(imglist) total_img_ch = int(np.ceil(total_img / float(chunksize))) with open(outfile, 'w') as f_handle: for i in range(0, total_img_ch): imgs = t_read_image_list(imglist, i*chunksize, chunksize) vals = model.predict(imgs) np.savetxt(f_handle, vals) return def learn(argv): if len(argv) < 10: print 'Usage: \n\t <Train Anchors (TXT)> <Train Positives (TXT)> <Train Negatives (TXT)> <Val Anchors (TXT)> <Val Positives (TXT)> <Val Negatives (TXT)> <embedding size> <batch size> <num epochs> <output model> \n\t\tLearns triplet-loss model' return in_t_a = argv[0] in_t_b = argv[1] in_t_c = argv[2] in_v_a = argv[3] in_v_b = argv[4] in_v_c = argv[5] emb_size = int(argv[6]) batch = int(argv[7]) numepochs = int(argv[8]) outpath = argv[9] # chunksize is the number of images we load from disk at a time chunksize = batch*100 total_t = file_numlines(in_t_a) total_v = file_numlines(in_v_b) total_t_ch = int(np.ceil(total_t / float(chunksize))) total_v_ch = int(np.ceil(total_v / float(chunksize))) print 'Dataset has ' + str(total_t) + ' training triplets, and ' + str(total_v) + ' validation triplets.' print 'Creating a model ...' model = createModel(emb_size) print 'Training loop ...' # manual loop over epochs to support very large sets of triplets for e in range(0, numepochs): for t in range(0, total_t_ch): print 'Epoch ' + str(e) + ': train chunk ' + str(t+1) + '/ ' + str(total_t_ch) + ' ...' print 'Reading image lists ...' anchors_t = t_read_image_list(in_t_a, t*chunksize, chunksize) positives_t = t_read_image_list(in_t_b, t*chunksize, chunksize) negatives_t = t_read_image_list(in_t_c, t*chunksize, chunksize) Y_train = np.random.randint(2, size=(1,2,anchors_t.shape[0])).T print 'Starting to fit ...' # This method does NOT use data augmentation # model.fit([anchors_t, positives_t, negatives_t], Y_train, epochs=numepochs, batch_size=batch) # This method uses data augmentation model.fit_generator(generator=createDataGen(anchors_t,positives_t,negatives_t,Y_train,batch), steps_per_epoch=len(Y_train) / batch, epochs=1, shuffle=False, use_multiprocessing=True) # In case the validation images don't fit in memory, we load chunks from disk again. val_res = [0.0, 0.0] total_w = 0.0 for v in range(0, total_v_ch): print 'Loading validation image lists ...' print 'Epoch ' + str(e) + ': val chunk ' + str(v+1) + '/ ' + str(total_v_ch) + ' ...' anchors_v = t_read_image_list(in_v_a, v*chunksize, chunksize) positives_v = t_read_image_list(in_v_b, v*chunksize, chunksize) negatives_v = t_read_image_list(in_v_c, v*chunksize, chunksize) Y_val = np.random.randint(2, size=(1,2,anchors_v.shape[0])).T # Weight of current validation measurement. # if loaded expected number of items, this will be 1.0, otherwise < 1.0, and > 0.0. w = float(anchors_v.shape[0]) / float(chunksize) total_w = total_w + w curval = model.evaluate([anchors_v, positives_v, negatives_v], Y_val, batch_size=batch) val_res[0] = val_res[0] + w*curval[0] val_res[1] = val_res[1] + w*curval[1] val_res = [x / total_w for x in val_res] print 'Validation Results: ' + str(val_res) print 'Saving model ...' # Save the model and weights model.save(outpath + '.h5') # Due to some remaining Keras bugs around loading custom optimizers # and objectives, we save the model architecture as well model_json = model.to_json() with open(outpath + '.json', "w") as json_file: json_file.write(model_json) return # Main Driver if __name__ == "__main__": main(sys.argv[1:])