from __future__ import absolute_import
import json
from io import StringIO
import logging
import os
import signal
import sys
import warnings
import uuid
import datetime
import random

import numpy as np
from gunicorn.six import iteritems
from pkg_resources import iter_entry_points as iep
import flask
import gunicorn.app.base

try:
    import http.client as httplib  # python 3
except ImportError:
    import httplib  # python 2
import werkzeug

from vw_serving.utils import dynamic_import
from vw_serving.sagemaker.gpu import get_num_gpus
from vw_serving.sagemaker import integration as integ
from vw_serving.sagemaker.error_handler import report_batch_inference_sdk_error, report_online_inference_sdk_error
from vw_serving.sagemaker.exceptions import convert_to_algorithm_error, raise_with_traceback, CustomerError, \
    AlgorithmError
from vw_serving.sagemaker.config.server_config import BaseServerConfig
import vw_serving.sagemaker.config.environment as environment

from vw_serving.vw_model import VWModel

# TODO: Add metrics publishing
# from vw_serving.metrics import metrics_wrapper
# from vw_serving.metrics.metrics import MetricsFactory

CONTENT_TYPE_JSON = 'application/json'
CONTENT_TYPE_JSONLINES = 'application/jsonlines'
CONTENT_TYPE_CSV = 'text/csv'
CONTENT_TYPE_RECORDIO = 'application/x-recordio-protobuf'

REDIS_PUBLISHER_CHANNEL = "EXPERIENCES"
KNOWN_CLI_ARGS = ['-r', '--resources', '-w']

MODEL_DIR = integ.ARTIFACTS_VOLUME


class InferenceCustomerError(CustomerError):
    def public_failure_message(self):
        return self.get_error_summary()


class InferenceAlgorithmError(CustomerError):
    def public_failure_message(self):
        return self.get_error_summary()


class GunicornApplication(gunicorn.app.base.Application):
    """Gunicorn application

    By extending base.Application, this class gets access to configuration via the
    environment variable GUNICORN_CMD_ARGS and command line. This env variable gives the flexibility
    to configure gunicorn per-endpoint.

    See: http://docs.gunicorn.org/en/stable/settings.html

    See: https://code.amazon.com/packages/Gunicorn/blobs/5ea7b077710253db3ed6676525090ec04c05e4b8/--/gunicorn/app/base.py#L158 # noqa: E501
    """

    def __init__(self, app, options=None):
        self.options = options or {}
        self.application = app
        super(GunicornApplication, self).__init__()

    def load_vars(self):
        """Load env and command line vars.

        This method has to be copied because options run from brazil are also passed to Gunicorn, such as
        '--resource' or '-r' which is required for local training and serving. This failed the parser, which was
        using parse_args(). Copied and modified from superclass.

        See: https://code.amazon.com/packages/Gunicorn/blobs/5ea7b077710253db3ed6676525090ec04c05e4b8/--/gunicorn/app/base.py#L137 # noqa: E501
        """

        def check_unknown_args(args):
            """Check arguments Gunicorn is not aware of, and remove the argument names
             and values the application is aware of. Fail if there are any arguments remaining.

            :param args: Remainder arguments from Gunicorn argparser
            """
            unrecognized_args = {arg for arg in args if arg.startswith("-") and arg not in KNOWN_CLI_ARGS}
            if unrecognized_args:
                raise AlgorithmError("Unrecognized variables used: {}".format(", ".join(unrecognized_args)))

        def set_config(args, cfg):
            for k, v in vars(args).items():
                if v is None:
                    continue
                if k == "args":
                    continue
                cfg.set(k.lower(), v)

        parser = self.cfg.parser()
        parsed_args = parser.parse_known_args()
        known_args = parsed_args[0]
        unknown_args = parsed_args[1]

        check_unknown_args(unknown_args)

        env_vars = self.cfg.get_cmd_args_from_env()
        if env_vars:
            env_args = parser.parse_args(env_vars)
            set_config(env_args, self.cfg)

        # Lastly, update the configuration with any command line
        # settings. Note that identical env vars will not be updated.
        set_config(known_args, self.cfg)

    def load_config(self):
        """Load configuration.

        Configuration is loaded from the options passed in the constructor,
        and then load arguments from the CLI and the environment variable GUNICORN_CMD_ARGS.
        Arguments in the environment variable take precedence.

        See http://docs.gunicorn.org/en/stable/settings.html
        """
        for key, value in iteritems(self.options):
            key = key.lower()
            if key in self.cfg.settings and value is not None:
                self.cfg.set(key, value)
        self.load_vars()

    def load(self):
        return self.application


