#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Utilities for downloading and building data.

These can be replaced if your particular file system does not support them.
"""

import importlib
import json
import time
import datetime
import os
import requests
import shutil
import hashlib
import tqdm
import math
import traceback
from multiprocessing import Pool


def built(path, version_string=None):
    """
    Check if '.built' flag has been set for that task.

    If a version_string is provided, this has to match, or the version
    is regarded as not built.
    """
    if version_string:
        fname = os.path.join(path, '.built')
        if not os.path.isfile(fname):
            return False
        else:
            with open(fname, 'r') as read:
                text = read.read().split('\n')
            return len(text) > 1 and text[1] == version_string
    else:
        return os.path.isfile(os.path.join(path, '.built'))


def mark_done(path, version_string=None):
    """
    Mark this path as prebuilt.

    Marks the path as done by adding a '.built' file with the current timestamp
    plus a version description string if specified.

    :param str path:
        The file path to mark as built.

    :param str version_string:
        The version of this dataset.
    """
    with open(os.path.join(path, '.built'), 'w') as write:
        write.write(str(datetime.datetime.today()))
        if version_string:
            write.write('\n' + version_string)


def download(url, path, fname, redownload=False):
    """
    Download file using `requests`.

    If ``redownload`` is set to false, then
    will not download tar file again if it is present (default ``True``).
    """
    outfile = os.path.join(path, fname)
    download = not os.path.isfile(outfile) or redownload
    print("[ downloading: " + url + " to " + outfile + " ]")
    retry = 5
    exp_backoff = [2 ** r for r in reversed(range(retry))]

    pbar = tqdm.tqdm(unit='B', unit_scale=True, desc='Downloading {}'.format(fname))

    while download and retry >= 0:
        resume_file = outfile + '.part'
        resume = os.path.isfile(resume_file)
        if resume:
            resume_pos = os.path.getsize(resume_file)
            mode = 'ab'
        else:
            resume_pos = 0
            mode = 'wb'
        response = None

        with requests.Session() as session:
            try:
                header = (
                    {'Range': 'bytes=%d-' % resume_pos, 'Accept-Encoding': 'identity'}
                    if resume
                    else {}
                )
                response = session.get(url, stream=True, timeout=5, headers=header)

                # negative reply could be 'none' or just missing
                if resume and response.headers.get('Accept-Ranges', 'none') == 'none':
                    resume_pos = 0
                    mode = 'wb'

                CHUNK_SIZE = 32768
                total_size = int(response.headers.get('Content-Length', -1))
                # server returns remaining size if resuming, so adjust total
                total_size += resume_pos
                pbar.total = total_size
                done = resume_pos

                with open(resume_file, mode) as f:
                    for chunk in response.iter_content(CHUNK_SIZE):
                        if chunk:  # filter out keep-alive new chunks
                            f.write(chunk)
                        if total_size > 0:
                            done += len(chunk)
                            if total_size < done:
                                # don't freak out if content-length was too small
                                total_size = done
                                pbar.total = total_size
                            pbar.update(len(chunk))
                    break
            except requests.exceptions.ConnectionError:
                retry -= 1
                pbar.clear()
                if retry >= 0:
                    print('Connection error, retrying. (%d retries left)' % retry)
                    time.sleep(exp_backoff[retry])
                else:
                    print('Retried too many times, stopped retrying.')
            finally:
                if response:
                    response.close()
    if retry < 0:
        raise RuntimeWarning('Connection broken too many times. Stopped retrying.')

    if download and retry > 0:
        pbar.update(done - pbar.n)
        if done < total_size:
            raise RuntimeWarning(
                'Received less data than specified in '
                + 'Content-Length header for '
                + url
                + '.'
                + ' There may be a download problem.'
            )
        move(resume_file, outfile)

    pbar.close()


def make_dir(path):
    """Make the directory and any nonexistent parent directories (`mkdir -p`)."""
    # the current working directory is a fine path
    if path != '':
        os.makedirs(path, exist_ok=True)


def move(path1, path2):
    """Rename the given file."""
    shutil.move(path1, path2)


def remove_dir(path):
    """Remove the given directory, if it exists."""
    shutil.rmtree(path, ignore_errors=True)


def untar(path, fname, deleteTar=True):
    """
    Unpack the given archive file to the same directory.

    :param str path:
        The folder containing the archive. Will contain the contents.

    :param str fname:
        The filename of the archive file.

    :param bool deleteTar:
        If true, the archive will be deleted after extraction.
    """
    print('unpacking ' + fname)
    fullpath = os.path.join(path, fname)
    shutil.unpack_archive(fullpath, path)
    if deleteTar:
        os.remove(fullpath)


def cat(file1, file2, outfile, deleteFiles=True):
    """Concatenate two files to an outfile, possibly deleting the originals."""
    with open(outfile, 'wb') as wfd:
        for f in [file1, file2]:
            with open(f, 'rb') as fd:
                shutil.copyfileobj(fd, wfd, 1024 * 1024 * 10)
                # 10MB per writing chunk to avoid reading big file into memory.
    if deleteFiles:
        os.remove(file1)
        os.remove(file2)


def _get_confirm_token(response):
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            return value
    return None


def download_from_google_drive(gd_id, destination):
    """Use the requests package to download a file from Google Drive."""
    URL = 'https://docs.google.com/uc?export=download'

    with requests.Session() as session:
        response = session.get(URL, params={'id': gd_id}, stream=True)
        token = _get_confirm_token(response)

        if token:
            response.close()
            params = {'id': gd_id, 'confirm': token}
            response = session.get(URL, params=params, stream=True)

        CHUNK_SIZE = 32768
        with open(destination, 'wb') as f:
            for chunk in response.iter_content(CHUNK_SIZE):
                if chunk:  # filter out keep-alive new chunks
                    f.write(chunk)
        response.close()


def get_model_dir(datapath):
    return os.path.join(datapath, 'models')


def download_models(
    opt, fnames, model_folder, version='v1.0', path='aws', use_model_type=False
):
    """
    Download models into the ParlAI model zoo from a url.

    :param fnames: list of filenames to download
    :param model_folder: models will be downloaded into models/model_folder/model_type
    :param path: url for downloading models; defaults to downloading from AWS
    :param use_model_type: whether models are categorized by type in AWS
    """
    model_type = opt.get('model_type', None)
    if model_type is not None:
        dpath = os.path.join(opt['datapath'], 'models', model_folder, model_type)
    else:
        dpath = os.path.join(opt['datapath'], 'models', model_folder)

    if not built(dpath, version):
        for fname in fnames:
            print('[building data: ' + dpath + '/' + fname + ']')
        if built(dpath):
            # An older version exists, so remove these outdated files.
            remove_dir(dpath)
        make_dir(dpath)

        # Download the data.
        for fname in fnames:
            if path == 'aws':
                url = 'http://parl.ai/downloads/_models/'
                url += model_folder + '/'
                if use_model_type:
                    url += model_type + '/'
                url += fname
            else:
                url = path + '/' + fname
            download(url, dpath, fname)
            if '.tgz' in fname or '.gz' in fname or '.zip' in fname:
                untar(dpath, fname)
        # Mark the data as built.
        mark_done(dpath, version)


def modelzoo_path(datapath, path):
    """
    Map pretrain models filenames to their path on disk.

    If path starts with 'models:', then we remap it to the model zoo path
    within the data directory (default is ParlAI/data/models).
    We download models from the model zoo if they are not here yet.
    """
    if path is None:
        return None
    if (
        not path.startswith('models:')
        and not path.startswith('zoo:')
        and not path.startswith('izoo:')
    ):
        return path
    elif path.startswith('models:') or path.startswith('zoo:'):
        zoo = path.split(':')[0]
        zoo_len = len(zoo) + 1
        # Check if we need to download the model
        animal = path[zoo_len : path.rfind('/')].replace('/', '.')
        if '.' not in animal:
            animal += '.build'
        module_name = 'parlai.zoo.{}'.format(animal)
        try:
            my_module = importlib.import_module(module_name)
            my_module.download(datapath)
        except (ImportError, AttributeError):
            pass

        return os.path.join(datapath, 'models', path[zoo_len:])
    else:
        # Internal path (starts with "izoo:") -- useful for non-public
        # projects.  Save the path to your internal model zoo in
        # parlai_internal/.internal_zoo_path
        # TODO: test the internal zoo.
        zoo_path = 'parlai_internal/zoo/.internal_zoo_path'
        if not os.path.isfile('parlai_internal/zoo/.internal_zoo_path'):
            raise RuntimeError(
                'Please specify the path to your internal zoo in the '
                'file parlai_internal/zoo/.internal_zoo_path in your '
                'internal repository.'
            )
        else:
            with open(zoo_path, 'r') as f:
                zoo = f.read().split('\n')[0]
            return os.path.join(zoo, path[5:])


def download_multiprocess(
    urls, path, num_processes=32, chunk_size=100, dest_filenames=None, error_path=None
):
    """
    Download items in parallel (e.g. for an image + dialogue task)

    :param urls: Array of urls to download
    :param path: directory to save items in
    :param num_processes: number of processes to use
    :param chunk_size: chunk size to use
    :param dest_filenames: optional array of same length as url with filenames.
     Images will be saved as path + dest_filename
    :param error_path: where to save error logs
    :return: array of tuples of (destination filename, http status code, error
    message if any). Note that upon failure, file may not actually be created.
    """

    pbar = tqdm.tqdm(total=len(urls), position=0)

    # Resume TODO: isfile() may take too long ?? Should I try in a .tmp file
    if dest_filenames:
        if len(dest_filenames) != len(urls):
            raise Exception(
                'If specified, destination filenames must equal url array in length.'
            )
    else:

        def _naming_fn(url, url_metadata=None):
            return hashlib.md5(url.encode('utf-8')).hexdigest()

        dest_filenames = [_naming_fn(url) for url in urls]

    items = zip(urls, dest_filenames)
    remaining_items = [
        it for it in items if not os.path.isfile(os.path.join(path, it[1]))
    ]
    print(
        f'Of {len(urls)} items, {len(urls) - len(remaining_items)} already existed; only going to download {len(remaining_items)} items.'
    )
    pbar.update(len(urls) - len(remaining_items))

    pool_chunks = (
        (remaining_items[i : i + chunk_size], path, _download_multiprocess_single)
        for i in range(0, len(remaining_items), chunk_size)
    )
    remaining_chunks_count = math.ceil(float(len(remaining_items) / chunk_size))
    print(
        f'Going to download {remaining_chunks_count} chunks with {chunk_size} images per chunk using {num_processes} processes.'
    )

    pbar.desc = 'Downloading'
    all_results = []
    collected_errors = []

    with Pool(num_processes) as pool:
        for idx, chunk_result in enumerate(
            pool.imap_unordered(_download_multiprocess_map_chunk, pool_chunks, 2)
        ):
            all_results.extend(chunk_result)
            for dest_file, http_status_code, error_msg in chunk_result:
                if http_status_code != 200:
                    # msg field available as third item in the tuple
                    # not using b/c error log file would blow up
                    collected_errors.append(
                        {
                            'dest_file': dest_file,
                            'status_code': http_status_code,
                            'error': error_msg,
                        }
                    )
                    print(
                        f'Bad download - chunk: {idx}, dest_file: {dest_file}, http status code: {http_status_code}, error_msg: {error_msg}'
                    )
            pbar.update(len(chunk_result))
    pbar.close()

    if error_path:
        now = time.strftime("%Y%m%d-%H%M%S")
        error_filename = os.path.join(
            error_path, 'parlai_download_multiprocess_errors_%s.log' % now
        )

        with open(os.path.join(error_filename), 'w+') as error_file:
            error_file.write(json.dumps(collected_errors))
            print('Summary of errors written to %s' % error_filename)

    print(
        'Of %s items attempted downloading, %s had errors.'
        % (len(remaining_items), len(collected_errors))
    )

    print('Finished downloading chunks.')
    return all_results


def _download_multiprocess_map_chunk(pool_tup):
    """
    Helper function for Pool imap_unordered.

    Apparently function must be pickable (which apparently means must be
    defined at the top level of a module and can't be a lamdba) to be used in
    imap_unordered. Has to do with how it's passed to the subprocess.

    :param pool_tup: is a tuple where first arg is an array of tuples of url
    and dest file name for the current chunk and second arg is function to be
    called.
    :return: an array of tuples
    """
    items = pool_tup[0]
    path = pool_tup[1]
    fn = pool_tup[2]
    return [fn(it[0], path, it[1]) for it in items]


def _download_multiprocess_single(url, path, dest_fname):
    """
    Helper function to download an individual item.

    Unlike download() above, does not deal with downloading chunks of a big
    file, does not support retries (and does not fail if retries are exhausted).

    :param url: URL to download from
    :param path: directory to save in
    :param dest_fname: destination file name of image
    :return tuple (dest_fname, http status)
    """

    status = None
    error_msg = None
    try:
        # 'User-Agent' header may need to be specified
        headers = {}

        # Use smaller timeout to skip errors, but can result in failed downloads
        response = requests.get(
            url, stream=False, timeout=10, allow_redirects=True, headers=headers
        )
    except Exception as e:
        # Likely a timeout during fetching but had an error in requests.get()
        status = 500
        error_msg = '[Exception during download during fetching] ' + str(e)
        return dest_fname, status, error_msg

    if response.ok:
        try:
            with open(os.path.join(path, dest_fname), 'wb+') as out_file:
                # Some sites respond with gzip transport encoding
                response.raw.decode_content = True
                out_file.write(response.content)
            status = 200
        except Exception as e:
            # Likely a timeout during download or decoding
            status = 500
            error_msg = '[Exception during decoding or writing] ' + str(e)
    else:
        # We get here if there is an HTML error page (i.e. a page saying "404
        # not found" or anything else)
        status = response.status_code
        error_msg = '[Response not OK] Response: %s' % response

    return dest_fname, status, error_msg