#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import argparse
import os
import torch
from subprocess import check_call, check_output
from glob import glob
from tempfile import NamedTemporaryFile as TempFile
import time
import subprocess
import multiprocessing as mp
from utils import check_last_line, count_line
import tqdm


def translate_files_slurm(args, cmds, expected_output_files):
    conda_env = '/private/home/pipibjc/.conda/envs/fairseq-20190509'
    for cmd in cmds:
        with TempFile('w') as script:
            sh = f"""#!/bin/bash
            source activate {conda_env}
            {cmd}
            """
            print(sh)
            script.write(sh)
            script.flush()
            cmd = f"sbatch --gres=gpu:1 -c {args.cpu + 2} {args.sbatch_args} --time=15:0:0 {script.name}"
            import sys
            print(cmd, file=sys.stderr)
            check_call(cmd, shell=True)

    # wait for all outputs has finished
    num_finished = 0
    while num_finished < len(expected_output_files):
        num_finished = 0
        for output_file in expected_output_files:
            num_finished += 1 if check_finished(output_file) else 0
        if num_finished < len(expected_output_files):
            time.sleep(3 * 60)
            print("sleeping for 3m ...")


def check_finished(output_file):
    return check_last_line(output_file, "finished")


def get_output_file(dest_dir, file):
    return f"{dest_dir}/{os.path.basename(file)}.log"


def translate(arg_list):
    (q, cmd) = arg_list
    i = q.get()
    os.environ['CUDA_VISIBLE_DEVICES']=str(i)
    cmd = f"CUDA_VISIBLE_DEVICES={i} {cmd}"
    print(f"executing:\n{cmd}")
    check_call(cmd, shell=True)
    q.put(i)


def translate_files_local(args, cmds):
    m = mp.Manager()
    gpu_queue = m.Queue()
    for i in args.cuda_visible_device_ids:
        gpu_queue.put(i)
    with mp.Pool(processes=len(args.cuda_visible_device_ids)) as pool:
        for _ in tqdm.tqdm(pool.imap_unordered(translate, [(gpu_queue, cmd) for cmd in cmds]), total=len(cmds)):
            pass


def translate_files(args, dest_dir, input_files):
    cmd_template = f"""fairseq-interactive \
        {args.databin} \
        --source-lang {args.source_lang} --target-lang {args.target_lang} \
        --path {args.model} \
        --lenpen {args.lenpen} \
        --max-len-a {args.max_len_a} \
        --max-len-b {args.max_len_b} \
        --buffer-size {args.buffer_size} \
        --max-tokens {args.max_tokens} \
        --num-workers {args.cpu} > {{output_file}} && \
    echo "finished" >> {{output_file}}
    """
    cmds = []
    expected_output_files = []
    for input_file in input_files:
        output_file = get_output_file(dest_dir, input_file)
        cmds.append(f"cat {input_file} | " + cmd_template.format(output_file=output_file))
        expected_output_files.append(output_file)
    if args.backend == 'local':
        translate_files_local(args, cmds)
    elif args.backend == 'slurm':
        translate_files_slurm(args, cmds, expected_output_files)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', '-d', required=True, help='Path to file to translate')
    parser.add_argument('--model', '-m', required=True, help='Model checkpoint')
    parser.add_argument('--lenpen', default=1.2, type=float, help='Length penalty')
    parser.add_argument('--beam', default=5, type=int, help='Beam size')
    parser.add_argument('--max-len-a', type=float, default=0, help='max-len-a parameter when back-translating')
    parser.add_argument('--max-len-b', type=int, default=200, help='max-len-b parameter when back-translating')
    parser.add_argument('--cpu', type=int, default=4, help='Number of CPU for interactive.py')
    parser.add_argument('--cuda-visible-device-ids', '-gids', default=None, nargs='*', help='List of cuda visible devices ids, camma separated')
    parser.add_argument('--dest', help='Output path for the intermediate and translated file')
    parser.add_argument('--max-tokens', type=int, default=12000, help='max tokens')
    parser.add_argument('--buffer-size', type=int, default=10000, help='Buffer size')
    parser.add_argument('--chunks', type=int, default=100)
    parser.add_argument('--source-lang', type=str, default=None, help='Source langauge. Will inference from the model if not set')
    parser.add_argument('--target-lang', type=str, default=None, help='Target langauge. Will inference from the model if not set')
    parser.add_argument('--databin', type=str, default=None, help='Parallel databin. Will combine with the back-translated databin')
    parser.add_argument('--sbatch-args', default='', help='Extra SBATCH arguments')

    parser.add_argument('--backend', type=str, default='local', choices=['local', 'slurm'])
    args = parser.parse_args()

    args.cuda_visible_device_ids = args.cuda_visible_device_ids or list(range(torch.cuda.device_count()))

    chkpnt = torch.load(args.model)
    model_args = chkpnt['args']
    if args.source_lang is None or args.target_lang is None:
        args.source_lang = args.source_lang or model_args.source_lang
        args.target_lang = args.target_lang or model_args.target_lang
    if args.databin is None:
        args.databin = args.databin or model_args.data

    root_dir = os.path.dirname(os.path.realpath(__file__))
    translation_dir = os.path.join(args.dest or root_dir, 'translations', f'{args.source_lang}-{args.target_lang}')

    tempdir = os.path.join(translation_dir, 'splits')
    os.makedirs(tempdir, exist_ok=True)
    split_files = glob(f'{tempdir}/mono_data*')

    if len(split_files) != args.chunks:
        if len(split_files) != 0:
            print("number of split files are not the same as chunks. removing files and re-split")
            [os.remove(os.path.join(tempdir, f)) for f in os.listdir(tempdir)]
        print("splitting files ...")
        check_call(f'split -n "r/{args.chunks}" -a3 -d {args.data} {tempdir}/mono_data', shell=True)
        split_files = glob(f'{tempdir}/mono_data*')
    else:
        print("has the same number of splitted file and the specified chunks, skip splitting file")

    translated_files = []
    files_to_translate = []
    for file in split_files:
        # skip the translation job if it's finished
        output_file = get_output_file(translation_dir, file)
        translated_files.append(output_file)
        if check_finished(output_file):
            print(f"{output_file} is translated")
            continue
        files_to_translate.append(file)

    print(f"{len(files_to_translate)} files to translate")

    translate_files(args, translation_dir, files_to_translate)

    # aggregate translated files
    generated_src = f'{args.dest}/generated.src'
    generated_tgt = f'{args.dest}/generated.hypo'
    if count_line(generated_src) != count_line(generated_tgt) or count_line(generated_src) <= 0:
        print(f"aggregating translated {len(translated_files)} files")
        with TempFile() as fout:
            files = " ".join(translated_files)
            check_call(f"cat {files}", shell=True, stdout=fout)
            # strip head and make pairs
            check_call(f'cat {fout.name} | grep "^S" | cut -f2 > {generated_src}', shell=True)
            check_call(f'cat {fout.name} | grep "^H" | cut -f3 > {generated_tgt}', shell=True)
    assert count_line(generated_src) == count_line(generated_tgt)
    print(f"output generated files to {generated_src}, {generated_tgt}")


if __name__ == '__main__':
    main()