class ScoringService(object):
    PORT = os.getenv(environment.SAGEMAKER_BIND_TO_PORT, "8080")

    # NOTE: 6 MB max content length
    MAX_CONTENT_LENGTH = os.getenv(environment.MAX_CONTENT_LENGTH, 6 * 1024 * 1024)

    BATCH_INFERENCE = os.getenv(environment.SAGEMAKER_BATCH, 'false') == 'true'

    EIA_PRESENT = os.getenv(environment.SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT, 'false') == 'true'

    DEFAULT_INVOCATIONS_ACCEPT = os.getenv(environment.SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT, "")

    LOG_INFERENCE_DATA = os.getenv(environment.LOG_INFERENCE_DATA, 'true').lower() == 'true'

    app = flask.Flask(__name__)
    request_iterators = {}
    response_encoders = {}

    _model_class = None
    _model_id = None
    _server_config = None
    _model = None
    _redis_client = None

    @classmethod
    def _report_sdk_error(cls, sdk_error):
        """Report SDK error.

        Use different mechanisms for online inference and batch inference.
        """
        if cls.BATCH_INFERENCE:
            report_batch_inference_sdk_error(sdk_error)
        else:
            report_online_inference_sdk_error(sdk_error)

    @classmethod
    def _load_class_entry_point(cls, group, name):
        entry_points = tuple(iep(group=group, name=name))
        if not entry_points:
            return None

        class_entry_point = entry_points[0]
        loaded_class = class_entry_point.load()
        cls.app.logger.info("loaded entry point class %s:%s", group, class_entry_point.name)
        return loaded_class

    @classmethod
    def eia_enabled(cls):
        return cls._server_config.eia_compatible and cls.EIA_PRESENT and get_num_gpus() <= 0

    @classmethod
    def _load_pre_worker_entry_points(cls):
        try:
            server_config_class = cls._load_class_entry_point("algorithm.serve.server_config", "config_api")
            if server_config_class:
                cls._server_config = server_config_class()
        except Exception:
            cls.app.logger.exception("Unable to load server_config entry point")

        if cls._server_config is None:
            cls._server_config = BaseServerConfig()

        cls.app.logger.info("loading entry points")
        for entry_point in iep(group="algorithm.io.data_handlers.serve"):
            cls.request_iterators[entry_point.name] = entry_point.load()
            cls.app.logger.info("loaded request iterator %s", entry_point.name)

        for entry_point in iep(group="algorithm.request_iterators"):
            warnings.warn("entrypoint algorithm.request_iterators is deprecated "
                          "in favor of algorithm.io.data_handlers.serve", DeprecationWarning)
            cls.request_iterators[entry_point.name] = entry_point.load()
            cls.app.logger.info("loaded request iterator %s", entry_point.name)

        for entry_point in iep(group="algorithm.response_encoders"):
            cls.response_encoders[entry_point.name] = entry_point.load()
            cls.app.logger.info("loaded response encoder %s", entry_point.name)

        try:
            cls._model_class = cls._load_class_entry_point("algorithm", "model")
        except Exception as e:
            raise_with_traceback(InferenceAlgorithmError("Unable to load algorithm.model entry point", caused_by=e))

    @classmethod
    def get_model(cls):
        if cls._model is None:
            try:
                import redis
                redis_client = redis.Redis()
                cls._model_id = redis_client.get("model_id").decode()
                model_weights_loc = redis_client.get("{}:weights".format(cls._model_id)).decode()
                model_metadata_loc = redis_client.get("{}:metadata".format(cls._model_id)).decode()
                cls._model = VWModel.load_vw_model(metadata_loc=model_metadata_loc,
                                                   weights_loc=model_weights_loc,
                                                   test_only=True,
                                                   quiet_mode=True)
                cls.app.logger.info(f"Loaded weights successfully for Model ID:{cls._model_id}")
            except Exception as e:
                raise_with_traceback(InferenceCustomerError("Unable to load model", caused_by=e))
        return cls._model

    @classmethod
    def _get_server_config(cls):
        if not cls._server_config:
            cls._load_pre_worker_entry_points()
        return cls._server_config

    @classmethod
    def get_transform_configuration(cls):
        """Get transform (batch inference) configuration.

        :return: (dict) a dictionary with two entries:
          max_concurrent_transforms: (int) number of concurrent transform requests to send.
            Platform can use this value to send appropriate number of concurrent requests to the container.
          max_payload_size: (int) maximum size of payload on /invocation request in bytes.
        """
        server_config = cls._get_server_config()

        return {
            'max_concurrent_transforms': server_config.max_concurrent_transforms,
            'batch_strategy': server_config.batch_strategy,
            'max_payload_size': cls.MAX_CONTENT_LENGTH,
        }

    @staticmethod
    def _post_worker_init(worker):
        """
        Gunicorn server hook http://docs.gunicorn.org/en/stable/settings.html#post-worker-init
        :param worker:
        """
        # Model is being loaded per worker because each worker communicates through PIPE with the VW C++ CLI
        try:
            if ScoringService.LOG_INFERENCE_DATA:
                import redis
                if ScoringService._redis_client is None:
                    ScoringService._redis_client = redis.Redis()
                    ScoringService.app.logger.info("Initiated redis client!")
        except Exception as e:
            sdk_error = convert_to_algorithm_error(e)
            ScoringService._report_sdk_error(sdk_error)
            sys.exit(sdk_error.exit_code)

        try:
            ScoringService.get_model()
            ScoringService._model.start()
        except Exception as e:
            sdk_error = convert_to_algorithm_error(e)
            ScoringService._report_sdk_error(sdk_error)
            sys.exit(sdk_error.exit_code)

    @staticmethod
    def _worker_exit(server, worker):
        """Do not cleanup resources on exit when memory profiler is enabled.

        Memory profiler imports multiprocessing module which causes exceptions on exit
        when used in the same process with fork().

        See:
          https://github.com/benoitc/gunicorn/issues/1391
          https://stackoverflow.com/questions/37692262
        """
        if os.getenv(environment.ENABLE_PROFILER):
            os._exit(0)

    @classmethod
    def parse_content_type(cls, content_type):
        content_type = content_type.lower() if content_type else CONTENT_TYPE_JSON

        tokens = content_type.split(";")
        content_type = tokens[0].strip()
        parameters = {}
        for token in tokens[1:]:
            key, value = token.split("=")
            key = key.strip()
            value = value.strip()
            parameters[key] = value

        if "shape" in parameters:
            parameters["shape"] = [int(s_i) for s_i in parameters["shape"].split(",")]

        return content_type, parameters

    @classmethod
    def get_num_workers(cls):
        forced_num_workers = int(os.getenv('NUM_WORKERS', 0))

        if forced_num_workers > 0:
            return forced_num_workers
        if cls.eia_enabled():
            return 1
        else:
            return int(cls._get_server_config().number_of_workers)

    @classmethod
    def _initialize(cls, daemon=False):
        cls._load_pre_worker_entry_points()

        # NOTE: Stop Flask application when SIGTERM is received as a result of "docker stop" command.
        signal.signal(signal.SIGTERM, cls.stop)

        nworkers = cls.get_num_workers()
        timeout = int(cls._get_server_config().timeout)

        cls.app.config["MAX_CONTENT_LENGTH"] = cls.MAX_CONTENT_LENGTH
        gunicorn_options = {
            "bind": "{}:{}".format("0.0.0.0", cls.PORT),
            "workers": min(nworkers, 4),
            "timeout": timeout,
            "worker_exit": cls._worker_exit,
            "pidfile": environment.PIDFILE,
            "daemon": daemon
        }
        # If prefork is chosen, call the model loading function immediately, otherwise register the model
        # loading function in the options map, os that the worker processes can call it after being forked.
        # if cls._server_config.prefork_load_model or True:
        # cls._post_worker_init(None)
        # else:

        # pre-fork model
        # cls._post_worker_init(None)

        gunicorn_options["post_worker_init"] = cls._post_worker_init
        cls.app.logger.info("Number of server workers: %s", nworkers)

        return gunicorn_options

    @classmethod
    def start(cls, daemon=False):
        integ.setup_logging()
        integ.write_trusted_log_info("worker started")

        try:
            options = cls._initialize(daemon=daemon)
        except Exception as e:
            sdk_error = convert_to_algorithm_error(e)
            ScoringService._report_sdk_error(sdk_error)
            sys.exit(sdk_error.exit_code)

        GunicornApplication(cls.app, options).run()

    @staticmethod
    def stop(*args, **kwargs):
        integ.write_trusted_log_info("worker closed")
        ScoringService.app.shutdown()


