import keras import keras.backend as K from keras.layers import Dense, Conv2D, Input, MaxPooling2D, Flatten from keras.models import Model from keras.datasets import fashion_mnist from keras.callbacks import ModelCheckpoint # setup parameters batch_sz = 128 nb_class = 10 nb_epochs = 10 img_h, img_w = 28, 28 print( K.image_data_format()) # input image dimensions img_rows, img_cols = 28, 28 def get_dataset(): """ Return processed and reshaped dataset for training In this cases Fashion-mnist dataset. """ # load mnist dataset (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data() # test and train datasets print("Nb Train:", x_train.shape[0], "Nb test:",x_test.shape[0]) x_train = x_train.reshape(x_train.shape[0], img_h, img_w, 1) x_test = x_test.reshape(x_test.shape[0], img_h, img_w, 1) in_shape = (img_h, img_w, 1) # normalize inputs x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255.0 x_test /= 255.0 # convert to one hot vectors y_train = keras.utils.to_categorical(y_train, nb_class) y_test = keras.utils.to_categorical(y_test, nb_class) return x_train, x_test, y_train, y_test x_train, x_test, y_train, y_test = get_dataset() def create_model(img_h=28, img_w=28): inputs = Input(shape=(img_h, img_w, 1)) x = Conv2D(32, kernel_size=(3,3), activation='relu')(inputs) # 32C 3K 1S VP RELU x = Conv2D(32, kernel_size=(3,3), activation='relu')(x) # 64C 3K 1S VP RELU x = MaxPooling2D(pool_size=(2,2))(x) # pool2 x = Conv2D(64, kernel_size=(3,3), activation='relu')(x) # 32C 3K 1S VP RELU x = Conv2D(64, kernel_size=(3,3), activation='relu')(x) # 64C 3K 1S VP RELU x = MaxPooling2D(pool_size=(2,2))(x) # pool2 x = Flatten()(x) preds = Dense(nb_class, activation='softmax')(x) model = Model(inputs=inputs, outputs=preds) print(model.summary()) return model model = create_model() model.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.SGD(lr=0.001), metrics=['accuracy']) callback = ModelCheckpoint() # start training model.fit(x_train, y_train, batch_size=batch_sz, epochs=nb_epochs, verbose=1, validation_data=(x_test, y_test), callbacks=[callback]) # Evaluate score = model.evaluate(x_test, y_test, verbose=0) print('Test loss:', score[0]) print('Test accuracy:', score[1])