""" The classifier of linked dynamic graph CNN. @author: Kuangen Zhang """ import tensorflow as tf import numpy as np import sys import os BASE_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append(BASE_DIR) sys.path.append(os.path.join(BASE_DIR, '../utils')) import tf_util def placeholder_inputs(batch_size, num_feature): pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, num_feature)) labels_pl = tf.placeholder(tf.int32, shape=(batch_size)) return pointclouds_pl, labels_pl def get_model(feature, is_training, bn_decay=None): # Fully connected layers: classifier layers = {} feature = tf.squeeze(feature) layer_name = 'ft_' # B: batch size; C: channels; # feature: B*C # net: B*512 net = tf_util.fully_connected(feature, 512, bn=True, is_training=is_training, scope=layer_name + 'fc2', bn_decay=bn_decay, activation_fn = tf.nn.relu) layers[layer_name + 'fc2'] = net net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope=layer_name + 'dp2') # net: B*256 net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, scope=layer_name + 'fc3', bn_decay=bn_decay, activation_fn = tf.nn.relu) layers[layer_name + 'fc3'] = net net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope=layer_name + 'dp3') # net: B*40 net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc4') layers[layer_name + 'fc4'] = net return net, layers def get_loss(pred, label): """ pred: B*NUM_CLASSES, label: B, """ # Change the label from an integer to the one_hot vector. labels = tf.one_hot(indices=label, depth=40) # Calculate the loss based on cross entropy method. loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=pred, label_smoothing=0.2) # Calculate the mean loss of a batch input. classify_loss = tf.reduce_mean(loss) return classify_loss if __name__=='__main__': batch_size = 2 num_pt = 124 pos_dim = 3 input_feed = np.random.rand(batch_size, num_pt, pos_dim) label_feed = np.random.rand(batch_size) label_feed[label_feed>=0.5] = 1 label_feed[label_feed<0.5] = 0 label_feed = label_feed.astype(np.int32) with tf.Graph().as_default(): input_pl, label_pl = placeholder_inputs(batch_size, num_pt) pos, ftr = get_model(input_pl, tf.constant(True)) # loss = get_loss(logits, label_pl, None) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) feed_dict = {input_pl: input_feed, label_pl: label_feed} res1, res2 = sess.run([pos, ftr], feed_dict=feed_dict) print (res1.shape) print (res1) print (res2.shape) print (res2)