# -*- coding: utf-8 -*-
# Copyright: (c) 2019, Jordan Borean (@jborean93) <jborean93@gmail.com>
# MIT License (see LICENSE or https://opensource.org/licenses/MIT)

import logging
import select
import socket
import struct
import threading

from collections import (
    OrderedDict,
)

from smbprotocol.structure import (
    BytesField,
    IntField,
    Structure,
)

try:
    from queue import Queue
except ImportError:  # pragma: no cover
    from Queue import Queue

log = logging.getLogger(__name__)


class DirectTCPPacket(Structure):
    """
    [MS-SMB2] v53.0 2017-09-15

    2.1 Transport
    The Directory TCP transport packet header MUST have the following
    structure.
    """

    def __init__(self):
        self.fields = OrderedDict([
            ('stream_protocol_length', IntField(
                size=4,
                little_endian=False,
                default=lambda s: len(s['smb2_message']),
            )),
            ('smb2_message', BytesField(
                size=lambda s: s['stream_protocol_length'].get_value(),
            )),
        ])
        super(DirectTCPPacket, self).__init__()


def socket_connect(func):
    def wrapped(self, *args, **kwargs):
        if not self._connected:
            log.info("Connecting to DirectTcp socket")
            try:
                self._sock = socket.create_connection((self.server, self.port), timeout=self.timeout)
            except (OSError, socket.gaierror) as err:
                raise ValueError("Failed to connect to '%s:%s': %s" % (self.server, self.port, str(err)))
            self._sock.settimeout(None)  # Make sure the socket is in blocking mode.

            self._t_recv = threading.Thread(target=self.recv_thread, name="recv-%s:%s" % (self.server, self.port))
            self._t_recv.daemon = True
            self._t_recv.start()
            self._connected = True

        func(self, *args, **kwargs)

    return wrapped


class Tcp(object):

    MAX_SIZE = 16777215

    def __init__(self, server, port, recv_queue, timeout=None):
        self.server = server
        self.port = port
        self.timeout = timeout
        self._connected = False
        self._sock = None
        self._recv_queue = recv_queue
        self._t_recv = None

    def close(self):
        if self._connected:
            log.info("Disconnecting DirectTcp socket")
            # Send a shutdown to the socket so the select returns and wait until the thread is closed before actually
            # closing the socket.
            self._connected = False
            self._sock.shutdown(socket.SHUT_RDWR)
            self._t_recv.join()
            self._sock.close()

    @socket_connect
    def send(self, header):
        b_msg = header
        data_length = len(b_msg)
        if data_length > self.MAX_SIZE:
            raise ValueError("Data to be sent over Direct TCP size %d exceeds the max length allowed %d"
                             % (data_length, self.MAX_SIZE))

        tcp_packet = DirectTCPPacket()
        tcp_packet['smb2_message'] = b_msg

        data = tcp_packet.pack()
        while data:
            sent = self._sock.send(data)
            data = data[sent:]

    def recv_thread(self):
        try:
            while True:
                select.select([self._sock], [], [])
                b_packet_size = self._sock.recv(4)
                if b_packet_size == b"":
                    return

                packet_size = struct.unpack(">L", b_packet_size)[0]
                b_data = bytearray()
                bytes_read = 0
                while bytes_read < packet_size:
                    b_fragment = self._sock.recv(packet_size - bytes_read)
                    b_data.extend(b_fragment)
                    bytes_read += len(b_fragment)

                self._recv_queue.put(bytes(b_data))
        except Exception as e:
            # Log a warning if the exception was raised while we were connected and not just some weird platform-ism
            # exception when reading from a closed socket.
            if self._connected:
                log.warning("Uncaught exception in socket recv thread: %s" % e)
            return
        finally:
            # Make sure we close the message processing thread in connection.py
            self._recv_queue.put(None)