##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ## Created by: Yaoyao Liu ## Modified from: https://github.com/cbfinn/maml ## Tianjin University ## liuyaoyao@tju.edu.cn ## Copyright (c) 2019 ## ## This source code is licensed under the MIT-style license found in the ## LICENSE file in the root directory of this source tree ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ """ Additional utility functions. """ import numpy as np import os import cv2 import random import tensorflow as tf from matplotlib.pyplot import imread from tensorflow.contrib.layers.python import layers as tf_layers from tensorflow.python.platform import flags FLAGS = flags.FLAGS def get_smallest_k_index(input_, k): """The function to get the smallest k items' indices. Args: input_: the list to be processed. k: the number of indices to return. Return: The index list with k dimensions. """ input_copy = np.copy(input_) k_list = [] for idx in range(k): this_index = np.argmin(input_copy) k_list.append(this_index) input_copy[this_index]=np.max(input_copy) return k_list def one_hot(inp): """The function to make the input to one-hot vectors. Arg: inp: the input numpy array. Return: The reorganized one-shot array. """ n_class = inp.max() + 1 n_sample = inp.shape[0] out = np.zeros((n_sample, n_class)) for idx in range(n_sample): out[idx, inp[idx]] = 1 return out def one_hot_class(inp, n_class): """The function to make the input to n-class one-hot vectors. Args: inp: the input numpy array. n_class: the number of classes. Return: The reorganized n-class one-shot array. """ n_sample = inp.shape[0] out = np.zeros((n_sample, n_class)) for idx in range(n_sample): out[idx, inp[idx]] = 1 return out def process_batch(input_filename_list, input_label_list, dim_input, batch_sample_num): """The function to process a part of an episode. Args: input_filename_list: the image files' directory list. input_label_list: the image files' corressponding label list. dim_input: the dimension number of the images. batch_sample_num: the sample number of the inputed images. Returns: img_array: the numpy array of processed images. label_array: the numpy array of processed labels. """ new_path_list = [] new_label_list = [] for k in range(batch_sample_num): class_idxs = list(range(0, FLAGS.way_num)) random.shuffle(class_idxs) for class_idx in class_idxs: true_idx = class_idx*batch_sample_num + k new_path_list.append(input_filename_list[true_idx]) new_label_list.append(input_label_list[true_idx]) img_list = [] for filepath in new_path_list: this_img = imread(filepath) this_img = np.reshape(this_img, [-1, dim_input]) this_img = this_img / 255.0 img_list.append(this_img) img_array = np.array(img_list).reshape([FLAGS.way_num*batch_sample_num, dim_input]) label_array = one_hot(np.array(new_label_list)).reshape([FLAGS.way_num*batch_sample_num, -1]) return img_array, label_array def process_batch_augmentation(input_filename_list, input_label_list, dim_input, batch_sample_num): """The function to process a part of an episode. All the images will be augmented by flipping. Args: input_filename_list: the image files' directory list. input_label_list: the image files' corressponding label list. dim_input: the dimension number of the images. batch_sample_num: the sample number of the inputed images. Returns: img_array: the numpy array of processed images. label_array: the numpy array of processed labels. """ new_path_list = [] new_label_list = [] for k in range(batch_sample_num): class_idxs = list(range(0, FLAGS.way_num)) random.shuffle(class_idxs) for class_idx in class_idxs: true_idx = class_idx*batch_sample_num + k new_path_list.append(input_filename_list[true_idx]) new_label_list.append(input_label_list[true_idx]) img_list = [] img_list_h = [] for filepath in new_path_list: this_img = imread(filepath) this_img_h = cv2.flip(this_img, 1) this_img = np.reshape(this_img, [-1, dim_input]) this_img = this_img / 255.0 img_list.append(this_img) this_img_h = np.reshape(this_img_h, [-1, dim_input]) this_img_h = this_img_h / 255.0 img_list_h.append(this_img_h) img_list_all = img_list + img_list_h label_list_all = new_label_list + new_label_list img_array = np.array(img_list_all).reshape([FLAGS.way_num*batch_sample_num*2, dim_input]) label_array = one_hot(np.array(label_list_all)).reshape([FLAGS.way_num*batch_sample_num*2, -1]) return img_array, label_array def get_images(paths, labels, nb_samples=None, shuffle=True): """The function to get the image files' directories with given class labels. Args: paths: the base path for the images. labels: the class name labels. nb_samples: the number of samples. shuffle: whether shuffle the generated image list. Return: The list for the image files' directories. """ if nb_samples is not None: sampler = lambda x: random.sample(x, nb_samples) else: sampler = lambda x: x images = [(i, os.path.join(path, image)) \ for i, path in zip(labels, paths) \ for image in sampler(os.listdir(path))] if shuffle: random.shuffle(images) return images def get_pretrain_images(path, label): """The function to get the image files' directories for pre-train phase. Args: paths: the base path for the images. labels: the class name labels. is_val: whether the images are for the validation phase during pre-training. Return: The list for the image files' directories. """ images = [] for image in os.listdir(path): images.append((label, os.path.join(path, image))) return images def get_images_tc(paths, labels, nb_samples=None, shuffle=True, is_val=False): """The function to get the image files' directories with given class labels for pre-train phase. Args: paths: the base path for the images. labels: the class name labels. nb_samples: the number of samples. shuffle: whether shuffle the generated image list. is_val: whether the images are for the validation phase during pre-training. Return: The list for the image files' directories. """ if nb_samples is not None: sampler = lambda x: random.sample(x, nb_samples) else: sampler = lambda x: x if is_val is False: images = [(i, os.path.join(path, image)) \ for i, path in zip(labels, paths) \ for image in sampler(os.listdir(path)[0:500])] else: images = [(i, os.path.join(path, image)) \ for i, path in zip(labels, paths) \ for image in sampler(os.listdir(path)[500:])] if shuffle: random.shuffle(images) return images ## Network helpers def leaky_relu(x, leak=0.1): """The leaky relu function. Args: x: the input feature maps. leak: the parameter for leaky relu. Return: The feature maps processed by non-liner layer. """ return tf.maximum(x, leak*x) def resnet_conv_block(inp, cweight, bweight, reuse, scope, activation=leaky_relu): """The function to forward a conv layer. Args: inp: the input feature maps. cweight: the filters' weights for this conv layer. bweight: the biases' weights for this conv layer. reuse: whether reuse the variables for the batch norm. scope: the label for this conv layer. activation: the activation function for this conv layer. Return: The processed feature maps. """ stride, no_stride = [1,2,2,1], [1,1,1,1] if FLAGS.activation == 'leaky_relu': activation = leaky_relu elif FLAGS.activation == 'relu': activation = tf.nn.relu else: activation = None conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME') + bweight normed = normalize(conv_output, activation, reuse, scope) return normed def resnet_nob_conv_block(inp, cweight, reuse, scope): """The function to forward a conv layer without biases, normalization and non-liner layer. Args: inp: the input feature maps. cweight: the filters' weights for this conv layer. reuse: whether reuse the variables for the batch norm. scope: the label for this conv layer. Return: The processed feature maps. """ stride, no_stride = [1,2,2,1], [1,1,1,1] conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME') return conv_output def normalize(inp, activation, reuse, scope): """The function to forward the normalization. Args: inp: the input feature maps. reuse: whether reuse the variables for the batch norm. scope: the label for this conv layer. activation: the activation function for this conv layer. Return: The processed feature maps. """ if FLAGS.norm == 'batch_norm': return tf_layers.batch_norm(inp, activation_fn=activation, reuse=reuse, scope=scope) elif FLAGS.norm == 'layer_norm': return tf_layers.layer_norm(inp, activation_fn=activation, reuse=reuse, scope=scope) elif FLAGS.norm == 'None': if activation is not None: return activation(inp) return inp else: raise ValueError('Please set correct normalization.') ## Loss functions def mse(pred, label): """The MSE loss function. Args: pred: the predictions. label: the ground truth labels. Return: The Loss. """ pred = tf.reshape(pred, [-1]) label = tf.reshape(label, [-1]) return tf.reduce_mean(tf.square(pred-label)) def softmaxloss(pred, label): """The softmax cross entropy loss function. Args: pred: the predictions. label: the ground truth labels. Return: The Loss. """ return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=label)) def xent(pred, label): """The softmax cross entropy loss function. The losses will be normalized by the shot number. Args: pred: the predictions. label: the ground truth labels. Return: The Loss. Note: with tf version <=0.12, this loss has incorrect 2nd derivatives """ return tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=label) / FLAGS.shot_num