import os
import glob

import tensorflow as tf


def from_cached_tfrecords(args):
    """ Use tf.Dataset, but feeding it using a placeholder w/ the whole dataset. """
    # this may seem a bit weird
    # we take tfrecords but load them into placeholders during training
    # we found that it loaded faster this way when this was first implemented
    # letting tf.Dataset loading all simulataneously is conceptually better

    res, nch = args.input_res, args.nchannels

    x = tf.placeholder(args.dtype, (None, res, res, nch))
    y = tf.placeholder('int64', (None))

    dataset = tf.contrib.data.Dataset.from_tensor_slices((x, y))

    # inputs are complex numbers
    # magnitude is ray length
    # phase is angle between ray and normal
    # we found that it is best to treat them independently, though
    dataset = dataset.map(lambda x, y: (tf.concat([tf.abs(x),
                                                   tf.imag(x/(tf.cast(tf.abs(x), 'complex64') +1e-8))],
                                                  axis=-1), y))

    # we use same batch sizes for train/val/test
    dataset = dataset.batch(args.train_bsize)
    iterator = dataset.make_initializable_iterator()

    fnames = {}
    for t in ['train', 'test', 'val']:
        fnames[t] = glob.glob(args.dset_dir + '/{}*.tfrecord'.format(t))

    out = {'x': x, 'y': y, 'fnames': fnames}
    print('loading dataset; number of tfrecords: {}'
          .format({k: len(v) for k, v in out['fnames'].items()}))

    return iterator, out


def load(args):
    dset_fun = globals().get(os.path.splitext(args.dset)[0])
    dset = dset_fun(args)

    return dset