from networks import googlenet import tensorflow as tf import scipy.io as sio import numpy as np from sklearn.manifold import TSNE import matplotlib.pyplot as plt BATCH_SIZE = 50 def get_feature(): inputs = tf.placeholder("float", [None, 64, 64, 1]) is_training = tf.placeholder("bool") _, feature = googlenet(inputs, is_training) feature = tf.squeeze(feature, [1, 2]) sess = tf.Session() sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() data = sio.loadmat("../data/dataset.mat") testdata = data["test"] / 127.5 - 1.0 testlabels = data["testlabels"] saver.restore(sess, "../save_para/.\\model.ckpt") nums_test = testdata.shape[0] FEATURE = np.zeros([nums_test, 1024]) for i in range(nums_test // BATCH_SIZE): FEATURE[i * BATCH_SIZE:i * BATCH_SIZE + BATCH_SIZE] = sess.run(feature, feed_dict={inputs: testdata[i * BATCH_SIZE:i * BATCH_SIZE + BATCH_SIZE], is_training: False}) FEATURE[(nums_test // BATCH_SIZE - 1) * BATCH_SIZE + BATCH_SIZE:] = sess.run(feature, feed_dict={inputs: testdata[(nums_test // BATCH_SIZE - 1) * BATCH_SIZE + BATCH_SIZE:], is_training: False}) sio.savemat("../data/feature.mat", {"feature": FEATURE, "testlabels": testlabels}) def tsne(): data = sio.loadmat("../data/feature.mat") feature_test = data["feature"] proj = TSNE().fit_transform(feature_test) sio.savemat("../data/proj.mat", {"proj": proj}) def plot_tsne(): proj = sio.loadmat("../data/proj.mat")["proj"] color = ['darkorchid', 'darkred', 'darksalmon', 'darkseagreen', 'darkslateblue', 'darkslategray', 'darkturquoise', 'darkviolet', 'deeppink', 'deepskyblue', 'dimgray', 'dodgerblue', 'firebrick', 'floralwhite', 'forestgreen', 'fuchsia', 'gainsboro', 'ghostwhite', 'gold', 'goldenrod', 'gray', 'green', 'greenyellow', 'honeydew', 'hotpink', 'indianred', 'indigo', 'ivory', 'khaki', 'lavenderblush'] data = sio.loadmat("../data/dataset.mat") labels = data["testlabels"][0, :] for i in range(30): plt.plot(proj[np.where(labels == i)[0], 0], proj[np.where(labels == i)[0], 1], ".", c=color[i], label="gmt") plt.show() if __name__ == "__main__": plot_tsne()