# Copyright 2014, 2015 SAP SE
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import struct
###
from pyhdb.protocol import constants
from pyhdb.protocol.headers import MessageHeader
from pyhdb.protocol.segments import ReplySegment
from pyhdb.lib.tracing import trace


class BaseMessage(object):
    """
    Message - basic frame for sending to and receiving data from HANA db.
    """
    header_struct = struct.Struct('qiIIhb9x')  # I8 I4 UI4 UI4 I2 I1 x[9]
    header_size = header_struct.size
    assert header_size == constants.general.MESSAGE_HEADER_SIZE  # Ensures that the constant defined there is correct!
    __tracing_attrs__ = ['header', 'segments']

    def __init__(self, session_id, packet_count, segments=(), autocommit=False, header=None):
        self.session_id = session_id
        self.packet_count = packet_count
        self.autocommit = autocommit
        self.segments = segments if isinstance(segments, (list, tuple)) else (segments, )
        self.header = header


class RequestMessage(BaseMessage):
    def build_payload(self, payload):
        """ Build payload of message. """
        for segment in self.segments:
            segment.pack(payload, commit=self.autocommit)

    def pack(self):
        """ Pack message to binary stream. """
        payload = io.BytesIO()
        # Advance num bytes equal to header size - the header is written later
        # after the payload of all segments and parts has been written:
        payload.seek(self.header_size, io.SEEK_CUR)

        # Write out payload of segments and parts:
        self.build_payload(payload)

        packet_length = len(payload.getvalue()) - self.header_size
        self.header = MessageHeader(self.session_id, self.packet_count, packet_length, constants.MAX_SEGMENT_SIZE,
                                    num_segments=len(self.segments), packet_options=0)
        packed_header = self.header_struct.pack(*self.header)

        # Go back to begining of payload for writing message header:
        payload.seek(0)
        payload.write(packed_header)
        payload.seek(0, io.SEEK_END)

        trace(self)

        return payload

    @classmethod
    def new(cls, connection, segments=()):
        """Return a new request message instance - extracts required data from connection object
        :param connection: connection object
        :param segments: a single segment instance, or a list/tuple of segment instances
        :returns: RequestMessage instance
        """
        return cls(connection.session_id, connection.get_next_packet_count(), segments,
                   autocommit=connection.autocommit)


class ReplyMessage(BaseMessage):
    """Reply message class"""
    @classmethod
    def unpack_reply(cls, header, payload):
        """Take already unpacked header and binary payload of received request reply and creates message instance
        :param header: a namedtuple header object providing header information
        :param payload: payload (BytesIO instance) of message
        """
        reply = cls(
            header.session_id, header.packet_count,
            segments=tuple(ReplySegment.unpack_from(payload, expected_segments=header.num_segments)),
            header=header
        )
        trace(reply)
        return reply

    @classmethod
    def header_from_raw_header_data(cls, raw_header):
        """Unpack binary message header data obtained as a reply from HANA
        :param raw_header: binary string containing message header data
        :returns: named tuple for easy access of header data
        """
        try:
            header = MessageHeader(*cls.header_struct.unpack(raw_header))
        except struct.error:
            raise Exception("Invalid message header received")
        return header