import time
import os
import vn

import tensorflow as tf
import argparse

from mridata import VnMriReconstructionData, VnMriFilenameProducer

import tensorflow.contrib.icg as icg

class VnMriReconstructionCell(tf.contrib.icg.VnBasicCell):
    def mriForwardOpWithOS(self, u, coil_sens, sampling_mask):
        with tf.variable_scope('mriForwardOp'):
            # add frequency encoding oversampling
            pad_u = tf.cast(tf.multiply(tf.cast(tf.shape(sampling_mask)[1], tf.float32), 0.25) + 1, tf.int32)
            pad_l = tf.cast(tf.multiply(tf.cast(tf.shape(sampling_mask)[1], tf.float32), 0.25) - 1, tf.int32)
            u_pad = tf.pad(u, [[0, 0], [pad_u, pad_l], [0, 0]])
            u_pad = tf.expand_dims(u_pad, axis=1)
            # apply sensitivites
            coil_imgs = u_pad * coil_sens
            # centered Fourier transform
            Fu = tf.contrib.icg.fftc2d(coil_imgs)
            # apply sampling mask
            mask = tf.expand_dims(sampling_mask, axis=1)
            kspace = tf.complex(tf.real(Fu) * mask, tf.imag(Fu) * mask)
        return kspace

    def mriAdjointOpWithOS(self, f, coil_sens, sampling_mask):
        with tf.variable_scope('mriAdjointOp'):
            # variables to remove frequency encoding oversampling
            pad_u = tf.cast(tf.multiply(tf.cast(tf.shape(sampling_mask)[1], tf.float32), 0.25) + 1, tf.int32)
            pad_l = tf.cast(tf.multiply(tf.cast(tf.shape(sampling_mask)[1], tf.float32), 0.25) - 1, tf.int32)
            # apply mask and perform inverse centered Fourier transform
            mask = tf.expand_dims(sampling_mask, axis=1)
            Finv = tf.contrib.icg.ifftc2d(tf.complex(tf.real(f) * mask, tf.imag(f) * mask))
            # multiply coil images with sensitivities and sum up over channels
            img = tf.reduce_sum(Finv * tf.conj(coil_sens), 1)[:, pad_u:-pad_l, :]
        return img

    def mriForwardOp(self, u, coil_sens, sampling_mask):
        with tf.variable_scope('mriForwardOp'):
            # apply sensitivites
            coil_imgs = tf.expand_dims(u, axis=1) * coil_sens
            # centered Fourier transform
            Fu = tf.contrib.icg.fftc2d(coil_imgs)
            # apply sampling mask
            mask = tf.expand_dims(sampling_mask, axis=1)
            kspace = tf.complex(tf.real(Fu) * mask, tf.imag(Fu) * mask)
        return kspace

    def mriAdjointOp(self, f, coil_sens, sampling_mask):
        with tf.variable_scope('mriAdjointOp'):
            # apply mask and perform inverse centered Fourier transform
            mask = tf.expand_dims(sampling_mask, axis=1)
            Finv = tf.contrib.icg.ifftc2d(tf.complex(tf.real(f) * mask, tf.imag(f) * mask))
            # multiply coil images with sensitivities and sum up over channels
            img = tf.reduce_sum(Finv * tf.conj(coil_sens), 1)
        return img

    def call(self, t, inputs):
        # get the variables
        u_t_1 = inputs[0][t]

        # extract constants
        f = self._constants['f']
        c = self._constants['coil_sens']
        m = self._constants['sampling_mask']

        # get the parameters
        param_idx = self.time_to_param_index(t)

        # datatermweight
        lambdaa = self._params['lambda'][param_idx]
        # activation function weights
        w = self._params['w1'][param_idx]
        # convolution kernels
        k = self._params['k1'][param_idx]

        # extract options
        vmin = self._options['vmin']
        vmax = self._options['vmax']
        pad = self._options['pad']

        # split kernels
        k_real = tf.real(k)
        k_imag = tf.imag(k)

        # define the cell
        # pad the image to avoid problems at the border
        u_p = tf.pad(tf.expand_dims(u_t_1,-1), [[0, 0], [pad, pad], [pad, pad], [0, 0]], 'SYMMETRIC')
        # split the image in real and imaginary part and perform convolution
        u_k_real = tf.nn.conv2d(tf.real(u_p), k_real, [1, 1, 1, 1], 'SAME')
        u_k_imag = tf.nn.conv2d(tf.imag(u_p), k_imag, [1, 1, 1, 1], 'SAME')
        # add up the convolution results
        u_k = u_k_real + u_k_imag
        # apply the activation functions
        f_u_k = icg.activation_rbf(u_k, w, v_min=vmin, v_max=vmax, num_weights=w.shape[1], feature_stride=1)
        # perform transpose convolution for real and imaginary part
        u_k_T_real = tf.nn.conv2d_transpose(f_u_k, tf.real(k), tf.shape(u_p), [1, 1, 1, 1], 'SAME')
        u_k_T_imag= tf.nn.conv2d_transpose(f_u_k, tf.imag(k), tf.shape(u_p), [1, 1, 1, 1], 'SAME')
        # rebuild complex image
        u_k_T = tf.complex(u_k_T_real, u_k_T_imag)
        # remove padding
        Ru = u_k_T[:, pad:-pad, pad:-pad, 0]

        # normalize regularizer by number of filters
        Ru /= self._options['num_filter']

        # define dataterm operators according to sampling pattern
        if self._options['sampling_pattern'] == 'cartesian':
            print('mri op')
            forwardOp = self.mriForwardOp
            adjointOp = self.mriAdjointOp
        elif not 'sampling_pattern' in self._options or  self._options['sampling_pattern'] == 'cartesian_with_os':
            print('mri op with OS')
            forwardOp = self.mriForwardOpWithOS
            adjointOp = self.mriAdjointOpWithOS
        else:
            raise ValueError("Selected sampling pattern '%s' does not exist!" % (self._options['sampling_pattern']))

        # build dataterm
        Au = forwardOp(u_t_1, c, m)
        At_Au_f = adjointOp(Au - f, c, m)
        Du = tf.complex(tf.real(At_Au_f)*lambdaa, tf.imag(At_Au_f)*lambdaa)

        # gradient step
        u_t = u_t_1 - Ru - Du

        return [u_t]

