# downloads/extracts datasets described in the README.md

import os
import sys
import errno
import tarfile

if sys.version_info >= (3,):
    from urllib.request import urlretrieve
else:
    from urllib import urlretrieve

DATA_DIR = 'Data'


# http://stackoverflow.com/questions/273192/how-to-check-if-a-directory-exists-and-create-it-if-necessary
def make_sure_path_exists(path):
    try:
        os.makedirs(path)
    except OSError as exception:
        if exception.errno != errno.EEXIST:
            raise


def create_data_paths():
    if not os.path.isdir(DATA_DIR):
        raise EnvironmentError('Needs to be run from project directory containing ' + DATA_DIR)
    needed_paths = [
        os.path.join(DATA_DIR, 'samples'),
        os.path.join(DATA_DIR, 'val_samples'),
        os.path.join(DATA_DIR, 'Models'),
    ]
    for p in needed_paths:
        make_sure_path_exists(p)


# adapted from http://stackoverflow.com/questions/51212/how-to-write-a-download-progress-indicator-in-python
def dl_progress_hook(count, blockSize, totalSize):
    percent = int(count * blockSize * 100 / totalSize)
    sys.stdout.write("\r" + "...%d%%" % percent)
    sys.stdout.flush()


def download_dataset(data_name):
    if data_name == 'flowers':
        print('== Flowers dataset ==')
        flowers_dir = os.path.join(DATA_DIR, 'flowers')
        flowers_jpg_tgz = os.path.join(flowers_dir, '102flowers.tgz')
        make_sure_path_exists(flowers_dir)

        # the original google drive link at https://drive.google.com/file/d/0B0ywwgffWnLLcms2WWJQRFNSWXM/view
        # from https://github.com/reedscot/icml2016 is problematic to download automatically, so included
        # the text_c10 directory from that archive as a bzipped file in the repo
        captions_tbz = os.path.join(DATA_DIR, 'flowers_text_c10.tar.bz2')
        print('Extracting ' + captions_tbz)
        captions_tar = tarfile.open(captions_tbz, 'r:bz2')
        captions_tar.extractall(flowers_dir)

        flowers_url = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
        print('Downloading ' + flowers_jpg_tgz + ' from ' + flowers_url)
        urlretrieve(flowers_url, flowers_jpg_tgz,
                    reporthook=dl_progress_hook)
        print('Extracting ' + flowers_jpg_tgz)
        flowers_jpg_tar = tarfile.open(flowers_jpg_tgz, 'r:gz')
        flowers_jpg_tar.extractall(flowers_dir)  # archive contains jpg/ folder

    elif data_name == 'skipthoughts':
        print('== Skipthoughts models ==')
        SKIPTHOUGHTS_DIR = os.path.join(DATA_DIR, 'skipthoughts')
        SKIPTHOUGHTS_BASE_URL = 'http://www.cs.toronto.edu/~rkiros/models/'
        make_sure_path_exists(SKIPTHOUGHTS_DIR)

        # following https://github.com/ryankiros/skip-thoughts#getting-started
        skipthoughts_files = [
            'dictionary.txt', 'utable.npy', 'btable.npy', 'uni_skip.npz', 'uni_skip.npz.pkl', 'bi_skip.npz',
            'bi_skip.npz.pkl',
        ]
        for filename in skipthoughts_files:
            src_url = SKIPTHOUGHTS_BASE_URL + filename
            print('Downloading ' + src_url)
            urlretrieve(src_url, os.path.join(SKIPTHOUGHTS_DIR, filename),
                        reporthook=dl_progress_hook)

    elif data_name == 'nltk_punkt':
        import nltk
        print('== NLTK pre-trained Punkt tokenizer for English ==')
        nltk.download('punkt')

    elif data_name == 'pretrained_model':
        print('== Pretrained model ==')
        MODEL_DIR = os.path.join(DATA_DIR, 'Models')
        pretrained_model_filename = 'latest_model_flowers_temp.ckpt'
        src_url = 'https://bitbucket.org/paarth_neekhara/texttomimagemodel/raw/74a4bbaeee26fe31e148a54c4f495694680e2c31/' + pretrained_model_filename
        print('Downloading ' + src_url)
        urlretrieve(
            src_url,
            os.path.join(MODEL_DIR, pretrained_model_filename),
            reporthook=dl_progress_hook,
        )

    else:
        raise ValueError('Unknown dataset name: ' + data_name)


def main():
    create_data_paths()
    # TODO: make configurable via command-line
    download_dataset('flowers')
    download_dataset('skipthoughts')
    download_dataset('nltk_punkt')
    download_dataset('pretrained_model')
    print('Done')


if __name__ == '__main__':
    main()