@ScoringService.app.errorhandler(httplib.INTERNAL_SERVER_ERROR)
def internal_server_error(e):
    sdk_error = convert_to_algorithm_error(e)
    ScoringService._report_sdk_error(sdk_error)
    return flask.Response(response="Internal Server Error", status=httplib.INTERNAL_SERVER_ERROR)


@ScoringService.app.route("/ping", methods=["GET"])
def ping():
    # TODO: implement health checks
    return flask.Response(status=httplib.OK)


def _error_predicate(return_value, exception):
    status_code_is_not_2xx = return_value and return_value.status_code / 100 != 2
    return exception or status_code_is_not_2xx


def _score_json(model, observation, response_content_type=CONTENT_TYPE_JSON):
    event_id = uuid.uuid1().int
    dt = datetime.datetime.now()
    timestamp = int(dt.strftime("%s"))
    action_probs = model.predict(observation)
    nchoices = len(action_probs)
    action_probs = (action_probs / action_probs.sum())
    action = np.random.choice(nchoices, p=action_probs) + 1
    action_prob = action_probs[action - 1]
    # add sample_prob for later dataset sampling
    sample_prob = random.uniform(0.0, 1.0)
    if response_content_type in (CONTENT_TYPE_JSON, CONTENT_TYPE_JSONLINES):
        # convert to JSON
        response_payload = json.dumps({"action": action,
                                       "action_prob": action_prob,
                                       "event_id": event_id,
                                       "timestamp": timestamp,
                                       "sample_prob": sample_prob,
                                       "model_id": ScoringService._model_id})
    else:
        # convert to CSV
        response_payload = ",".join([
            str(x) for x in
            [action, action_prob, event_id, timestamp, sample_prob, ScoringService._model_id]
        ])
    blob_to_log = json.dumps({"action": action,
                              "action_prob": action_prob,
                              "event_id": event_id,
                              "observation": observation,
                              "timestamp": timestamp,
                              "model_id": ScoringService._model_id,
                              "sample_prob": sample_prob,
                              "type": "actions"})
    if ScoringService.LOG_INFERENCE_DATA:
        # TODO: Log state, action, eventID
        ScoringService._redis_client.publish(REDIS_PUBLISHER_CHANNEL, blob_to_log)
    return response_payload


