import tensorflow as tf import os from tensorflow.python.tools import freeze_graph from tensorflow.python.framework import graph_util from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model import utils as saved_model_utils class CNN(object): def __init__(self, input_size, num_classes, optimizer): self.num_classes = num_classes self.input_size = input_size self.optimizer = optimizer self.learning_rate = tf.placeholder(tf.float32, shape=[], name='learning_rate') self.dropout_rate = tf.placeholder(tf.float32, shape=[], name='dropout_rate') self.input = tf.placeholder(tf.float32, [None] + self.input_size, name='input') self.label = tf.placeholder(tf.float32, [None, self.num_classes], name='label') self.output = self.network_initializer() self.loss = self.loss_initializer() self.optimization = self.optimizer_initializer() self.saver = tf.train.Saver() self.sess = tf.Session() self.sess.run(tf.global_variables_initializer()) def network(self, input, dropout_rate): conv1 = tf.layers.conv2d(inputs=input, filters=64, kernel_size=[3, 3], padding='same', activation=tf.nn.relu, name='conv1') conv2 = tf.layers.conv2d(inputs=conv1, filters=64, kernel_size=[3, 3], padding='same', activation=tf.nn.relu, name='conv2') pool1 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=[2, 2], name='pool1') pool1_dropout = tf.layers.dropout(inputs=pool1, rate=dropout_rate, training=True, name='pool1_dropout') conv3 = tf.layers.conv2d(inputs=pool1_dropout, filters=128, kernel_size=[3, 3], padding='same', activation=tf.nn.relu, name='conv3') conv4 = tf.layers.conv2d(inputs=conv3, filters=128, kernel_size=[3, 3], padding='same', activation=tf.nn.relu, name='conv4') pool2 = tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 2], strides=[2, 2], name='pool2') pool2_dropout = tf.layers.dropout(inputs=pool2, rate=dropout_rate, training=True, name='pool2_dropout') conv5 = tf.layers.conv2d(inputs=pool2_dropout, filters=256, kernel_size=[3, 3], padding='same', activation=tf.nn.relu, name='conv5') pool3 = tf.layers.max_pooling2d(inputs=conv5, pool_size=[2, 2], strides=[2, 2], name='pool3') pool3_dropout = tf.layers.dropout(inputs=pool3, rate=dropout_rate, training=True, name='pool3_dropout') flat = tf.layers.flatten(inputs=pool3_dropout, name='flat') fc1 = tf.layers.dense(inputs=flat, units=256, activation=tf.nn.relu, name='fc1') fc1_dropout = tf.layers.dropout(inputs=fc1, rate=dropout_rate, training=True, name='fc1_dropout') fc2 = tf.layers.dense(inputs=fc1_dropout, units=self.num_classes, activation=None, name='fc2') # Give output node a output = tf.identity(fc2, name='output') return output def network_initializer(self): with tf.variable_scope('cnn') as scope: ouput = self.network(input=self.input, dropout_rate=self.dropout_rate) return ouput def loss_initializer(self): with tf.variable_scope('loss') as scope: cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2( labels=self.label, logits=self.output, name='cross_entropy') cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy_mean') return cross_entropy_mean def optimizer_initializer(self): if self.optimizer == 'Adam': optimizer = tf.train.AdamOptimizer( learning_rate=self.learning_rate).minimize(self.loss) else: optimizer = tf.train.GradientDescentOptimizer( learning_rate=self.learning_rate).minimize(self.loss) return optimizer def train(self, data, label, learning_rate, dropout_rate): _, train_loss = self.sess.run( [self.optimization, self.loss], feed_dict={ self.input: data, self.label: label, self.learning_rate: learning_rate, self.dropout_rate: dropout_rate }) return train_loss def validate(self, data, label): output, validate_loss = self.sess.run([self.output, self.loss], feed_dict={ self.input: data, self.label: label, self.dropout_rate: 0.0 }) return output, validate_loss def test(self, data): output = self.sess.run(self.output, feed_dict={ self.input: data, self.dropout_rate: 0.0 }) return output def save(self, directory, filename): if not os.path.exists(directory): os.makedirs(directory) filepath = os.path.join(directory, filename + '.ckpt') self.saver.save(self.sess, filepath) return filepath def save_signature(self, directory): signature = signature_def_utils.build_signature_def( inputs={ 'input': saved_model_utils.build_tensor_info(self.input), 'dropout_rate': saved_model_utils.build_tensor_info(self.dropout_rate) }, outputs={ 'output': saved_model_utils.build_tensor_info(self.output) }, method_name=signature_constants.PREDICT_METHOD_NAME) signature_map = { signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature } model_builder = saved_model_builder.SavedModelBuilder(directory) model_builder.add_meta_graph_and_variables( self.sess, tags=[tag_constants.SERVING], signature_def_map=signature_map, clear_devices=True) model_builder.save(as_text=False) def save_as_pb(self, directory, filename): if not os.path.exists(directory): os.makedirs(directory) # Save check point for graph frozen later ckpt_filepath = self.save(directory=directory, filename=filename) pbtxt_filename = filename + '.pbtxt' pbtxt_filepath = os.path.join(directory, pbtxt_filename) pb_filepath = os.path.join(directory, filename + '.pb') # This will only save the graph but the variables will not be saved. # You have to freeze your model first. tf.train.write_graph(graph_or_graph_def=self.sess.graph_def, logdir=directory, name=pbtxt_filename, as_text=True) # Freeze graph # Method 1 freeze_graph.freeze_graph(input_graph=pbtxt_filepath, input_saver='', input_binary=False, input_checkpoint=ckpt_filepath, output_node_names='cnn/output', restore_op_name='save/restore_all', filename_tensor_name='save/Const:0', output_graph=pb_filepath, clear_devices=True, initializer_nodes='') # Method 2 ''' graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() output_node_names = ['cnn/output'] output_graph_def = graph_util.convert_variables_to_constants(self.sess, input_graph_def, output_node_names) with tf.gfile.GFile(pb_filepath, 'wb') as f: f.write(output_graph_def.SerializeToString()) ''' return pb_filepath def load(self, filepath): if os.path.splitext(filepath)[1] != '.ckpt': filepath += '.ckpt' self.saver.restore(self.sess, filepath)