import argparse
import logging
import os
import sys
import time
import uuid
import warnings
from collections import OrderedDict


import zmq
from termcolor import colored
from zmq.utils import jsonapi

__all__ = ['set_logger', 'send_ndarray', 'get_args_parser',
           'check_tf_version', 'auto_bind', 'import_tf', 'TimeContext', 'CappedHistogram']


def set_logger(context, verbose=False):
    if os.name == 'nt':  # for Windows
        return NTLogger(context, verbose)

    logger = logging.getLogger(context)
    logger.setLevel(logging.DEBUG if verbose else logging.INFO)
    formatter = logging.Formatter(
        '%(levelname)-.1s:' + context + ':[%(filename).3s:%(funcName).3s:%(lineno)3d]:%(message)s', datefmt=
        '%m-%d %H:%M:%S')
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)
    console_handler.setFormatter(formatter)
    logger.handlers = []
    logger.addHandler(console_handler)
    return logger


class NTLogger:
    def __init__(self, context, verbose):
        self.context = context
        self.verbose = verbose

    def info(self, msg, **kwargs):
        print('I:%s:%s' % (self.context, msg), flush=True)

    def debug(self, msg, **kwargs):
        if self.verbose:
            print('D:%s:%s' % (self.context, msg), flush=True)

    def error(self, msg, **kwargs):
        print('E:%s:%s' % (self.context, msg), flush=True)

    def warning(self, msg, **kwargs):
        print('W:%s:%s' % (self.context, msg), flush=True)


def send_ndarray(src, dest, X, req_id=b'', flags=0, copy=True, track=False):
    """send a numpy array with metadata"""
    md = dict(dtype=str(X.dtype), shape=X.shape)
    return src.send_multipart([dest, jsonapi.dumps(md), X, req_id], flags, copy=copy, track=track)


def check_max_seq_len(value):
    if value is None or value.lower() == 'none':
        return None
    try:
        ivalue = int(value)
        if ivalue <= 3:
            raise argparse.ArgumentTypeError("%s is an invalid int value must be >3 "
                                             "(account for maximum three special symbols in BERT model) or NONE" % value)
    except TypeError:
        raise argparse.ArgumentTypeError("%s is an invalid int value" % value)
    return ivalue