@ScoringService.app.route("/invocations", methods=["POST"])
# @METRICS.count("invocations")
# @METRICS.count_when("invocations_error", _error_predicate)
def invocations():
    content_type, content_parameters = ScoringService.parse_content_type(flask.request.content_type)
    response_content_type = flask.request.headers.get("Accept", "application/json")
    if content_type not in [CONTENT_TYPE_JSON, CONTENT_TYPE_JSONLINES, CONTENT_TYPE_CSV]:
        sdk_error = InferenceCustomerError("Content-type {} not supported".format(content_type))
        ScoringService._report_sdk_error(sdk_error)
        return flask.Response(
            response="content-type {} not supported".format(content_type), status=httplib.UNSUPPORTED_MEDIA_TYPE
        )

    payload = flask.request.data
    if len(payload) == 0:
        return flask.Response(response="", status=httplib.NO_CONTENT)

    try:
        model = ScoringService.get_model()
    except Exception as e:
        sdk_error = convert_to_algorithm_error(e)
        ScoringService._report_sdk_error(sdk_error)
        return flask.Response(response="unable to load model", status=httplib.INTERNAL_SERVER_ERROR,
                              mimetype="application/json", content_type="application/json")

    if content_type == CONTENT_TYPE_JSON:
        data = json.loads(payload.decode("utf-8"))

        request_type = data.get("request_type", "observation").lower()

        if request_type == "reward":
            event_id = data["event_id"]
            reward = data["reward"]
            blob_to_log = {"event_id": int(event_id), "reward": float(reward), "type": "rewards"}
            blob_to_log = json.dumps(blob_to_log)
            if ScoringService.LOG_INFERENCE_DATA:
                ScoringService._redis_client.publish(REDIS_PUBLISHER_CHANNEL, blob_to_log)
                status = "success"
            else:
                status = "failure"
            return flask.Response(response='{"status": "%s"}' % status, status=httplib.OK)

        elif request_type == "model_id":
            model_info_payload = json.dumps({"model_id": ScoringService._model_id,
                                             "soft_model_update_status": "TBD: To be used for indicating rollbacks"})
            return flask.Response(response=model_info_payload, status=httplib.OK, mimetype="application/json",
                                  content_type="application/json")

        else:
            observation = data["observation"]

            response_payload = _score_json(model, observation)
            return flask.Response(response=response_payload, status=httplib.OK, mimetype="application/json",
                                  content_type="application/json")
    elif content_type == CONTENT_TYPE_JSONLINES:
        #  Content type is application/jsonlines, which means this is Batch Inference mode
        data = payload.decode("utf-8")
        f = StringIO(data)
        response = [_score_json(model, json.loads(line), response_content_type=response_content_type) for line in f.readlines()]
        response_payload = "\n".join(response)
        return flask.Response(response=response_payload, status=httplib.OK, mimetype=response_content_type,
                              content_type=response_content_type)
    else:
        # content type is csv, batch inference
        data = payload.decode("utf-8")
        rows_ = np.genfromtxt(StringIO(data), delimiter=',').astype(float)
        rows = rows_.tolist()
        # if this was a single record, we will just have a single list
        # the loop below expects a list of lists, so pack it up
        if len(rows_.shape) == 1:
            rows = [rows]
        response = [
            _score_json(
                model,
                row,
                response_content_type=response_content_type
            )
            for row in rows
        ]
        response_payload = "\n".join(response)
        return flask.Response(response=response_payload, status=httplib.OK, mimetype=response_content_type,
                              content_type=response_content_type)


@ScoringService.app.route("/execution-parameters", methods=["GET"])
def execution_parameters():
    service_config = ScoringService.get_transform_configuration()
    try:
        parameters = {
            "MaxConcurrentTransforms": service_config["max_concurrent_transforms"],
            "BatchStrategy": service_config["batch_strategy"],
            "MaxPayloadInMB": service_config["max_payload_size"] // (1024 * 1024)  # convert bytes to MB
        }
    except KeyError as e:
        sdk_error = convert_to_algorithm_error(e)
        ScoringService._report_sdk_error(sdk_error)
        return flask.Response(response="unable to determine execution parameters", status=httplib.INTERNAL_SERVER_ERROR)

    response_text = json.dumps(parameters)
    return flask.Response(response=response_text, status=httplib.OK, mimetype="application/json")


if __name__ == "__main__":
    ScoringService.start()