import math
import multiprocessing
import os
import queue
import subprocess
from argparse import ArgumentParser, Namespace

import torch
import yaml
from tqdm import tqdm

import utils
from dataset import load_data
from train import build_model
from utils import misc_utils


class DescriptionGenerator(object):
    def __init__(self, config, **opt):
        # Load config used for training and merge with testing options
        self.config = yaml.load(open(config, "r"))
        self.config = Namespace(**{**self.config, **opt})

        # Load training data.pkl for src and tgt vocabs
        self.data = load_data(self.config)

        # Load trained model checkpoints
        device, devices_ids = misc_utils.set_cuda(self.config)
        self.model, _ = build_model(None, self.config, device)
        self.model.eval()

    def predict(self, original_src: list) -> list:
        src_vocab = self.data["src_vocab"]
        tgt_vocab = self.data["tgt_vocab"]
        srcIds = src_vocab.convertToIdx(list(original_src), utils.UNK_WORD)
        src = torch.LongTensor(srcIds).unsqueeze(0)
        src_len = torch.LongTensor([len(srcIds)])

        if self.config.use_cuda:
            src = src.cuda()
            src_len = src_len.cuda()

        with torch.no_grad():

            if self.config.beam_size > 1:
                samples, alignments = self.model.beam_sample(
                    src, src_len, beam_size=self.config.beam_size, eval_=False
                )
            else:
                samples, alignments = self.model.sample(src, src_len)

        assert len(samples) == 1
        candidates = [tgt_vocab.convertToLabels(samples[0], utils.EOS)]

        # Replace unk with src tokens
        if self.config.unk and self.config.attention != "None":
            s = original_src
            c = candidates[0]
            align = alignments[0]
            cand = []
            for word, idx in zip(c, align):
                if word == utils.UNK_WORD and idx < len(s):
                    try:
                        cand.append(s[idx])
                    except:
                        cand.append(word)
                        print("%d %d\n" % (len(s), idx))
                else:
                    cand.append(word)
        return cand


class DescriptionGeneratorProxy(object):
    @staticmethod
    def enqueue_output(out, queue):
        for line in iter(out.readline, b""):
            queue.put(line)
        out.close()

    def __init__(self, gpu_id):
        env = os.environ.copy()
        env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
        self.process = subprocess.Popen(
            ["python", "api.py"],
            env=env,
            stdout=subprocess.PIPE,
            stdin=subprocess.PIPE,
            universal_newlines=True,
        )
        self.stdout_reader = multiprocessing.Queue()
        self.stdout_reader_process = multiprocessing.Process(
            target=DescriptionGeneratorProxy.enqueue_output,
            args=(self.process.stdout, self.stdout_reader),
            daemon=True,
        )
        self.stdout_reader_process.start()

    def send(self, src):
        self.process.stdin.write(f"{src}\n")
        self.process.stdin.flush()

    def recv(self, timeout=None):
        try:
            stdout = self.stdout_reader.get(timeout=timeout)
            return stdout.strip()
        except queue.Empty:
            return ""

    def flush(self):
        while True:
            line = self.recv()
            if line == "COMPLETE":
                break


class DescriptionGeneratorMultiprocessing(object):
    def __init__(self, n_gpus=8, n_process_per_gpu=8, **kwargs):
        self.proxies = []
        for gpu_id in range(n_gpus):
            for _ in range(n_process_per_gpu):
                self.proxies.append(DescriptionGeneratorProxy(gpu_id))
        for proxy in self.proxies:
            proxy.flush()

    def _predict_batch(self, src_list):
        """Batch size = n_gpu * n_process_per_gpu"""
        assert len(src_list) <= len(self.proxies)
        for proxy, src in zip(self.proxies, src_list):
            proxy.send(src)
        return [proxy.recv() for proxy, src in zip(self.proxies, src_list)]

    def predict_all(self, src_list):
        tgt_list = []
        for idx in tqdm(range(int(math.ceil(len(src_list) / len(self.proxies))))):
            tgt_list += self._predict_batch(
                src_list[idx * len(self.proxies) : (idx + 1) * len(self.proxies)]
            )
        return tgt_list


if __name__ == "__main__":
    g = DescriptionGenerator(
        config="yaml/title_summary_item_filter_t2t.yaml",
        gpus="0",
        restore=False,
        pretrain="experiments/3.8-finetune-big/best_bleu_checkpoint.pt",
        mode="eval",
        batch_size=1,
        beam_size=10,
        # refactor issue; workaround; delete afterwards:
        scale=1,
        char=False,
        use_cuda=True,
        seed=1234,
        model="tensor2tensor",
    )
    # For testing
    print("".join(g.predict(list("这东西真智障"))))
    # Interactive interface for multiprocessing
    print("COMPLETE")
    while True:
        src = input()
        print("".join(g.predict(list(src))))