from __future__ import absolute_import

from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.types import base_sdk_types as _base_sdk_types
from flytekit.models import types as _idl_types, literals as _literals
from google.protobuf import reflection as _proto_reflection

import base64 as _base64
import six as _six


def create_protobuf(pb_type):
    """
    :param T pb_type:
    :rtype: ProtobufType
    """
    if not isinstance(pb_type, _proto_reflection.GeneratedProtocolMessageType):
        raise _user_exceptions.FlyteTypeException(
            expected_type=_proto_reflection.GeneratedProtocolMessageType,
            received_type=type(pb_type),
            received_value=pb_type
        )

    class _Protobuf(Protobuf):
        _pb_type = pb_type
    return _Protobuf


class ProtobufType(_base_sdk_types.FlyteSdkType):

    @property
    def pb_type(cls):
        """
        :rtype: GeneratedProtocolMessageType
        """
        return cls._pb_type

    @property
    def descriptor(cls):
        """
        :rtype: Text
        """
        return "{}.{}".format(cls.pb_type.__module__, cls.pb_type.__name__)

    @property
    def tag(cls):
        """
        :rtype: Text
        """
        return "{}{}".format(Protobuf.TAG_PREFIX, cls.descriptor)


class Protobuf(_six.with_metaclass(ProtobufType, _base_sdk_types.FlyteSdkValue)):

    PB_FIELD_KEY = "pb_type"
    TAG_PREFIX = "{}=".format(PB_FIELD_KEY)

    def __init__(self, pb_object):
        """
        :param T pb_object:
        """
        data = pb_object.SerializeToString()
        super(Protobuf, self).__init__(
            scalar=_literals.Scalar(
                binary=_literals.Binary(
                    value=bytes(data) if _six.PY2 else data,
                    tag=type(self).tag
                )
            )
        )

    @classmethod
    def from_string(cls, string_value):
        """
        :param Text string_value: b64 encoded string of bytes
        :rtype: Protobuf
        """
        try:
            decoded = _base64.b64decode(string_value)
        except TypeError:
            raise _user_exceptions.FlyteValueException(string_value, "The string is not valid base64-encoded.")
        pb_obj = cls.pb_type()
        pb_obj.ParseFromString(decoded)
        return cls(pb_obj)

    @classmethod
    def is_castable_from(cls, other):
        """
        :param flytekit.common.types.base_literal_types.FlyteSdkType other:
        :rtype: bool
        """
        return isinstance(other, ProtobufType) and other.pb_type is cls.pb_type

    @classmethod
    def from_python_std(cls, t_value):
        """
        :param T t_value: It is up to each individual object as to whether or not this value can be cast.
        :rtype: _base_sdk_types.FlyteSdkValue
        :raises: flytekit.common.exceptions.user.FlyteTypeException
        """
        if t_value is None:
            return _base_sdk_types.Void()
        elif isinstance(t_value, cls.pb_type):
            return cls(t_value)
        else:
            raise _user_exceptions.FlyteTypeException(
                type(t_value),
                cls.pb_type,
                received_value=t_value
            )

    @classmethod
    def to_flyte_literal_type(cls):
        """
        :rtype: flytekit.models.types.LiteralType
        """
        return _idl_types.LiteralType(
            simple=_idl_types.SimpleType.BINARY,
            metadata={
                cls.PB_FIELD_KEY: cls.descriptor
            }
        )

    @classmethod
    def promote_from_model(cls, literal_model):
        """
        Creates an object of this type from the model primitive defining it.
        :param flytekit.models.literals.Literal literal_model:
        :rtype: Protobuf
        """
        if literal_model.scalar.binary.tag != cls.tag:
            raise _user_exceptions.FlyteTypeException(
                literal_model.scalar.binary.tag,
                cls.pb_type,
                received_value=_base64.b64encode(literal_model.scalar.binary.value),
                additional_msg="Can not deserialize as proto tags don't match."
            )
        pb_obj = cls.pb_type()
        pb_obj.ParseFromString(literal_model.scalar.binary.value)
        return cls(pb_obj)

    @classmethod
    def short_class_string(cls):
        """
        :rtype: Text
        """
        return "Types.Proto({})".format(cls.descriptor)

    def to_python_std(self):
        """
        :returns: The protobuf object as defined by the user.
        :rtype: T
        """
        pb_obj = type(self).pb_type()
        pb_obj.ParseFromString(self.scalar.binary.value)
        return pb_obj

    def short_string(self):
        """
        :rtype: Text
        """
        return "{}".format(self.to_python_std())