def get_args_parser():
    from . import __version__
    from .graph import PoolingStrategy

    parser = argparse.ArgumentParser(description='Start a BertServer for serving')

    group1 = parser.add_argument_group('File Paths',
                                       'config the path, checkpoint and filename of a pretrained/fine-tuned BERT model')
    group1.add_argument('-model_dir', type=str, required=True,
                        help='directory of a pretrained BERT model')
    group1.add_argument('-tuned_model_dir', type=str,
                        help='directory of a fine-tuned BERT model')
    group1.add_argument('-ckpt_name', type=str, default='bert_model.ckpt',
                        help='filename of the checkpoint file. By default it is "bert_model.ckpt", but \
                             for a fine-tuned model the name could be different.')
    group1.add_argument('-config_name', type=str, default='bert_config.json',
                        help='filename of the JSON config file for BERT model.')
    group1.add_argument('-graph_tmp_dir', type=str, default=None,
                        help='path to graph temp file')

    group2 = parser.add_argument_group('BERT Parameters',
                                       'config how BERT model and pooling works')
    group2.add_argument('-max_seq_len', type=check_max_seq_len, default=25,
                        help='maximum length of a sequence, longer sequence will be trimmed on the right side. '
                             'set it to NONE for dynamically using the longest sequence in a (mini)batch.')
    group2.add_argument('-cased_tokenization', dest='do_lower_case', action='store_false', default=True,
                        help='Whether tokenizer should skip the default lowercasing and accent removal.'
                             'Should be used for e.g. the multilingual cased pretrained BERT model.')
    group2.add_argument('-pooling_layer', type=int, nargs='+', default=[-2],
                        help='the encoder layer(s) that receives pooling. \
                        Give a list in order to concatenate several layers into one')
    group2.add_argument('-pooling_strategy', type=PoolingStrategy.from_string,
                        default=PoolingStrategy.REDUCE_MEAN, choices=list(PoolingStrategy),
                        help='the pooling strategy for generating encoding vectors')
    group2.add_argument('-mask_cls_sep', action='store_true', default=False,
                        help='masking the embedding on [CLS] and [SEP] with zero. \
                        When pooling_strategy is in {CLS_TOKEN, FIRST_TOKEN, SEP_TOKEN, LAST_TOKEN} \
                        then the embedding is preserved, otherwise the embedding is masked to zero before pooling')
    group2.add_argument('-no_special_token', action='store_true', default=False,
                        help='add [CLS] and [SEP] in every sequence, \
                        put sequence to the model without [CLS] and [SEP] when True and \
                        is_tokenized=True in Client')
    group2.add_argument('-show_tokens_to_client', action='store_true', default=False,
                        help='sending tokenization results to client')
    group2.add_argument('-no_position_embeddings', action='store_true', default=False,
                        help='Whether to add position embeddings for the position of each token in the sequence.')
    group2.add_argument('-num_labels', type=int, default=2,
                        help='Numbers of Label')
    
    group3 = parser.add_argument_group('Serving Configs',
                                       'config how server utilizes GPU/CPU resources')
    group3.add_argument('-port', '-port_in', '-port_data', type=int, default=5555,
                        help='server port for receiving data from client')
    group3.add_argument('-port_out', '-port_result', type=int, default=5556,
                        help='server port for sending result to client')
    group3.add_argument('-http_port', type=int, default=None,
                        help='server port for receiving HTTP requests')
    group3.add_argument('-http_max_connect', type=int, default=10,
                        help='maximum number of concurrent HTTP connections')
    group3.add_argument('-cors', type=str, default='*',
                        help='setting "Access-Control-Allow-Origin" for HTTP requests')
    group3.add_argument('-num_worker', type=int, default=1,
                        help='number of server instances')
    group3.add_argument('-max_batch_size', type=int, default=256,
                        help='maximum number of sequences handled by each worker')
    group3.add_argument('-priority_batch_size', type=int, default=16,
                        help='batch smaller than this size will be labeled as high priority,'
                             'and jumps forward in the job queue')
    group3.add_argument('-cpu', action='store_true', default=False,
                        help='running on CPU (default on GPU)')
    group3.add_argument('-xla', action='store_true', default=False,
                        help='enable XLA compiler (experimental)')
    group3.add_argument('-fp16', action='store_true', default=False,
                        help='use float16 precision (experimental)')
    group3.add_argument('-gpu_memory_fraction', type=float, default=0.5,
                        help='determine the fraction of the overall amount of memory \
                        that each visible GPU should be allocated per worker. \
                        Should be in range [0.0, 1.0]')
    group3.add_argument('-device_map', type=int, nargs='+', default=[],
                        help='specify the list of GPU device ids that will be used (id starts from 0). \
                        If num_worker > len(device_map), then device will be reused; \
                        if num_worker < len(device_map), then device_map[:num_worker] will be used')
    group3.add_argument('-prefetch_size', type=int, default=10,
                        help='the number of batches to prefetch on each worker. When running on a CPU-only machine, \
                        this is set to 0 for comparability')
    group3.add_argument('-fixed_embed_length', action='store_true', default=False,
                        help='when "max_seq_len" is set to None, the server determines the "max_seq_len" according to '
                             'the actual sequence lengths within each batch. When "pooling_strategy=NONE", '
                             'this may cause two ".encode()" from the same client results in different sizes [B, T, D].'
                             'Turn this on to fix the "T" in [B, T, D] to "max_position_embeddings" in bert json config.')

    parser.add_argument('-verbose', action='store_true', default=False,
                        help='turn on tensorflow logging for debug')
    parser.add_argument('-version', action='version', version='%(prog)s ' + __version__)
    return parser


def check_tf_version():
    import tensorflow as tf
    tf_ver = tf.__version__.split('.')
    if int(tf_ver[0]) <= 1 and int(tf_ver[1]) < 10:
        raise ModuleNotFoundError('Tensorflow >=1.10 (one-point-ten) is required!')
    elif int(tf_ver[0]) > 1:
        warnings.warn('Tensorflow %s is not tested! It may or may not work. '
                      'Feel free to submit an issue at https://github.com/hanxiao/bert-as-service/issues/' % tf.__version__)
    return tf_ver


def import_tf(device_id=-1, verbose=False, use_fp16=False):
    os.environ['CUDA_VISIBLE_DEVICES'] = '-1' if device_id < 0 else str(device_id)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' if verbose else '3'
    os.environ['TF_FP16_MATMUL_USE_FP32_COMPUTE'] = '0' if use_fp16 else '1'
    os.environ['TF_FP16_CONV_USE_FP32_COMPUTE'] = '0' if use_fp16 else '1'
    import tensorflow as tf
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG if verbose else tf.compat.v1.logging.ERROR)
    return tf


def auto_bind(socket):
    if os.name == 'nt':  # for Windows
        socket.bind_to_random_port('tcp://127.0.0.1')
    else:
        # Get the location for tmp file for sockets
        try:
            tmp_dir = os.environ['ZEROMQ_SOCK_TMP_DIR']
            if not os.path.exists(tmp_dir):
                raise ValueError('This directory for sockets ({}) does not seems to exist.'.format(tmp_dir))
            tmp_dir = os.path.join(tmp_dir, str(uuid.uuid1())[:8])
        except KeyError:
            tmp_dir = '*'

        socket.bind('ipc://{}'.format(tmp_dir))
    return socket.getsockopt(zmq.LAST_ENDPOINT).decode('ascii')