if __name__ == '__main__':
    # Add arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--training_config', type=str, default='./configs/training.yaml')
    parser.add_argument('--network_config', type=str, default='./configs/mri_vn.yaml')
    parser.add_argument('--data_config', type=str, default='./configs/data.yaml')
    parser.add_argument('--global_config', type=str, default='./configs/global.yaml')

    args = parser.parse_args()

    # Load the configs
    network_config, reg_config = tf.contrib.icg.utils.loadYaml(args.network_config, ['network', 'reg'])
    checkpoint_config, optimizer_config = tf.contrib.icg.utils.loadYaml(args.training_config, ['checkpoint_config', 'optimizer_config'])
    data_config = tf.contrib.icg.utils.loadYaml(args.data_config, ['data_config'])
    global_config = tf.contrib.icg.utils.loadYaml(args.global_config, ['global_config'])

    # Tensorflow config
    tf_config = tf.ConfigProto(log_device_placement=False)
    tf_config.gpu_options.allow_growth = global_config['tf_allow_gpu_growth']

    # define the output locations
    base_name = os.path.basename(args.network_config).split('.')[0]
    suffix = base_name + '_' + time.strftime('%Y-%m-%d--%H-%M-%S')
    vn.setupLogDirs(suffix, args, checkpoint_config)

    # load data
    filename_producer = VnMriFilenameProducer(data_config)
    data = VnMriReconstructionData(data_config, filename_dequeue_op=filename_producer.dequeue_op, queue_capacity=global_config['data_queue_capacity'])

    network_config['sampling_pattern'] = data_config['sampling_pattern']

    # Create a queue runner that will run 4 threads in parallel to enqueue examples.
    qr_data = tf.train.QueueRunner(data.queue, [data.enqueue_op] * global_config['data_num_threads'])
    # Create a queue runner to produce the filenames
    qr_filenames = tf.train.QueueRunner(filename_producer.queue, [filename_producer.enqueue_op])

    # Create a coordinator, launch the queue runner threads.
    coord = tf.train.Coordinator()

    # define parameters
    params = tf.contrib.icg.utils.Params()
    const_params = tf.contrib.icg.utils.ConstParams()

    vn.paramdefinitions.add_convolution_params(params, const_params, reg_config['filter1'])
    vn.paramdefinitions.add_activation_function_params(params, reg_config['activation1'])
    vn.paramdefinitions.add_dataterm_weights(params, network_config)

    # setup the network
    vn_cell = VnMriReconstructionCell(params=params.get(),
                              const_params=const_params.get(),
                              inputs=[data.u],
                              constants=data.constants,
                              options=network_config)

    mrirecon_vn = tf.contrib.icg.VariationalNetwork(cell=vn_cell,
                                        num_stages=network_config['num_stages'],
                                        parallel_iterations=global_config['parallel_iterations'],
                                        swap_memory=global_config['swap_memory'])

    # get all images
    u_all = mrirecon_vn.get_outputs(stage_outputs=True)[0]
    u_T = tf.identity(u_all[-1], 'u_T')

    # define loss
    with tf.variable_scope('loss'):
        # mse abs-smoothed
        target_abs = tf.sqrt(tf.real((data.target) * tf.conj(data.target)) + 1e-12)
        output_abs = tf.sqrt(tf.real((u_T) * tf.conj(u_T)) + 1e-12)

        energy = tf.reduce_mean(tf.reduce_sum(((output_abs - target_abs) ** 2), axis=(1, 2)))

        # rmse
        denominator = tf.reduce_sum(tf.real((data.target) * tf.conj(data.target)), axis=(1, 2))
        nominator = tf.reduce_sum(tf.real((u_T - data.target) * tf.conj(u_T - data.target)), axis=(1, 2))
        rmse = tf.reduce_mean(tf.sqrt(nominator / denominator))

        # ssim
        output_abs = tf.expand_dims(tf.abs(u_T), -1)
        target_abs = tf.expand_dims(tf.abs(data.target), -1)
        L = tf.reduce_max(target_abs, axis=(1, 2, 3), keepdims=True) - tf.reduce_min(target_abs, axis=(1, 2, 3),
                                                                                     keepdims=True)
        ssim = vn.utils.ssim(output_abs, target_abs, L=L)

    # add images and energy to summary
    with tf.variable_scope('loss_summary'):
        tf.summary.scalar('energy', energy)
        tf.summary.scalar('rmse', rmse)
        tf.summary.scalar('ssim', ssim)

    # add images to tensorboard
    tf.summary.image('input', tf.abs(tf.expand_dims(data.u, -1)), max_outputs=10)
    for i in range(network_config['num_stages']):
        tf.summary.image('u%d' % (i + 1), tf.abs(tf.expand_dims(u_all[i + 1], -1)), max_outputs=10)
    tf.summary.image('target', tf.abs(tf.expand_dims(data.target, -1)), max_outputs=10)

    # define the optimizer
    optimizer = icg.optimizer.IPALMOptimizer(params, energy, optimizer_config)

    with tf.Session(config=tf_config) as sess:
        # initialize the variables
        init = tf.global_variables_initializer()
        sess.run(init)

        # memorize a few ops and placeholders to be used in evaluation
        energy_op = tf.add_to_collection('energy_op', energy)
        ssim_op = tf.add_to_collection('ssim_op', ssim)
        rmse_op = tf.add_to_collection('rmse_op', rmse)
        u_op = tf.add_to_collection('u_op', u_all[-1])
        u_all_op = tf.add_to_collection('u_all_op', u_all)
        u_var = tf.add_to_collection('u_var', data.u)
        g_var = tf.add_to_collection('g_var', data.target)
        c_var = tf.add_to_collection('c_var', data.constants['coil_sens'])
        m_var = tf.add_to_collection('m_var', data.constants['sampling_mask'])
        f_var = tf.add_to_collection('f_var', data.constants['f'])
        g_var = tf.add_to_collection('g_var', data.target)

        # load from checkpoint if required
        saver = tf.train.Saver(max_to_keep=0)

        # initialize enqueuing threads
        enqueue_threads_filename_producer = qr_filenames.create_threads(sess, coord=coord, start=True)
        enqueue_threads_data = qr_data.create_threads(sess, coord=coord, start=True)

        # collect the summaries
        epoch_summaries = tf.summary.merge_all()
        image_summaries = tf.summary.merge_all(key='images')
        train_writer = tf.summary.FileWriter(checkpoint_config['log_dir'] + '/' + suffix + '/train/', sess.graph)

        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()

        iter_per_epoch = filename_producer.iter_per_epoch

        try:
            start_time = time.time()
            for epoch in range(0, optimizer_config['max_iter'] + 1):
                if coord.should_stop():
                    break

                # get next mini batch
                feed_dict = data.get_feed_dict(sess=sess)

                # run a single iteration
                optimizer.minimize(sess, epoch, feed_dict)

                feed_dict = data.get_eval_feed_dict()

                if (epoch % checkpoint_config['summary_modulo'] == 0) or epoch == optimizer_config['max_iter']:
                    summary = sess.run(epoch_summaries,
                                       feed_dict=feed_dict,
                                       options=run_options, run_metadata=run_metadata)
                    train_writer.add_run_metadata(run_metadata, 'step%d' % epoch)
                    train_writer.add_summary(summary, epoch)

                if (epoch % checkpoint_config['save_modulo'] == 0) or epoch == optimizer_config['max_iter']:
                    # update summary
                    summary = sess.run(image_summaries,
                                       feed_dict=feed_dict,
                                       options=run_options, run_metadata=run_metadata)
                    train_writer.add_run_metadata(run_metadata, 'images%d' % epoch)
                    train_writer.add_summary(summary, epoch)
                    # save variables to checkpoint
                    saver.save(sess, checkpoint_config['log_dir'] + '/' + suffix + '/checkpoints/' + 'checkpoint', global_step=epoch)

                # compute the current energy
                e_i = sess.run(energy, feed_dict=feed_dict)
                print("epoch:", epoch, "energy =", e_i)

            print('Elapsed training time:', time.time() - start_time)

        except Exception as e:
            # Report exceptions to the coordinator.
            coord.request_stop(e)
        except KeyboardInterrupt as e:
            print('[KEYBOARD INTERRUPT]: Stop training.')
        finally:
            # Terminate as usual. It is innocuous to request stop twice.
            coord.request_stop()
            coord.join(enqueue_threads_data)
            coord.join(enqueue_threads_filename_producer)