# Arda Mavi import os import numpy from get_dataset import get_dataset from get_model import get_model, save_model from keras.callbacks import ModelCheckpoint, TensorBoard epochs = 15 batch_size = 6 def train_model(model, X, X_test, Y, Y_test): checkpoints = [] if not os.path.exists('Data/Checkpoints/'): os.makedirs('Data/Checkpoints/') checkpoints.append(ModelCheckpoint('Data/Checkpoints/best_weights.h5', monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=True, mode='auto', period=1)) checkpoints.append(TensorBoard(log_dir='Data/Checkpoints/./logs', histogram_freq=1, write_graph=True, write_images=True, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None)) ''' # Creates live data: # For better yield. The duration of the training is extended. from keras.preprocessing.image import ImageDataGenerator generated_data = ImageDataGenerator(rotation_range=30, width_shift_range=0.1, height_shift_range=0.1, zoom_range=0.2) # For include left hand data add: 'horizontal_flip = True' generated_data.fit(X) model.fit_generator(generated_data.flow(X, Y, batch_size=batch_size), steps_per_epoch=X.shape[0]/batch_size, epochs=epochs, validation_data=(X_test, Y_test), callbacks=checkpoints) ''' model.fit(X, Y, batch_size=batch_size, epochs=epochs, validation_data=(X_test, Y_test), shuffle=True, callbacks=checkpoints) return model def main(): X, X_test, Y, Y_test = get_dataset() model = get_model() model = train_model(model, X, X_test, Y, Y_test) save_model(model) return model if __name__ == '__main__': main()