""" Code for the MAML algorithm and network definitions. """
from __future__ import print_function
import numpy as np
import sys
import tensorflow as tf
try:
    import special_grads
except KeyError as e:
    print('WARN: Cannot define MaxPoolGrad, likely already defined for this version of tensorflow: %s' % e,
          file=sys.stderr)

from collections import OrderedDict
from tensorflow.python.platform import flags
from utils import xent, conv_block, normalize, bn_relu_conv_block

FLAGS = flags.FLAGS

class MAML:
    def __init__(self, dim_input=1, dim_output_train=1, dim_output_val=1, test_num_updates=5):
        """ must call construct_model() after initializing MAML! """
        self.dim_input = dim_input
        self.dim_output_train = dim_output_train
        self.dim_output_val = dim_output_val
        self.update_lr = FLAGS.update_lr
        self.meta_lr = tf.placeholder_with_default(FLAGS.meta_lr, ())
        self.classification = False
        self.test_num_updates = test_num_updates
        self.loss_func = xent
        self.classification = True
        if FLAGS.on_encodings:
            print('Meta-learning on encodings')
            self.dim_hidden = [FLAGS.num_filters] * FLAGS.num_hidden_layers
            print('hidden layers: {}'.format(self.dim_hidden))
            self.forward = self.forward_fc
            self.construct_weights = self.construct_fc_weights
        else:
            if FLAGS.conv:
                self.dim_hidden = FLAGS.num_filters
                if FLAGS.resnet:
                    if FLAGS.input_type == 'images_84x84':
                        self.forward = self.forward_resnet84
                        self.construct_weights = self.construct_resnet_weights84
                        assert FLAGS.num_parts_per_res_block == 2
                        assert FLAGS.num_res_blocks == 4
                        self.num_parts_per_res_block = FLAGS.num_parts_per_res_block
                        self.blocks = ['input', 'maxpool', 'res0', 'maxpool', 'res1', 'maxpool', 'res2', 'maxpool', 'res3', 'output']
                    elif FLAGS.input_type == 'images_224x224':
                        self.forward = self.forward_resnet224
                        self.construct_weights = self.construct_resnet_weights224
                        assert FLAGS.num_parts_per_res_block == 2
                        assert FLAGS.num_res_blocks == 4
                        self.num_parts_per_res_block = FLAGS.num_parts_per_res_block
                        self.blocks = ['input', 'maxpool', 'res0', 'maxpool', 'res1', 'maxpool', 'res2', 'maxpool', 'res3', 'output']
                    else:
                        raise ValueError
                else:
                    self.forward = self.forward_conv
                    self.construct_weights = self.construct_conv_weights
            else:
                self.dim_hidden = [1024, 512, 256, 128]
                print('hidden layers: {}'.format(self.dim_hidden))
                self.forward = self.forward_fc
                self.construct_weights = self.construct_fc_weights
        if FLAGS.dataset == 'mnist' or FLAGS.dataset == 'omniglot':
            self.channels = 1
        else:
            self.channels = 3
        self.img_size = int(np.sqrt(self.dim_input/self.channels))
        if FLAGS.dataset not in ['mnist', 'omniglot', 'miniimagenet', 'celeba', 'imagenet']:
            raise ValueError('Unrecognized data source.')

        # resnet things


    def construct_model(self, input_tensors=None, prefix='metatrain_'):
        # a: training data for inner gradient, b: test data for meta gradient
        if prefix == 'metatrain_':
            inner_update_batch_size = FLAGS.inner_update_batch_size_train
        else:
            inner_update_batch_size = FLAGS.inner_update_batch_size_val
        outer_update_batch_size = FLAGS.outer_update_batch_size
        if input_tensors is None:
            self.inputa = tf.placeholder(tf.float32)
            self.inputb = tf.placeholder(tf.float32)
            self.labela = tf.placeholder(tf.float32)
            self.labelb = tf.placeholder(tf.float32)
        else:
            self.inputa = input_tensors['inputa']
            self.inputb = input_tensors['inputb']
            self.labela = input_tensors['labela']
            self.labelb = input_tensors['labelb']
            if prefix == 'metaval_':
                self.mv_inputa = self.inputa
                self.mv_inputb = self.inputb
                self.mv_labela = self.labela
                self.mv_labelb = self.labelb

        with tf.variable_scope('model', reuse=None) as training_scope:
            if 'weights' in dir(self):
                training_scope.reuse_variables()
                weights = self.weights
            else:
                # Define the weights
                self.weights = weights = self.construct_weights()
                print(weights.keys())

            # outputbs[i] and lossesb[i] is the output and loss after i+1 gradient updates
            lossesa, outputas, lossesb, outputbs = [], [], [], []
            accuraciesa, accuraciesb = [], []
            num_updates = max(self.test_num_updates, FLAGS.num_updates)
            outputbs = [[]]*num_updates
            lossesb = [[]]*num_updates
            accuraciesb = [[]]*num_updates
            if FLAGS.from_scratch:
                train_accuracies = [[]]*num_updates

            def task_metalearn(inp, reuse=True):
                """ Perform gradient descent for one task in the meta-batch. """
                inputa, inputb, labela, labelb = inp
                task_outputbs, task_lossesb = [], []

                if FLAGS.from_scratch:
                    task_outputas = []

                if self.classification:
                    task_accuraciesb = []
                    if FLAGS.from_scratch:
                        task_accuraciesa = []


                task_outputa = self.forward(inputa, weights, prefix, reuse=reuse)  # only reuse on the first iter
                if FLAGS.from_scratch:
                    task_outputas.append(task_outputa)
                task_lossa = self.loss_func(task_outputa, labela, inner_update_batch_size)

                grads = tf.gradients(task_lossa, list(weights.values()))
                if FLAGS.stop_grad:
                    grads = [tf.stop_gradient(grad) for grad in grads]
                gradients = dict(zip(weights.keys(), grads))
                fast_weights = dict(zip(weights.keys(), [weights[key] - self.update_lr*gradients[key] for key in weights.keys()]))
                output = self.forward(inputb, fast_weights, prefix, reuse=True)
                task_outputbs.append(output)
                task_lossesb.append(self.loss_func(output, labelb, outer_update_batch_size))

                for j in range(num_updates - 1):
                    outputa = self.forward(inputa, fast_weights, prefix, reuse=True)
                    loss = self.loss_func(outputa, labela, inner_update_batch_size)
                    if FLAGS.from_scratch:
                        task_outputas.append(outputa)

                    grads = tf.gradients(loss, list(fast_weights.values()))
                    if FLAGS.stop_grad:
                        grads = [tf.stop_gradient(grad) for grad in grads]
                    gradients = dict(zip(fast_weights.keys(), grads))
                    fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.update_lr*gradients[key] for key in fast_weights.keys()]))
                    output = self.forward(inputb, fast_weights, prefix, reuse=True)
                    task_outputbs.append(output)
                    task_lossesb.append(self.loss_func(output, labelb, outer_update_batch_size))

                task_output = [task_outputa, task_outputbs, task_lossa, task_lossesb]

                if self.classification:
                    task_accuracya = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputa), 1), tf.argmax(labela, 1))
                    for j in range(num_updates):
                        task_accuraciesb.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputbs[j]), 1), tf.argmax(labelb, 1)))
                        if FLAGS.from_scratch:
                            task_accuraciesa.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputas[j]), 1), tf.argmax(labela, 1)))
                    task_output.extend([task_accuracya, task_accuraciesb])
                    if FLAGS.from_scratch:
                        task_output.extend([task_accuraciesa])

                return task_output

            if FLAGS.norm is not 'None':
                # to initialize the batch norm vars, might want to combine this, and not run idx 0 twice.
                unused = task_metalearn((self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]), False)

            out_dtype = [tf.float32, [tf.float32]*num_updates, tf.float32, [tf.float32]*num_updates]
            if self.classification:
                out_dtype.extend([tf.float32, [tf.float32]*num_updates])
                if FLAGS.from_scratch:
                    out_dtype.extend([[tf.float32] * num_updates])
            result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, self.labela, self.labelb), dtype=out_dtype, parallel_iterations=FLAGS.meta_batch_size)
            if self.classification:
                if FLAGS.from_scratch:
                    outputas, outputbs, lossesa, lossesb, accuraciesa, accuraciesb, train_accuracies = result
                else:
                    outputas, outputbs, lossesa, lossesb, accuraciesa, accuraciesb = result
            else:
                outputas, outputbs, lossesa, lossesb  = result

        ## Performance & Optimization
        if 'train' in prefix:
            self.total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size)
            self.total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
            # after the map_fn
            self.outputas, self.outputbs = outputas, outputbs
            if self.classification:
                self.total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(FLAGS.meta_batch_size)
                self.total_accuracies2 = total_accuracies2 = [tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
            self.pretrain_op = tf.train.AdamOptimizer(self.meta_lr).minimize(total_loss1)

            if FLAGS.metatrain_iterations > 0:
                optimizer = tf.train.AdamOptimizer(self.meta_lr)
                self.gvs = gvs = optimizer.compute_gradients(self.total_losses2[FLAGS.num_updates-1])
                if FLAGS.dataset == 'miniimagenet' or FLAGS.dataset == 'celeba' or FLAGS.dataset == 'imagenet':
                    gvs = [(tf.clip_by_value(grad, -10, 10), var) for grad, var in gvs]
                self.metatrain_op = optimizer.apply_gradients(gvs)
        else:
            self.metaval_total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size)
            self.metaval_total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
            if self.classification:
                self.metaval_total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(FLAGS.meta_batch_size)
                self.metaval_total_accuracies2 = total_accuracies2 =[tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
                self.mv_outputbs = outputbs
                if FLAGS.from_scratch:
                    self.metaval_train_accuracies = [tf.reduce_sum(train_accuracies[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
        ## Summaries
        tf.summary.scalar(prefix+'Pre-update loss', total_loss1)
        if self.classification:
            tf.summary.scalar(prefix+'Pre-update accuracy', total_accuracy1)

        for j in range(num_updates):
            tf.summary.scalar(prefix+'Post-update loss, step ' + str(j+1), total_losses2[j])
            if self.classification:
                tf.summary.scalar(prefix+'Post-update accuracy, step ' + str(j+1), total_accuracies2[j])

    ### Network construction functions (fc networks and conv networks)
    def construct_fc_weights(self):
        weights = {}
        weights['w1'] = tf.Variable(tf.truncated_normal([self.dim_input, self.dim_hidden[0]], stddev=0.01))
        weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden[0]]))
        for i in range(1,len(self.dim_hidden)):
            weights['w'+str(i+1)] = tf.Variable(tf.truncated_normal([self.dim_hidden[i-1], self.dim_hidden[i]], stddev=0.01))
            weights['b'+str(i+1)] = tf.Variable(tf.zeros([self.dim_hidden[i]]))
        weights['w'+str(len(self.dim_hidden)+1)] = tf.Variable(tf.truncated_normal([self.dim_hidden[-1], self.dim_output_train], stddev=0.01))
        weights['b'+str(len(self.dim_hidden)+1)] = tf.Variable(tf.zeros([self.dim_output_train]))
        return weights

    def forward_fc(self, inp, weights, prefix, reuse=False):
        hidden = normalize(tf.matmul(inp, weights['w1']) + weights['b1'], activation=tf.nn.relu, reuse=reuse, scope='0')
        for i in range(1,len(self.dim_hidden)):
            hidden = normalize(tf.matmul(hidden, weights['w'+str(i+1)]) + weights['b'+str(i+1)], activation=tf.nn.relu, reuse=reuse, scope=str(i+1))
        logits = tf.matmul(hidden, weights['w'+str(len(self.dim_hidden)+1)]) + weights['b'+str(len(self.dim_hidden)+1)]
        if 'val' in prefix:
            logits = tf.gather(logits, tf.range(self.dim_output_val), axis=1)
        return logits

    def construct_conv_weights(self):
        weights = {}

        dtype = tf.float32
        conv_initializer =  tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
        fc_initializer =  tf.contrib.layers.xavier_initializer(dtype=dtype)
        k = 3

        channels = self.channels
        weights['conv1'] = tf.get_variable('conv1', [k, k, channels, self.dim_hidden], initializer=conv_initializer, dtype=dtype)
        weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden]))
        weights['conv2'] = tf.get_variable('conv2', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype)
        weights['b2'] = tf.Variable(tf.zeros([self.dim_hidden]))
        weights['conv3'] = tf.get_variable('conv3', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype)
        weights['b3'] = tf.Variable(tf.zeros([self.dim_hidden]))
        weights['conv4'] = tf.get_variable('conv4', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype)
        weights['b4'] = tf.Variable(tf.zeros([self.dim_hidden]))
        if FLAGS.dataset == 'miniimagenet' or FLAGS.dataset == 'celeba' or FLAGS.dataset == 'imagenet':
            # assumes max pooling
            weights['w5'] = tf.get_variable('w5', [self.dim_hidden*5*5, self.dim_output_train], initializer=fc_initializer)
            weights['b5'] = tf.Variable(tf.zeros([self.dim_output_train]), name='b5')
        else:
            weights['w5'] = tf.Variable(tf.random_normal([self.dim_hidden, self.dim_output_train]), name='w5')
            weights['b5'] = tf.Variable(tf.zeros([self.dim_output_train]), name='b5')

        return weights

    def forward_conv(self, inp, weights, prefix, reuse=False, scope=''):
        # reuse is for the normalization parameters.
        channels = self.channels
        inp = tf.reshape(inp, [-1, self.img_size, self.img_size, channels])

        hidden1 = conv_block(inp, weights['conv1'], weights['b1'], reuse, scope+'0')
        hidden2 = conv_block(hidden1, weights['conv2'], weights['b2'], reuse, scope+'1')
        hidden3 = conv_block(hidden2, weights['conv3'], weights['b3'], reuse, scope+'2')
        hidden4 = conv_block(hidden3, weights['conv4'], weights['b4'], reuse, scope+'3')
        if FLAGS.dataset == 'miniimagenet' or FLAGS.dataset == 'celeba' or FLAGS.dataset == 'imagenet':
            # last hidden layer is 6x6x64-ish, reshape to a vector
            hidden4 = tf.reshape(hidden4, [-1, np.prod([int(dim) for dim in hidden4.get_shape()[1:]])])
        else:
            hidden4 = tf.reduce_mean(hidden4, [1, 2])

        logits = tf.matmul(hidden4, weights['w5']) + weights['b5']

        if 'val' in prefix:
            logits = tf.gather(logits, tf.range(self.dim_output_val), axis=1)
        return logits

    def construct_resnet_weights224(self):
        weights = OrderedDict()
        dtype = tf.float32

        conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
        bias_initializer = tf.zeros_initializer(dtype=dtype)
        fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
        def make_conv_layer_weights(weights, scope, k, filters_in, filters_out, bias=True):
            weights['{}/conv'.format(scope)] = tf.get_variable('{}/conv'.format(scope), [k, k, filters_in, filters_out], initializer=conv_initializer, dtype=dtype)
            if bias:
                weights['{}/bias'.format(scope)] = tf.get_variable('{}/bias'.format(scope), [filters_out], initializer=bias_initializer, dtype=dtype)
        def make_fc_layer_weights(weights, scope, dims_in, dims_out):
            weights['{}/fc'.format(scope)] = tf.get_variable('{}/fc'.format(scope), [dims_in, dims_out], initializer=fc_initializer, dtype=dtype)
            weights['{}/bias'.format(scope)] = tf.get_variable('{}/bias'.format(scope), [dims_out], initializer=bias_initializer, dtype=dtype)
        for block_name in self.blocks:
            if block_name == 'input':
                make_conv_layer_weights(weights, block_name, k=7, filters_in=self.channels, filters_out=64)
            elif 'res' in block_name:
                j = int(block_name[-1])
                last_block_filter = 64 if j == 0 else 64 * 2 ** (j-1)
                this_block_filter = 64 if j == 0 else last_block_filter * 2
                print(block_name, last_block_filter, this_block_filter)
                make_conv_layer_weights(weights, '{}/shortcut'.format(block_name), k=1, filters_in=last_block_filter,
                                        filters_out=this_block_filter, bias=False)
                for i in range(self.num_parts_per_res_block):
                    make_conv_layer_weights(weights, '{}/part{}'.format(block_name, i), k=3,
                                            filters_in=last_block_filter if i == 0 else this_block_filter,
                                            filters_out=this_block_filter)
            elif block_name == 'output':
                make_fc_layer_weights(weights, block_name, dims_in=512, dims_out=self.dim_output_train)
        return weights

    def construct_resnet_weights84(self):
        weights = OrderedDict()
        dtype = tf.float32

        conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
        bias_initializer = tf.zeros_initializer(dtype=dtype)
        fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
        def make_conv_layer_weights(weights, scope, k, filters_in, filters_out, bias=True):
            weights['{}/conv'.format(scope)] = tf.get_variable('{}/conv'.format(scope), [k, k, filters_in, filters_out], initializer=conv_initializer, dtype=dtype)
            if bias:
                weights['{}/bias'.format(scope)] = tf.get_variable('{}/bias'.format(scope), [filters_out], initializer=bias_initializer, dtype=dtype)
        def make_fc_layer_weights(weights, scope, dims_in, dims_out):
            weights['{}/fc'.format(scope)] = tf.get_variable('{}/fc'.format(scope), [dims_in, dims_out], initializer=fc_initializer, dtype=dtype)
            weights['{}/bias'.format(scope)] = tf.get_variable('{}/bias'.format(scope), [dims_out], initializer=bias_initializer, dtype=dtype)
        for block_name in self.blocks:
            if block_name == 'input':
                make_conv_layer_weights(weights, block_name, k=3, filters_in=self.channels, filters_out=64)
            elif 'res' in block_name:
                j = int(block_name[-1])
                last_block_filter = 64 if j == 0 else 64 * 2 ** (j - 1)
                this_block_filter = 64 if j == 0 else last_block_filter * 2
                print(block_name, last_block_filter, this_block_filter)
                make_conv_layer_weights(weights, '{}/shortcut'.format(block_name), k=1, filters_in=last_block_filter,
                                        filters_out=this_block_filter, bias=False)
                for i in range(self.num_parts_per_res_block):
                    make_conv_layer_weights(weights, '{}/part{}'.format(block_name, i), k=3,
                                            filters_in=last_block_filter if i == 0 else this_block_filter,
                                            filters_out=this_block_filter)
            elif block_name == 'output':
                make_fc_layer_weights(weights, block_name, dims_in=512, dims_out=self.dim_output_train)
        return weights

    def forward_resnet224(self, inp, weights, prefix, reuse=False):

        inp = tf.reshape(inp, [-1, self.img_size, self.img_size, self.channels])

        for block_name in self.blocks:
            if block_name == 'input':
                conv = weights['{}/conv'.format(block_name)]
                bias = weights['{}/bias'.format(block_name)]
                inp = tf.nn.conv2d(inp, filter=conv, strides=[1, 2, 2, 1], padding="SAME") + bias
            elif 'res' in block_name:
                shortcut = inp
                conv = weights['{}/shortcut/conv'.format(block_name)]
                shortcut = tf.nn.conv2d(input=shortcut, filter=conv, strides=[1, 1, 1, 1], padding="SAME")
                for part in range(self.num_parts_per_res_block):
                    part_name = 'part{}'.format(part)
                    scope = '{}/{}'.format(block_name, part_name)
                    conv = weights['{}/{}/conv'.format(block_name, part_name)]
                    bias = weights['{}/{}/bias'.format(block_name, part_name)]
                    inp = bn_relu_conv_block(inp=inp, conv=conv, bias=bias, reuse=reuse, scope=scope)
                inp = shortcut + inp
            elif 'maxpool' in block_name:
                inp = tf.nn.max_pool(inp, [1, 2, 2, 1], [1, 2, 2, 1], "VALID")
            elif 'output' in block_name:
                inp = tf.reduce_mean(inp, [1, 2])
                fc = weights['{}/fc'.format(block_name)]
                bias = weights['{}/bias'.format(block_name)]
                inp = tf.matmul(inp, fc) + bias
                if 'val' in prefix:
                    inp = tf.gather(inp, tf.range(self.dim_output_val), axis=1)
        return inp

    def forward_resnet84(self, inp, weights, prefix, reuse=False):

        inp = tf.reshape(inp, [-1, self.img_size, self.img_size, self.channels])

        for block_name in self.blocks:
            if block_name == 'input':
                conv = weights['{}/conv'.format(block_name)]
                bias = weights['{}/bias'.format(block_name)]
                inp = tf.nn.conv2d(inp, filter=conv, strides=[1, 1, 1, 1], padding="SAME") + bias
            elif 'res' in block_name:
                shortcut = inp
                conv = weights['{}/shortcut/conv'.format(block_name)]
                shortcut = tf.nn.conv2d(input=shortcut, filter=conv, strides=[1, 1, 1, 1], padding="SAME")
                for part in range(self.num_parts_per_res_block):
                    part_name = 'part{}'.format(part)
                    scope = '{}/{}'.format(block_name, part_name)
                    conv = weights['{}/{}/conv'.format(block_name, part_name)]
                    bias = weights['{}/{}/bias'.format(block_name, part_name)]
                    inp = bn_relu_conv_block(inp=inp, conv=conv, bias=bias, reuse=reuse, scope=scope)
                inp = shortcut + inp
            elif 'maxpool' in block_name:
                inp = tf.nn.max_pool(inp, [1, 2, 2, 1], [1, 2, 2, 1], "VALID")
            elif 'output' in block_name:
                inp = tf.reduce_mean(inp, [1, 2])
                fc = weights['{}/fc'.format(block_name)]
                bias = weights['{}/bias'.format(block_name)]
                inp = tf.matmul(inp, fc) + bias
                if 'val' in prefix:
                    inp = tf.gather(inp, tf.range(self.dim_output_val), axis=1)
        return inp

    def wrap(self, inp, weights, prefix, reuse=False, scope=''):
        unused = self.forward_resnet(inp, weights, prefix, reuse=False)
        return self.forward_resnet(inp, weights, prefix, reuse=True)