def get_run_args(parser_fn=get_args_parser, printed=True):
    args = parser_fn().parse_args()
    if printed:
        param_str = '\n'.join(['%20s = %s' % (k, v) for k, v in sorted(vars(args).items())])
        print('usage: %s\n%20s   %s\n%s\n%s\n' % (' '.join(sys.argv), 'ARG', 'VALUE', '_' * 50, param_str))
    return args


def get_benchmark_parser():
    parser = get_args_parser()
    parser.description = 'Benchmark BertServer locally'

    parser.set_defaults(num_client=1, client_batch_size=4096)

    group = parser.add_argument_group('Benchmark parameters', 'config the experiments of the benchmark')

    group.add_argument('-test_client_batch_size', type=int, nargs='*', default=[1, 16, 256, 4096])
    group.add_argument('-test_max_batch_size', type=int, nargs='*', default=[8, 32, 128, 512])
    group.add_argument('-test_max_seq_len', type=int, nargs='*', default=[32, 64, 128, 256])
    group.add_argument('-test_num_client', type=int, nargs='*', default=[1, 4, 16, 64])
    group.add_argument('-test_pooling_layer', type=int, nargs='*', default=[[-j] for j in range(1, 13)])

    group.add_argument('-wait_till_ready', type=int, default=30,
                       help='seconds to wait until server is ready to serve')
    group.add_argument('-client_vocab_file', type=str, default='README.md',
                       help='file path for building client vocabulary')
    group.add_argument('-num_repeat', type=int, default=10,
                       help='number of repeats per experiment (must >2), '
                            'as the first two results are omitted for warm-up effect')
    return parser


def get_shutdown_parser():
    parser = argparse.ArgumentParser()
    parser.description = 'Shutting down a BertServer instance running on a specific port'

    parser.add_argument('-ip', type=str, default='localhost',
                        help='the ip address that a BertServer is running on')
    parser.add_argument('-port', '-port_in', '-port_data', type=int, required=True,
                        help='the port that a BertServer is running on')
    parser.add_argument('-timeout', type=int, default=5000,
                        help='timeout (ms) for connecting to a server')
    return parser


class TimeContext:
    def __init__(self, msg):
        self._msg = msg

    def __enter__(self):
        self.start = time.perf_counter()
        print(self._msg, end=' ...\t', flush=True)

    def __exit__(self, typ, value, traceback):
        self.duration = time.perf_counter() - self.start
        print(colored('    [%3.3f secs]' % self.duration, 'green'), flush=True)

class CappedHistogram:
    """Space capped dict with aggregate stat tracking.

    Evicts using LRU policy when at capacity; evicted elements are added to aggregate stats.
    Arguments:
    capacity -- the capacity limit of the dict
    """
    def __init__(self, capacity):
        self.cache = OrderedDict()
        self.capacity = capacity
        self.base_bins = 0
        self.base_count = 0
        self.base_min = float('inf')
        self.min_count = 0
        self.base_max = 0
        self.max_count = 0

    def __getitem__(self, key):
        if key in self.cache:
            return self.cache[key]
        return 0

    def __setitem__(self, key, value):
        if key in self.cache:
            del self.cache[key]
        self.cache[key] = value
        if len(self.cache) > self.capacity:
            self._evict()

    def total_size(self):
        return self.base_bins + len(self.cache)

    def __len__(self):
        return len(self.cache)

    def values(self):
        return self.cache.values()

    def _evict(self):
        key,val = self.cache.popitem(False)
        self.base_bins += 1
        self.base_count += val
        if val < self.base_min:
            self.base_min = val
            self.min_count = 1
        elif val == self.base_min:
            self.min_count += 1
        if val > self.base_max:
            self.base_max = val
            self.max_count = 1
        elif val == self.base_max:
            self.max_count += 1

    def get_stat_map(self, name):
        if len(self.cache) == 0:
            return {}
        counts = self.cache.values()
        avg = (self.base_count + sum(counts)) / (self.base_bins + len(counts))
        min_, max_ = min(counts), max(counts)
        num_min, num_max = 0, 0
        if self.base_min <= min_:
            min_ = self.base_min
            num_min += self.min_count
        if self.base_min >= min_:
            num_min += sum(v == min_ for v in counts)

        if self.base_max >= max_:
            max_ = self.base_max
            num_max += self.max_count
        if self.base_max <= max_:
            num_max += sum(v == max_ for v in counts)

        return {
            'avg_%s' % name: avg,
            'min_%s' % name: min_,
            'max_%s' % name: max_,
            'num_min_%s' % name: num_min,
            'num_max_%s' % name: num_max,
        }