import tensorflow as tf

class CNN:

    ''' This function initializes the Convolutional Neural Network (CNN) '''
    def __init__(self):
        self.model = tf.keras.models.Sequential()
        self.modeltrained = False
        self.modelbuilt = False

    '''This function builds the CNN and compiles it'''
    def build_and_compile_model(self):
        if self.modelbuilt:
            return
        # Add a Convolutional layer
        self.model.add(tf.keras.layers.Conv2D(32, (3, 3), input_shape=(28, 28, 1), activation='relu'))
        # Add a Max pooling layer
        self.model.add(tf.keras.layers.MaxPool2D())
        # Add the flattened layer
        self.model.add(tf.keras.layers.Flatten())
        # Add the hidden layer
        self.model.add(tf.keras.layers.Dense(512, activation='relu'))
        # Adding a dropout layer
        self.model.add(tf.keras.layers.Dropout(0.2))
        # Add the output layer
        self.model.add(tf.keras.layers.Dense(10, activation='softmax'))
        # Compiling the model
        self.model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
        self.modelbuilt = True

    '''This function loads the Train/Test dataset, trains the model and evaluates it.
    It prints the accuracy attained on the test set in the end'''
    def train_and_evaluate_model(self):
        if not self.modelbuilt:
            raise Exception("Build and train the model first!")
        if self.modeltrained:
            return
        # MNIST object
        mnist = tf.keras.datasets.mnist
        # Loading the Train/Test data
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
        # Reshape to form a 3D Vector
        x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
        # Normalize the train/test dataset
        x_train, x_test = x_train / 255.0, x_test / 255.0
        # Train the model
        self.model.fit(x=x_train, y=y_train, epochs=5)
        # Evaluate the model
        test_loss, test_acc = self.model.evaluate(x=x_test, y=y_test)
        # Print out the model accuracy
        print('\nTest accuracy:', test_acc)
        self.modeltrained = True

    def save_model(self):
        if not self.modelbuilt:
            raise Exception("Build and compile the model first!")
        if not self.modeltrained:
            raise Exception("Train and evaluate the model first!")
        self.model.save("cnn.hdf5", overwrite=True)