if __name__ == '__main__':
    import ipdb

    FLAGS = flags.FLAGS

    ## Dataset/method options
    flags.DEFINE_string('dataset', 'omniglot', 'omniglot or mnist or miniimagenet or celeba')
    flags.DEFINE_integer('num_encoding_dims', -1, 'of unsupervised representation learning method')
    flags.DEFINE_string('encoder', 'acai', 'acai or bigan or deepcluster or infogan')

    ## Training options
    flags.DEFINE_integer('metatrain_iterations', 30000, 'number of metatraining iterations.')
    flags.DEFINE_integer('meta_batch_size', 8, 'number of tasks sampled per meta-update')
    flags.DEFINE_float('meta_lr', 0.001, 'the base learning rate of the generator')
    flags.DEFINE_float('update_lr', 0.05, 'step size alpha for inner gradient update.')
    flags.DEFINE_integer('inner_update_batch_size_train', 1,
                         'number of examples used for inner gradient update (K for K-shot learning).')
    flags.DEFINE_integer('inner_update_batch_size_val', 5, 'above but for meta-val')
    flags.DEFINE_integer('outer_update_batch_size', 5, 'number of examples used for outer gradient update')
    flags.DEFINE_integer('num_updates', 5, 'number of inner gradient updates during training.')
    flags.DEFINE_string('mt_mode', 'gtgt', 'meta-training mode (for sampling, labeling): gtgt or encenc')
    flags.DEFINE_string('mv_mode', 'gtgt', 'meta-validation mode (for sampling, labeling): gtgt or encenc')
    flags.DEFINE_integer('num_classes_train', 5, 'number of classes used in classification for meta-training')
    flags.DEFINE_integer('num_classes_val', 5, 'number of classes used in classification for meta-validation.')
    flags.DEFINE_float('margin', 0.0, 'margin for generating partitions using random hyperplanes')
    flags.DEFINE_integer('num_partitions', 1, 'number of partitions, -1 for same as number of meta-training tasks')
    flags.DEFINE_string('partition_algorithm', 'kmeans', 'hyperplanes or kmeans')
    flags.DEFINE_integer('num_clusters', -1, 'number of clusters for kmeans')
    flags.DEFINE_boolean('scaled_encodings', True, 'if True, use randomly scaled encodings for kmeans')
    flags.DEFINE_boolean('on_encodings', False, 'if True, train MAML on top of encodings')
    flags.DEFINE_integer('num_hidden_layers', 2, 'number of mlp hidden layers')
    flags.DEFINE_integer('num_parallel_calls', 8, 'for loading data')
    flags.DEFINE_integer('gpu', 7, 'CUDA_VISIBLE_DEVICES=')

    ## Model options
    flags.DEFINE_string('norm', 'batch_norm', 'batch_norm, layer_norm, or None')
    flags.DEFINE_integer('num_filters', 32, 'number of filters for each conv layer')
    flags.DEFINE_bool('conv', True, 'whether or not to use a convolutional network')
    flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions')
    flags.DEFINE_bool('stop_grad', False, 'if True, do not use second derivatives in meta-optimization (for speed)')

    ## Logging, saving, and testing options
    flags.DEFINE_bool('log', True, 'if false, do not log summaries, for debugging code.')
    flags.DEFINE_string('logdir', './log', 'directory for summaries and checkpoints.')
    flags.DEFINE_bool('resume', True, 'resume training if there is a model available')
    flags.DEFINE_bool('train', True, 'True to train, False to test.')
    flags.DEFINE_integer('test_iter', -1, 'iteration to load model (-1 for latest model)')
    flags.DEFINE_bool('test_set', False, 'Set to true to test on the the test set, False for the validation set.')
    flags.DEFINE_integer('log_inner_update_batch_size_val', -1,
                         'specify log directory iubsv. (use to test with different iubsv)')
    flags.DEFINE_float('train_update_lr', -1,
                       'value of inner gradient step step during training. (use if you want to test with a different value)')
    flags.DEFINE_bool('save_checkpoints', False, 'if True, save model weights as checkpoints')
    flags.DEFINE_bool('debug', False, 'if True, use tf debugger')
    flags.DEFINE_string('suffix', '', 'suffix for an exp_string')
    flags.DEFINE_bool('from_scratch', False, 'fast-adapt from scratch')
    flags.DEFINE_integer('num_eval_tasks', 1000, 'number of tasks to meta-test on')

    # Imagenet
    flags.DEFINE_string('input_type', 'images_84x84',
                        'features or features_processed or images_fullsize or images_84x84')
    flags.DEFINE_string('data_dir', '/data3/kylehsu/data', 'location of data')
    flags.DEFINE_bool('resnet', False, 'use resnet architecture')
    flags.DEFINE_integer('num_res_blocks', 5, 'number of resnet blocks')
    flags.DEFINE_integer('num_parts_per_res_block', 2, 'number of bn-relu-conv parts in a res block')

    FLAGS.resnet = True

    maml = MAML(dim_input=3*84*84, dim_output_train=10, dim_output_val=5, test_num_updates=5)
    maml.channels = 3
    maml.img_size = 84
    weights = maml.construct_resnet_weights()
    input_ph = tf.placeholder(tf.float32)
    unused = maml.forward_resnet(input_ph, weights, 'hi', reuse=False)




    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())
    input = np.ones((1, 84 * 84 * 3), dtype=np.float32)

    y = sess.run(maml.forward_resnet(input_ph, weights, 'val', reuse=True), {input_ph: input})


    ipdb.set_trace()
    x=1