import os
import shutil
import urllib.request
from pathlib import Path
import sys
import progressbar

BASE_URL = 'https://storage.googleapis.com/at16k-ce/models'
AVAILABLE_MODELS = ['en_8k', 'en_16k']

PROGRESS_BAR = None


def show_progress(block_num, block_size, total_size):
    global PROGRESS_BAR
    if PROGRESS_BAR is None:
        PROGRESS_BAR = progressbar.ProgressBar(maxval=total_size)
        PROGRESS_BAR.start()

    downloaded = block_num * block_size
    if downloaded < total_size:
        PROGRESS_BAR.update(downloaded)
    else:
        PROGRESS_BAR.finish()
        PROGRESS_BAR = None


def setup_home():
    """
    Init home directory to store assets and models
    """
    if 'AT16K_RESOURCES_DIR' in os.environ:
        at16k_model_dir = os.environ['AT16K_RESOURCES_DIR']
    else:
        home_dir = str(Path.home())
        at16k_model_dir = os.path.join(home_dir, '.at16k')
    if not os.path.exists(at16k_model_dir):
        os.mkdir(at16k_model_dir)
    return at16k_model_dir


def download_model(remote_path, local_path):
    """
    Download file
    """
    if not os.path.exists(local_path):
        print('Downloading from %s' % remote_path)
        urllib.request.urlretrieve(remote_path, local_path, show_progress)


def unarchive(local_path, base_dir):
    """
    Unarchive zipped file
    """
    shutil.unpack_archive(local_path, base_dir)


def main():
    """
    Main
    """
    assert len(
        sys.argv) > 1, ('Please specify model name: one of en_8k, en_16k, all')
    name = sys.argv[1]
    if name == 'all':
        name = AVAILABLE_MODELS
    else:
        assert name in AVAILABLE_MODELS, (
            'Please specify a valid model name: one of en_8k, en_16k, all')
        name = [name]
    base_dir = setup_home()
    for item in name:
        file_name = '%s.tar.gz' % item
        remote_path = os.path.join(BASE_URL, file_name)
        local_path = os.path.join(base_dir, file_name)
        download_model(remote_path, local_path)
        unarchive(local_path, base_dir)
        os.remove(local_path)
        print('Downloaded model: %s' % item)


if __name__ == '__main__':
    main()