import tensorflow as tf import numpy as np import matplotlib # OSX fix matplotlib.use('TkAgg') import matplotlib.pyplot as plt import seaborn as sns from layers import conv_layer from config import * from utils import init_weights, gen_data class CryptoNet(object): def __init__(self, sess, msg_len=MSG_LEN, batch_size=BATCH_SIZE, epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE): """ Args: sess: TensorFlow session msg_len: The length of the input message to encrypt. key_len: Length of Alice and Bob's private key. batch_size: Minibatch size for each adversarial training epochs: Number of epochs in the adversarial training learning_rate: Learning Rate for Adam Optimizer """ self.sess = sess self.msg_len = msg_len self.key_len = self.msg_len self.N = self.msg_len self.batch_size = batch_size self.epochs = epochs self.learning_rate = learning_rate self.build_model() def build_model(self): # Weights for fully connected layers self.w_alice = init_weights("alice_w", [2 * self.N, 2 * self.N]) self.w_bob = init_weights("bob_w", [2 * self.N, 2 * self.N]) self.w_eve1 = init_weights("eve_w1", [self.N, 2 * self.N]) self.w_eve2 = init_weights("eve_w2", [2 * self.N, 2 * self.N]) # Placeholder variables for Message and Key self.msg = tf.placeholder("float", [None, self.msg_len]) self.key = tf.placeholder("float", [None, self.key_len]) # Alice's network # FC layer -> Conv Layer (4 1-D convolutions) self.alice_input = tf.concat([self.msg, self.key],1) self.alice_hidden = tf.nn.sigmoid(tf.matmul(self.alice_input, self.w_alice)) self.alice_hidden = tf.expand_dims(self.alice_hidden, 2) self.alice_output = tf.squeeze(conv_layer(self.alice_hidden, "alice")) # Bob's network # FC layer -> Conv Layer (4 1-D convolutions) self.bob_input = tf.concat([self.alice_output, self.key],1) self.bob_hidden = tf.nn.sigmoid(tf.matmul(self.bob_input, self.w_bob)) self.bob_hidden = tf.expand_dims(self.bob_hidden, 2) self.bob_output = tf.squeeze(conv_layer(self.bob_hidden, "bob")) # Eve's network # FC layer -> FC layer -> Conv Layer (4 1-D convolutions) self.eve_input = self.alice_output self.eve_hidden1 = tf.nn.sigmoid(tf.matmul(self.eve_input, self.w_eve1)) self.eve_hidden2 = tf.nn.sigmoid(tf.matmul(self.eve_hidden1, self.w_eve2)) self.eve_hidden2 = tf.expand_dims(self.eve_hidden2, 2) self.eve_output = tf.squeeze(conv_layer(self.eve_hidden2, "eve")) def train(self): # Loss Functions self.decrypt_err_eve = tf.reduce_mean(tf.abs(self.msg - self.eve_output)) self.decrypt_err_bob = tf.reduce_mean(tf.abs(self.msg - self.bob_output)) self.loss_bob = self.decrypt_err_bob + (1. - self.decrypt_err_eve) ** 2. # Get training variables corresponding to each network self.t_vars = tf.trainable_variables() self.alice_or_bob_vars = [var for var in self.t_vars if 'alice_' in var.name or 'bob_' in var.name] self.eve_vars = [var for var in self.t_vars if 'eve_' in var.name] # Build the optimizers self.bob_optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize( self.loss_bob, var_list=self.alice_or_bob_vars) self.eve_optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize( self.decrypt_err_eve, var_list=self.eve_vars) self.bob_errors, self.eve_errors = [], [] # Begin Training tf.global_variables_initializer().run() for i in range(self.epochs): iterations = 2000 print 'Training Alice and Bob, Epoch:', i + 1 bob_loss, _ = self._train('bob', iterations) self.bob_errors.append(bob_loss) print 'Training Eve, Epoch:', i + 1 _, eve_loss = self._train('eve', iterations) self.eve_errors.append(eve_loss) self.plot_errors() def _train(self, network, iterations): bob_decrypt_error, eve_decrypt_error = 1., 1. bs = self.batch_size # Train Eve for two minibatches to give it a slight computational edge if network == 'eve': bs *= 2 for i in range(iterations): msg_in_val, key_val = gen_data(n=bs, msg_len=self.msg_len, key_len=self.key_len) if network == 'bob': _, decrypt_err = self.sess.run([self.bob_optimizer, self.decrypt_err_bob], feed_dict={self.msg: msg_in_val, self.key: key_val}) bob_decrypt_error = min(bob_decrypt_error, decrypt_err) elif network == 'eve': _, decrypt_err = self.sess.run([self.eve_optimizer, self.decrypt_err_eve], feed_dict={self.msg: msg_in_val, self.key: key_val}) eve_decrypt_error = min(eve_decrypt_error, decrypt_err) return bob_decrypt_error, eve_decrypt_error def plot_errors(self): """ Plot Lowest Decryption Errors achieved by Bob and Eve per epoch """ sns.set_style("darkgrid") plt.plot(self.bob_errors) plt.plot(self.eve_errors) plt.legend(['bob', 'eve']) plt.xlabel('Epoch') plt.ylabel('Lowest Decryption error achieved') plt.show()