# -*- coding: utf-8 -*-
# Copyright: (c) 2019, Jordan Borean (@jborean93) <jborean93@gmail.com>
# MIT License (see LICENSE or https://opensource.org/licenses/MIT)

from __future__ import division

import binascii
import hashlib
import hmac
import logging
import math
import os
import struct
import time
import threading

from collections import (
    OrderedDict,
)

from cryptography.exceptions import (
    UnsupportedAlgorithm,
)

from cryptography.hazmat.backends import (
    default_backend,
)

from cryptography.hazmat.primitives import (
    cmac,
)

from cryptography.hazmat.primitives.ciphers import (
    aead,
    algorithms,
)

from datetime import (
    datetime,
)

from threading import (
    Lock,
)

from smbprotocol import (
    Commands,
    Dialects,
    MAX_PAYLOAD_SIZE,
)

from smbprotocol._text import (
    to_native,
    to_text,
)

from smbprotocol.exceptions import (
    NtStatus,
    SMB2SymbolicLinkErrorResponse,
    SMBException,
    SMBResponseException,
)

from smbprotocol.open import (
    Open,
)

from smbprotocol.structure import (
    BytesField,
    DateTimeField,
    EnumField,
    FlagField,
    IntField,
    ListField,
    Structure,
    StructureField,
    UuidField,
)

from smbprotocol.transport import (
    Tcp,
)

try:
    from queue import Queue
except ImportError:  # pragma: no cover
    from Queue import Queue

log = logging.getLogger(__name__)


class Smb2Flags(object):
    """
    [MS-SMB2] v53.0 2017-09-15

    2.2.1.2 SMB2 Packet Header - SYNC Flags
    Indicates various processing rules that need to be done on the SMB2 packet.
    """
    SMB2_FLAGS_SERVER_TO_REDIR = 0x00000001
    SMB2_FLAGS_ASYNC_COMMAND = 0x00000002
    SMB2_FLAGS_RELATED_OPERATIONS = 0x00000004
    SMB2_FLAGS_SIGNED = 0x00000008
    SMB2_FLAGS_PRIORITY_MASK = 0x00000070
    SMB2_FLAGS_DFS_OPERATIONS = 0x10000000
    SMB2_FLAGS_REPLAY_OPERATIONS = 0x20000000


class SecurityMode(object):
    """
    [MS-SMB2] v53.0 2017-09-15

    2.2.3 SMB2 NEGOTIATE Request SecurityMode
    Indicates whether SMB signing is enabled or required by the client.
    """
    SMB2_NEGOTIATE_SIGNING_ENABLED = 0x0001
    SMB2_NEGOTIATE_SIGNING_REQUIRED = 0x0002


class Capabilities(object):
    """
    [MS-SMB2] v53.0 2017-09-15

    2.2.3 SMB2 NEGOTIATE Request Capabilities
    Used in SMB3.x and above, used to specify the capabilities supported.
    """
    SMB2_GLOBAL_CAP_DFS = 0x00000001
    SMB2_GLOBAL_CAP_LEASING = 0x00000002
    SMB2_GLOBAL_CAP_LARGE_MTU = 0x00000004
    SMB2_GLOBAL_CAP_MULTI_CHANNEL = 0x00000008
    SMB2_GLOBAL_CAP_PERSISTENT_HANDLES = 0x00000010
    SMB2_GLOBAL_CAP_DIRECTORY_LEASING = 0x00000020
    SMB2_GLOBAL_CAP_ENCRYPTION = 0x00000040


class NegotiateContextType(object):
    """
    [MS-SMB2] v53.0 2017-09-15

    2.2.3.1 SMB2 NEGOTIATE_CONTENT Request ContextType
    Specifies the type of context in an SMB2 NEGOTIATE_CONTEXT message.
    """
    SMB2_PREAUTH_INTEGRITY_CAPABILITIES = 0x0001
    SMB2_ENCRYPTION_CAPABILITIES = 0x0002


class HashAlgorithms(object):
    """
    [MS-SMB2] v53.0 2017-09-15

    2.2.3.1.1 SMB2_PREAUTH_INTEGRITY_CAPABILITIES
    16-bit integer IDs that specify the integrity hash algorithm supported
    """
    SHA_512 = 0x0001

    @staticmethod
    def get_algorithm(hash):
        return {
            HashAlgorithms.SHA_512: hashlib.sha512
        }[hash]


class Ciphers(object):
    """
    [MS-SMB2] v53.0 2017-09-15

    2.2.3.1.2 SMB2_ENCRYPTION_CAPABILITIES
    16-bit integer IDs that specify the supported encryption algorithms.
    """
    AES_128_CCM = 0x0001
    AES_128_GCM = 0x0002

    @staticmethod
    def get_cipher(cipher):
        return {
            Ciphers.AES_128_CCM: aead.AESCCM,
            Ciphers.AES_128_GCM: aead.AESGCM
        }[cipher]

    @staticmethod
    def get_supported_ciphers():
        supported_ciphers = []
        try:
            aead.AESGCM(b"\x00" * 16)
            supported_ciphers.append(Ciphers.AES_128_GCM)
        except UnsupportedAlgorithm:  # pragma: no cover
            pass
        try:
            aead.AESCCM(b"\x00" * 16)
            supported_ciphers.append(Ciphers.AES_128_CCM)
        except UnsupportedAlgorithm:  # pragma: no cover
            pass
        return supported_ciphers


class SMB2HeaderAsync(Structure):
    """
    [MS-SMB2] 2.2.1.1 SMB2 Packer Header - ASYNC
    https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-smb2/ea4560b7-90da-4803-82b5-344754b92a79

    The SMB2 Packet header for async commands.
    """

    def __init__(self):
        self.fields = OrderedDict([
            ('protocol_id', BytesField(
                size=4,
                default=b"\xfeSMB",
            )),
            ('structure_size', IntField(
                size=2,
                default=64,
            )),
            ('credit_charge', IntField(size=2)),
            ('channel_sequence', IntField(size=2)),
            ('reserved', IntField(size=2)),
            ('command', EnumField(
                size=2,
                enum_type=Commands,
            )),
            ('credit_request', IntField(size=2)),
            ('flags', FlagField(
                size=4,
                flag_type=Smb2Flags,
            )),
            ('next_command', IntField(size=4)),
            ('message_id', IntField(size=8)),
            ('async_id', IntField(size=8)),
            ('session_id', IntField(size=8)),
            ('signature', BytesField(
                size=16,
                default=b"\x00" * 16,
            )),
            ('data', BytesField())
        ])
        super(SMB2HeaderAsync, self).__init__()


class SMB2HeaderRequest(Structure):
    """
    [MS-SMB2] v53.0 2017-09-15

    2.2.1.2 SMB2 Packet Header - SYNC
    This is the header definition that contains the ChannelSequence/Reserved
    instead of the Status field used for a Packet request.
    """

    def __init__(self):
        self.fields = OrderedDict([
            ('protocol_id', BytesField(
                size=4,
                default=b"\xfeSMB",
            )),
            ('structure_size', IntField(
                size=2,
                default=64,
            )),
            ('credit_charge', IntField(size=2)),
            ('channel_sequence', IntField(size=2)),
            ('reserved', IntField(size=2)),
            ('command', EnumField(
                size=2,
                enum_type=Commands
            )),
            ('credit_request', IntField(size=2)),
            ('flags', FlagField(
                size=4,
                flag_type=Smb2Flags,
            )),
            ('next_command', IntField(size=4)),
            ('message_id', IntField(size=8)),
            ('process_id', IntField(size=4)),
            ('tree_id', IntField(size=4)),
            ('session_id', IntField(size=8)),
            ('signature', BytesField(
                size=16,
                default=b"\x00" * 16,
            )),
            ('data', BytesField())
        ])
        super(SMB2HeaderRequest, self).__init__()


class SMB2HeaderResponse(Structure):
    """
    [MS-SMB2] v53.0 2017-09-15

    2.2.1.2 SMB2 Packet Header - SYNC
    The header definition for an SMB Response that contains the Status field
    instead of the ChannelSequence/Reserved used for a Packet response.
    """

    def __init__(self):
        self.fields = OrderedDict([
            ('protocol_id', BytesField(
                size=4,
                default=b'\xfeSMB',
            )),
            ('structure_size', IntField(
                size=2,
                default=64,
            )),
            ('credit_charge', IntField(size=2)),
            ('status', EnumField(
                size=4,
                enum_type=NtStatus,
                enum_strict=False
            )),
            ('command', EnumField(
                size=2,
                enum_type=Commands,
                enum_strict=False,
            )),
            ('credit_response', IntField(size=2)),
            ('flags', FlagField(
                size=4,
                flag_type=Smb2Flags,
            )),
            ('next_command', IntField(size=4)),
            ('message_id', IntField(size=8)),
            ('reserved', IntField(size=4)),
            ('tree_id', IntField(size=4)),
            ('session_id', IntField(size=8)),
            ('signature', BytesField(
                size=16,
                default=b"\x00" * 16,
            )),
            ('data', BytesField()),
        ])
        super(SMB2HeaderResponse, self).__init__()


class SMB2NegotiateRequest(Structure):
    """
    [MS-SMB2] v53.0 2017-09-15

    2.2.3 SMB2 Negotiate Request
    The SMB2 NEGOTIATE Request packet is used by the client to notify the
    server what dialects of the SMB2 Protocol the client understands. This is
    only used if the client explicitly sets the Dialect to use to a version
    less than 3.1.1. Dialect 3.1.1 added support for negotiate_context and
    SMB3NegotiateRequest should be used to support that.
    """
    COMMAND = Commands.SMB2_NEGOTIATE

    def __init__(self):
        self.fields = OrderedDict([
            ('structure_size', IntField(
                size=2,
                default=36,
            )),
            ('dialect_count', IntField(
                size=2,
                default=lambda s: len(s['dialects'].get_value()),
            )),
            ('security_mode', FlagField(
                size=2,
                flag_type=SecurityMode
            )),
            ('reserved', IntField(size=2)),
            ('capabilities', FlagField(
                size=4,
                flag_type=Capabilities,
            )),
            ('client_guid', UuidField()),
            ('client_start_time', IntField(size=8)),
            ('dialects', ListField(
                size=lambda s: s['dialect_count'].get_value() * 2,
                list_count=lambda s: s['dialect_count'].get_value(),
                list_type=EnumField(size=2, enum_type=Dialects),
            )),
        ])

        super(SMB2NegotiateRequest, self).__init__()


class SMB3NegotiateRequest(Structure):
    """
    [MS-SMB2] v53.0 2017-09-15

    2.2.3 SMB2 Negotiate Request
    Like SMB2NegotiateRequest but with support for setting a list of
    Negotiate Context values. This is used by default and is for Dialects 3.1.1
    or greater.
    """
    COMMAND = Commands.SMB2_NEGOTIATE

    def __init__(self):
        self.fields = OrderedDict([
            ('structure_size', IntField(
                size=2,
                default=36,
            )),
            ('dialect_count', IntField(
                size=2,
                default=lambda s: len(s['dialects'].get_value()),
            )),
            ('security_mode', FlagField(
                size=2,
                flag_type=SecurityMode,
            )),
            ('reserved', IntField(size=2)),
            ('capabilities', FlagField(
                size=4,
                flag_type=Capabilities,
            )),
            ('client_guid', UuidField()),
            ('negotiate_context_offset', IntField(
                size=4,
                default=lambda s: self._negotiate_context_offset_value(s),
            )),
            ('negotiate_context_count', IntField(
                size=2,
                default=lambda s: len(s['negotiate_context_list'].get_value()),
            )),
            ('reserved2', IntField(size=2)),
            ('dialects', ListField(
                size=lambda s: s['dialect_count'].get_value() * 2,
                list_count=lambda s: s['dialect_count'].get_value(),
                list_type=EnumField(size=2, enum_type=Dialects),
            )),
            ('padding', BytesField(
                size=lambda s: self._padding_size(s),
                default=lambda s: b"\x00" * self._padding_size(s),
            )),
            ('negotiate_context_list', ListField(
                list_count=lambda s: s['negotiate_context_count'].get_value(),
                unpack_func=lambda s, d: self._negotiate_context_list(s, d),
            )),
        ])
        super(SMB3NegotiateRequest, self).__init__()

    def _negotiate_context_offset_value(self, structure):
        # The offset from the beginning of the SMB2 header to the first, 8-byte
        # aligned, negotiate context
        header_size = 64
        negotiate_size = structure['structure_size'].get_value()
        dialect_size = len(structure['dialects'])
        padding_size = self._padding_size(structure)
        return header_size + negotiate_size + dialect_size + padding_size

    def _padding_size(self, structure):
        # Padding between the end of the buffer value and the first Negotiate
        # context value so that the first value is 8-byte aligned. Padding is
        # 4 is there are no dialects specified
        mod = (structure['dialect_count'].get_value() * 2) % 8
        return 0 if mod == 0 else mod

    def _negotiate_context_list(self, structure, data):
        context_count = structure['negotiate_context_count'].get_value()
        context_list = []
        for idx in range(0, context_count):
            field, data = self._parse_negotiate_context_entry(data, idx)
            context_list.append(field)

        return context_list

    def _parse_negotiate_context_entry(self, data, idx):
        data_length = struct.unpack("<H", data[2:4])[0]
        negotiate_context = SMB2NegotiateContextRequest()
        negotiate_context.unpack(data[:data_length + 8])
        return negotiate_context, data[8 + data_length:]


class SMB2NegotiateContextRequest(Structure):
    """
    [MS-SMB2] v53.0 2017-09-15

    2.2.3.1 SMB2 NEGOTIATE_CONTEXT Request Values
    The SMB2_NEGOTIATE_CONTEXT structure is used by the SMB2 NEGOTIATE Request
    and the SMB2 NEGOTIATE Response to encode additional properties.
    """
    COMMAND = Commands.SMB2_NEGOTIATE

    def __init__(self):
        self.fields = OrderedDict([
            ('context_type', EnumField(
                size=2,
                enum_type=NegotiateContextType,
            )),
            ('data_length', IntField(
                size=2,
                default=lambda s: len(s['data'].get_value()),
            )),
            ('reserved', IntField(size=4)),
            ('data', StructureField(
                size=lambda s: s['data_length'].get_value(),
                structure_type=lambda s: self._data_structure_type(s)
            )),
            # not actually a field but each list entry must start at the 8 byte
            # alignment
            ('padding', BytesField(
                size=lambda s: self._padding_size(s),
                default=lambda s: b"\x00" * self._padding_size(s),
            ))
        ])
        super(SMB2NegotiateContextRequest, self).__init__()

    def _data_structure_type(self, structure):
        con_type = structure['context_type'].get_value()
        if con_type == \
                NegotiateContextType.SMB2_PREAUTH_INTEGRITY_CAPABILITIES:
            return SMB2PreauthIntegrityCapabilities
        elif con_type == NegotiateContextType.SMB2_ENCRYPTION_CAPABILITIES:
            return SMB2EncryptionCapabilities

    def _padding_size(self, structure):
        data_size = len(structure['data'])
        return 8 - data_size if data_size <= 8 else 8 - (data_size % 8)


class SMB2PreauthIntegrityCapabilities(Structure):
    """
    [MS-SMB2] v53.0 2017-09-15

    2.2.3.1.1 SMB2_PREAUTH_INTEGRITY_CAPABILITIES
    The SMB2_PREAUTH_INTEGRITY_CAPABILITIES context is specified in an SMB2
    NEGOTIATE request by the client to indicate which preauthentication
    integrity hash algorithms it supports and to optionally supply a
    preauthentication integrity hash salt value.
    """

    def __init__(self):
        self.fields = OrderedDict([
            ('hash_algorithm_count', IntField(
                size=2,
                default=lambda s: len(s['hash_algorithms'].get_value()),
            )),
            ('salt_length', IntField(
                size=2,
                default=lambda s: len(s['salt']),
            )),
            ('hash_algorithms', ListField(
                size=lambda s: s['hash_algorithm_count'].get_value() * 2,
                list_count=lambda s: s['hash_algorithm_count'].get_value(),
                list_type=EnumField(size=2, enum_type=HashAlgorithms),
            )),
            ('salt', BytesField(
                size=lambda s: s['salt_length'].get_value(),
            )),
        ])
        super(SMB2PreauthIntegrityCapabilities, self).__init__()


class SMB2EncryptionCapabilities(Structure):
    """
    [MS-SMB2] v53.0 2017-09-15

    2.2.3.1.2 SMB2_ENCRYPTION_CAPABILITIES
    The SMB2_ENCRYPTION_CAPABILITIES context is specified in an SMB2 NEGOTIATE
    request by the client to indicate which encryption algorithms the client
    supports.
    """

    def __init__(self):
        self.fields = OrderedDict([
            ('cipher_count', IntField(
                size=2,
                default=lambda s: len(s['ciphers'].get_value()),
            )),
            ('ciphers', ListField(
                size=lambda s: s['cipher_count'].get_value() * 2,
                list_count=lambda s: s['cipher_count'].get_value(),
                list_type=EnumField(size=2, enum_type=Ciphers),
            )),
        ])
        super(SMB2EncryptionCapabilities, self).__init__()


class SMB2NegotiateResponse(Structure):
    """
    [MS-SMB2] v53.0 2017-09-15

    2.2.4 SMB2 NEGOTIATE Response
    The SMB2 NEGOTIATE Response packet is sent by the server to notify the
    client of the preferred common dialect.
    """
    COMMAND = Commands.SMB2_NEGOTIATE

    def __init__(self):
        self.fields = OrderedDict([
            ('structure_size', IntField(
                size=2,
                default=65,
            )),
            ('security_mode', FlagField(
                size=2,
                flag_type=SecurityMode,
            )),
            ('dialect_revision', EnumField(
                size=2,
                enum_type=Dialects,
            )),
            ('negotiate_context_count', IntField(
                size=2,
                default=lambda s: self._negotiate_context_count_value(s),
            )),
            ('server_guid', UuidField()),
            ('capabilities', FlagField(
                size=4,
                flag_type=Capabilities
            )),
            ('max_transact_size', IntField(size=4)),
            ('max_read_size', IntField(size=4)),
            ('max_write_size', IntField(size=4)),
            ('system_time', DateTimeField()),
            ('server_start_time', DateTimeField()),
            ('security_buffer_offset', IntField(
                size=2,
                default=128,  # (header size 64) + (structure size 64)
            )),
            ('security_buffer_length', IntField(
                size=2,
                default=lambda s: len(s['buffer'].get_value()),
            )),
            ('negotiate_context_offset', IntField(
                size=4,
                default=lambda s: self._negotiate_context_offset_value(s),
            )),
            ('buffer', BytesField(
                size=lambda s: s['security_buffer_length'].get_value(),
            )),
            ('padding', BytesField(
                size=lambda s: self._padding_size(s),
                default=lambda s: b"\x00" * self._padding_size(s),
            )),
            ('negotiate_context_list', ListField(
                list_count=lambda s: s['negotiate_context_count'].get_value(),
                unpack_func=lambda s, d:
                self._negotiate_context_list(s, d),
            )),
        ])
        super(SMB2NegotiateResponse, self).__init__()

    def _negotiate_context_count_value(self, structure):
        # If the dialect_revision is SMBv3.1.1, this field specifies the
        # number of negotiate contexts in negotiate_context_list; otherwise
        # this field must not be used and must be reserved (0).
        if structure['dialect_revision'].get_value() == Dialects.SMB_3_1_1:
            return len(structure['negotiate_context_list'].get_value())
        else:
            return None

    def _negotiate_context_offset_value(self, structure):
        # If the dialect_revision is SMBv3.1.1, this field specifies the offset
        # from the beginning of the SMB2 header to the first 8-byte
        # aligned negotiate context entry in negotiate_context_list; otherwise
        # this field must not be used and must be reserved (0).
        if structure['dialect_revision'].get_value() == Dialects.SMB_3_1_1:
            buffer_offset = structure['security_buffer_offset'].get_value()
            buffer_size = structure['security_buffer_length'].get_value()
            padding_size = self._padding_size(structure)
            return buffer_offset + buffer_size + padding_size
        else:
            return None

    def _padding_size(self, structure):
        # Padding between the end of the buffer value and the first Negotiate
        # context value so that the first value is 8-byte aligned. Padding is
        # not required if there are not negotiate contexts
        if structure['negotiate_context_count'].get_value() == 0:
            return 0

        mod = structure['security_buffer_length'].get_value() % 8
        return 0 if mod == 0 else 8 - mod

    def _negotiate_context_list(self, structure, data):
        context_count = structure['negotiate_context_count'].get_value()
        context_list = []
        for idx in range(0, context_count):
            field, data = self._parse_negotiate_context_entry(data)
            context_list.append(field)

        return context_list

    def _parse_negotiate_context_entry(self, data):
        data_length = struct.unpack("<H", data[2:4])[0]
        negotiate_context = SMB2NegotiateContextRequest()
        negotiate_context.unpack(data[:data_length + 8])
        padded_size = data_length % 8
        if padded_size != 0:
            padded_size = 8 - padded_size

        return negotiate_context, data[8 + data_length + padded_size:]


class SMB2Echo(Structure):
    """
    [MS-SMB2] v53.0 2017-09-15

    2.2.28 SMB2 Echo Request/Response
    Request and response for an SMB ECHO message.
    """
    COMMAND = Commands.SMB2_ECHO

    def __init__(self):
        self.fields = OrderedDict([
            ('structure_size', IntField(
                size=2,
                default=4
            )),
            ('reserved', IntField(size=2))
        ])
        super(SMB2Echo, self).__init__()


class SMB2CancelRequest(Structure):
    """
    [MS-SMB2] 2.2.30 - SMB2 CANCEL Request
    https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-smb2/91913fc6-4ec9-4a83-961b-370070067e63

    The SMB2 CANCEL Request packet is sent by the client to cancel a previously sent message on the same SMB2 transport
    connection.
    """
    COMMAND = Commands.SMB2_CANCEL

    def __init__(self):
        self.fields = OrderedDict([
            ('structure_size', IntField(
                size=2,
                default=4,
            )),
            ('reserved', IntField(size=2)),
        ])
        super(SMB2CancelRequest, self).__init__()


class SMB2TransformHeader(Structure):
    """
    [MS-SMB2] v53.0 2017-09-15

    2.2.41 SMB2 TRANSFORM_HEADER
    The SMB2 Transform Header is used by the client or server when sending
    encrypted message. This is only valid for the SMB.x dialect family.
    """

    def __init__(self):
        self.fields = OrderedDict([
            ('protocol_id', BytesField(
                size=4,
                default=b"\xfdSMB"
            )),
            ('signature', BytesField(
                size=16,
                default=b"\x00" * 16
            )),
            ('nonce', BytesField(size=16)),
            ('original_message_size', IntField(size=4)),
            ('reserved', IntField(size=2, default=0)),
            ('flags', IntField(
                size=2,
                default=1
            )),
            ('session_id', IntField(size=8)),
            ('data', BytesField())  # not in spec
        ])
        super(SMB2TransformHeader, self).__init__()


def _worker_running(func):
    """ Ensures the message worker thread is still running and hasn't failed for any reason. """
    def wrapped(self, *args, **kwargs):
        self._check_worker_running()
        return func(self, *args, **kwargs)
    return wrapped


class Connection(object):

    def __init__(self, guid, server_name, port=445, require_signing=True):
        """
        [MS-SMB2] v53.0 2017-09-15

        3.2.1.2 Per SMB2 Transport Connection
        Used as the transport interface for a server. Some values have been
        omitted as they can be retrieved by the Server object stored in
        self.server

        :param guid: A unique guid that represents the client
        :param server_name: The server to start the connection
        :param port: The port to use for the transport, default is 445
        :param require_signing: Whether signing is required on SMB messages
            sent over this connection
        """
        log.info("Initialising connection, guid: %s, require_singing: %s, "
                 "server_name: %s, port: %d"
                 % (guid, require_signing, server_name, port))
        self.server_name = server_name
        self.port = port
        self.transport = None  # Instanciated in .connect()

        # Table of Session entries, the order is important for smbclient.
        self.session_table = OrderedDict()

        # Table of sessions that have not completed authentication, indexed by
        # session_id
        self.preauth_session_table = {}

        # Table of Requests that have yet to be picked up by the application,
        # it MAY contain a response from the server as well
        self.outstanding_requests = dict()

        # Table of available sequence numbers
        self.sequence_window = dict(
            low=0,
            high=1
        )
        self.sequence_lock = Lock()

        # Byte array containing the negotiate token and remembered for
        # authentication
        self.gss_negotiate_token = None

        self.server_guid = None
        self.max_transact_size = None
        self.max_read_size = None
        self.max_write_size = None
        self.require_signing = require_signing

        # SMB 2.1+
        self.dialect = None
        self.supports_file_leasing = None
        # just go with False as a default for Dialect 2.0.2
        self.supports_multi_credit = False
        self.client_guid = guid

        # SMB 3.x+
        self.salt = None
        self.supports_directory_leasing = None
        self.supports_multi_channel = None
        self.supports_persistent_handles = None
        self.supports_encryption = None

        # used for SMB 3.x for secure negotiate verification on tree connect
        self.negotiated_dialects = []
        self.client_capabilities = Capabilities.SMB2_GLOBAL_CAP_LARGE_MTU | \
            Capabilities.SMB2_GLOBAL_CAP_ENCRYPTION

        self.client_security_mode = \
            SecurityMode.SMB2_NEGOTIATE_SIGNING_REQUIRED if \
            require_signing else SecurityMode.SMB2_NEGOTIATE_SIGNING_ENABLED
        self.server_security_mode = None
        self.server_capabilities = None

        # SMB 3.1.1+
        # The hashing algorithm object that was negotiated
        self.preauth_integrity_hash_id = None

        # Preauth integrity hash value computed for the SMB2 NEGOTIATE request
        # contains the messages used to compute the hash
        self.preauth_integrity_hash_value = []

        # The cipher object that was negotiated
        self.cipher_id = None

        # Keep track of the message processing thread's potential traceback that it may raise.
        self._t_exc = None

    def connect(self, dialect=None, timeout=60):
        """
        Will connect to the target server and negotiate the capabilities
        with the client. Once setup, the client MUST call the disconnect()
        function to close the listener thread. This function will populate
        various connection properties that denote the capabilities of the
        server.

        :param dialect: If specified, forces the dialect that is negotiated
            with the server, if not set, then the newest dialect supported by
            the server is used up to SMB 3.1.1
        :param timeout: The timeout in seconds to wait for the initial
            negotiation process to complete
        """
        log.info("Setting up transport connection")
        message_queue = Queue()
        self.transport = Tcp(self.server_name, self.port, message_queue, timeout)
        t_worker = threading.Thread(target=self._process_message_thread, args=(message_queue,),
                                    name="msg_worker-%s:%s" % (self.server_name, self.port))
        t_worker.daemon = True
        t_worker.start()

        log.info("Starting negotiation with SMB server")
        smb_response = self._send_smb2_negotiate(dialect, timeout)
        log.info("Negotiated dialect: %s"
                 % str(smb_response['dialect_revision']))
        self.dialect = smb_response['dialect_revision'].get_value()
        self.max_transact_size = smb_response['max_transact_size'].get_value()
        self.max_read_size = smb_response['max_read_size'].get_value()
        self.max_write_size = smb_response['max_write_size'].get_value()
        self.server_guid = smb_response['server_guid'].get_value()
        self.gss_negotiate_token = smb_response['buffer'].get_value()

        if not self.require_signing and \
                smb_response['security_mode'].has_flag(
                    SecurityMode.SMB2_NEGOTIATE_SIGNING_REQUIRED):
            self.require_signing = True
        log.info("Connection require signing: %s" % self.require_signing)
        capabilities = smb_response['capabilities']

        # SMB 2.1
        if self.dialect >= Dialects.SMB_2_1_0:
            self.supports_file_leasing = \
                capabilities.has_flag(Capabilities.SMB2_GLOBAL_CAP_LEASING)
            self.supports_multi_credit = \
                capabilities.has_flag(Capabilities.SMB2_GLOBAL_CAP_LARGE_MTU)

        # SMB 3.x
        if self.dialect >= Dialects.SMB_3_0_0:
            self.supports_directory_leasing = capabilities.has_flag(
                Capabilities.SMB2_GLOBAL_CAP_DIRECTORY_LEASING)
            self.supports_multi_channel = capabilities.has_flag(
                Capabilities.SMB2_GLOBAL_CAP_MULTI_CHANNEL)

            # TODO: SMB2_GLOBAL_CAP_PERSISTENT_HANDLES
            self.supports_persistent_handles = False
            self.supports_encryption = capabilities.has_flag(
                Capabilities.SMB2_GLOBAL_CAP_ENCRYPTION) \
                and self.dialect < Dialects.SMB_3_1_1
            self.server_capabilities = capabilities
            self.server_security_mode = \
                smb_response['security_mode'].get_value()

            # TODO: Check/add server to server_list in Client Page 203

        # SMB 3.1
        if self.dialect >= Dialects.SMB_3_1_1:
            for context in smb_response['negotiate_context_list']:
                if context['context_type'].get_value() == \
                        NegotiateContextType.SMB2_ENCRYPTION_CAPABILITIES:
                    cipher_id = context['data']['ciphers'][0]
                    self.cipher_id = Ciphers.get_cipher(cipher_id)
                    self.supports_encryption = self.cipher_id != 0
                else:
                    hash_id = context['data']['hash_algorithms'][0]
                    self.preauth_integrity_hash_id = \
                        HashAlgorithms.get_algorithm(hash_id)

    def disconnect(self, close=True):
        """
        Closes the connection as well as logs off any of the
        Disconnects the TCP connection and shuts down the socket listener
        running in a thread.

        :param close: Will close all sessions in the connection as well as the
            tree connections of each session.
        """
        if close:
            for session in list(self.session_table.values()):
                session.disconnect(True)

        log.info("Disconnecting transport connection")
        self.transport.close()

    def send(self, message, sid=None, tid=None, credit_request=None, message_id=None, async_id=None):
        """
        Will send a message to the server that is passed in. The final unencrypted header is returned to the function
        that called this.

        :param message: An SMB message structure to send.
        :param sid: A session_id that the message is sent for.
        :param tid: A tree_id object that the message is sent for.
        :param credit_request: Specifies extra credits to be requested with the SMB header.
        :param message_id: The message_id for the header, only useful for a cancel request.
        :param async_id: The async_id for the header, only useful for a cancel request.
        :return: Request of the message that was sent.
        """
        return self._send([message], session_id=sid, tree_id=tid, message_id=message_id, credit_request=credit_request,
                          async_id=async_id)[0]

    def send_compound(self, messages, sid, tid, related=False):
        """
        Sends multiple messages within 1 TCP request, will fail if the size of the total length exceeds the maximum of
        the transport max.

        :param messages: A list of messages to send to the server.
        :param sid: The session_id that the request is sent for.
        :param tid: A tree_id object that the message is sent for.
        :param related: Whether each message is related to each other, sets the Session, Tree, and File Id to the same
            value as the first message.
        :return: List<Request> for each request that was sent, each entry in the list is in the same order of the
            message list that was passed in.
        """
        return self._send(messages, session_id=sid, tree_id=tid, related=related)

    @_worker_running
    def receive(self, request, wait=True, timeout=None, resolve_symlinks=True):
        """
        Polls the message buffer of the TCP connection and waits until a valid
        message is received based on the message_id passed in.

        :param request: The Request object to wait get the response for
        :param wait: Wait for the final response in the case of a STATUS_PENDING response, the pending response is
            returned in the case of wait=False
        :param timeout: Set a timeout used while waiting for the final response from the server.
        :param resolve_symlinks: Set to automatically resolve symlinks in the path when opening a file or directory.
        :return: SMB2HeaderResponse of the received message
        """
        start_time = time.time()
        while True:
            iter_timeout = int(max(timeout - (time.time() - start_time), 1)) if timeout is not None else None
            if not request.response_event.wait(timeout=iter_timeout):
                raise SMBException("Connection timeout of %d seconds exceeded while waiting for a message id %s "
                                   "response from the server" % (timeout, request.message['message_id'].get_value()))

            # Use a lock on the request so that in the case of a pending response we have exclusive lock on the event
            # flag and can clear it without the future pending response taking it over before we first clear the flag.
            with request.response_event_lock:
                self._check_worker_running()  # The worker may have failed while waiting for the response, check again

                response = request.response
                status = response['status'].get_value()
                if status == NtStatus.STATUS_PENDING and wait:
                    # Received a pending message, clear the response_event flag and wait again.
                    request.response_event.clear()
                    continue
                elif status == NtStatus.STATUS_STOPPED_ON_SYMLINK and resolve_symlinks:
                    # Received when we do an Open on a path that contains a symlink. Need to capture all related
                    # requests and resend the Open + others with the redirected path. First we need to resolve the
                    # symlink path. This will fail if the symlink is pointing to a location that is not in the same
                    # tree/share as the original request.

                    # First wait for the other remaining requests to be processed. Their status will also fail and we
                    # need to make sure we update the old request with the new one properly.
                    related_requests = [self.outstanding_requests[i] for i in request.related_ids]
                    [r.response_event.wait() for r in related_requests]

                    # Now create a new request with the new path the symlink points to.
                    session = self.session_table[request.session_id]
                    tree = session.tree_connect_table[request.message['tree_id'].get_value()]

                    old_create = request.get_message_data()
                    tree_share_name = tree.share_name + u'\\'
                    original_path = tree_share_name + to_text(old_create['buffer_path'].get_value(),
                                                              encoding='utf-16-le')

                    exp = SMBResponseException(response, status)
                    reparse_buffer = next((e for e in exp.error_details
                                           if isinstance(e, SMB2SymbolicLinkErrorResponse)))
                    new_path = reparse_buffer.resolve_path(original_path)[len(tree_share_name):]

                    new_open = Open(tree, new_path)
                    create_req = new_open.create(
                        old_create['impersonation_level'].get_value(),
                        old_create['desired_access'].get_value(),
                        old_create['file_attributes'].get_value(),
                        old_create['share_access'].get_value(),
                        old_create['create_disposition'].get_value(),
                        old_create['create_options'].get_value(),
                        create_contexts=old_create['buffer_contexts'].get_value(),
                        send=False
                    )[0]

                    # Now add all the related requests (if any) to send as a compound request.
                    new_msgs = [create_req] + [r.get_message_data() for r in related_requests]
                    new_requests = self.send_compound(new_msgs, session.session_id, tree.tree_connect_id, related=True)

                    # Verify that the first request was successful before updating the related requests with the new
                    # info.
                    error = None
                    try:
                        new_response = self.receive(new_requests[0], wait=wait, timeout=timeout, resolve_symlinks=True)
                    except SMBResponseException as exc:
                        # We need to make sure we fix up the remaining responses before throwing this.
                        error = exc
                    [r.response_event.wait() for r in new_requests]

                    # Update the old requests with the new response information
                    for i, old_request in enumerate([request] + related_requests):
                        del self.outstanding_requests[old_request.message['message_id'].get_value()]
                        old_request.update_request(new_requests[i])

                    if error:
                        raise error

                    return new_response
                elif status not in [NtStatus.STATUS_SUCCESS, NtStatus.STATUS_PENDING]:
                    raise SMBResponseException(response, status)
                else:
                    break

        # now we have a retrieval request for the response, we can delete
        # the request from the outstanding requests
        message_id = request.message['message_id'].get_value()
        del self.outstanding_requests[message_id]

        return response

    def echo(self, sid=0, timeout=60, credit_request=1):
        """
        Sends an SMB2 Echo request to the server. This can be used to request
        more credits from the server with the credit_request param.

        On a Samba server, the sid can be 0 but for a Windows SMB Server, the
        sid of an authenticated session must be passed into this function or
        else the socket will close.

        :param sid: When talking to a Windows host this must be populated with
            a valid session_id from a negotiated session
        :param timeout: The timeout in seconds to wait for the Echo Response
        :param credit_request: The number of credits to request
        :return: the credits that were granted by the server
        """
        log.info("Sending Echo request with a timeout of %d and credit "
                 "request of %d" % (timeout, credit_request))

        echo_msg = SMB2Echo()
        log.debug(echo_msg)
        req = self.send(echo_msg, sid=sid, credit_request=credit_request)

        log.info("Receiving Echo response")
        response = self.receive(req, timeout=timeout)
        log.info("Credits granted from the server echo response: %d"
                 % response['credit_response'].get_value())
        echo_resp = SMB2Echo()
        echo_resp.unpack(response['data'].get_value())
        log.debug(echo_resp)

        return response['credit_response'].get_value()

    def verify_signature(self, header, session_id, force=False):
        """
        Verifies the SMB2 Header request/response signature.

        :param header: The SMB2Header that will have its signature verified against the signing key specified.
        :param session_id: The Session Id to denote what session security verifies the message.
        :param force: Force verification of the header even if it does not match the criteria required in normal
            scenarios.
        """
        message_id = header['message_id'].get_value()
        flags = header['flags']
        status = header['status'].get_value()
        command = header['command'].get_value()

        if not force and (message_id == 0xFFFFFFFFFFFFFFFF or
                          not flags.has_flag(Smb2Flags.SMB2_FLAGS_SIGNED) or
                          status == NtStatus.STATUS_PENDING or
                          command == Commands.SMB2_SESSION_SETUP):
            return

        session = self.session_table.get(session_id, None)
        if session is None:
            raise SMBException("Failed to find session %s for message verification" % session_id)

        expected = self._generate_signature(header.pack(), session.signing_key)
        actual = header['signature'].get_value()
        if actual != expected:
            raise SMBException("Server message signature could not be verified: %s != %s"
                               % (to_native(binascii.hexlify(actual)), to_native(binascii.hexlify(expected))))

    def _check_worker_running(self):
        """ Checks that the message worker thread is still running and raises it's exception if it has failed. """
        if self._t_exc is not None:
            self.disconnect(False)
            raise self._t_exc

    @_worker_running
    def _send(self, messages, session_id=None, tree_id=None, message_id=None, credit_request=None, related=False,
              async_id=None):
        send_data = b""
        requests = []
        session = self.session_table.get(session_id, None)
        tree = None
        if tree_id and session:
            if tree_id not in session.tree_connect_table:
                raise SMBException("Cannot find Tree with the ID %d in the session tree table" % tree_id)
            tree = session.tree_connect_table[tree_id]

        total_requests = len(messages)
        for i, message in enumerate(messages):
            if i == total_requests - 1:
                next_command = 0
                padding = b""
            else:
                # each compound message must start at the 8-byte boundary
                msg_length = 64 + len(message)
                mod = msg_length % 8
                padding_length = 8 - mod if mod > 0 else 0
                next_command = msg_length + padding_length
                padding = b"\x00" * padding_length

            # When running with multiple threads we need to ensure that getting the message id and adjusting the
            # sequence windows is done in a thread safe manner so we use a lock to ensure only 1 thread accesses the
            # sequence window at a time.
            with self.sequence_lock:
                sequence_window_low = self.sequence_window['low']
                sequence_window_high = self.sequence_window['high']
                credit_charge = self._calculate_credit_charge(message)
                credits_available = sequence_window_high - sequence_window_low
                if credit_charge > credits_available:
                    raise SMBException("Request requires %d credits but only %d credits are available"
                                       % (credit_charge, credits_available))

                current_id = message_id or sequence_window_low
                if message.COMMAND != Commands.SMB2_CANCEL:
                    self.sequence_window['low'] += credit_charge if credit_charge > 0 else 1

            if async_id is None:
                header = SMB2HeaderRequest()
                header['tree_id'] = tree_id or 0
            else:
                header = SMB2HeaderAsync()
                header['flags'].set_flag(Smb2Flags.SMB2_FLAGS_ASYNC_COMMAND)
                header['async_id'] = async_id

            header['credit_charge'] = credit_charge
            header['command'] = message.COMMAND
            header['credit_request'] = credit_request if credit_request else credit_charge
            header['message_id'] = current_id
            header['session_id'] = session_id or 0
            header['data'] = message.pack()
            header['next_command'] = next_command

            if i != 0 and related:
                header['session_id'] = b"\xff" * 8
                header['tree_id'] = b"\xff" * 4
                header['flags'].set_flag(Smb2Flags.SMB2_FLAGS_RELATED_OPERATIONS)

            if session and session.signing_required and session.signing_key:
                header['flags'].set_flag(Smb2Flags.SMB2_FLAGS_SIGNED)
                b_header = header.pack() + padding
                signature = self._generate_signature(b_header, session.signing_key)

                # To save on unpacking and re-packing, manually adjust the signature and update the request object for
                # back-referencing.
                b_header = b_header[:48] + signature + b_header[64:]
                header['signature'] = signature
            else:
                b_header = header.pack() + padding

            send_data += b_header

            if message.COMMAND == Commands.SMB2_CANCEL:
                request = self.outstanding_requests[header['message_id'].get_value()]
            else:
                request = Request(header, type(message), self, session_id=session_id)
                self.outstanding_requests[header['message_id'].get_value()] = request

            requests.append(request)

        if related:
            requests[0].related_ids = [r.message['message_id'].get_value() for r in requests][1:]

        if session and session.encrypt_data or tree and tree.encrypt_data:
            send_data = self._encrypt(send_data, session)

        self.transport.send(send_data)
        return requests

    def _process_message_thread(self, msg_queue):
        while True:
            b_msg = msg_queue.get()

            # The socket will put None in the queue if it is closed, signalling the end of the connection.
            if b_msg is None:
                return

            try:
                is_encrypted = b_msg[:4] == b"\xfdSMB"
                if is_encrypted:
                    msg = SMB2TransformHeader()
                    msg.unpack(b_msg)
                    b_msg = self._decrypt(msg)

                next_command = -1
                while next_command != 0:
                    next_command = struct.unpack("<L", b_msg[20:24])[0]
                    header_length = next_command if next_command != 0 else len(b_msg)
                    b_header = b_msg[:header_length]
                    b_msg = b_msg[header_length:]

                    header = SMB2HeaderResponse()
                    header.unpack(b_header)

                    message_id = header['message_id'].get_value()
                    request = self.outstanding_requests[message_id]

                    # Typically you want to get the Session Id from the first message in a compound request but that is
                    # unreliable for async responses. Instead get the Session Id from the original request object if
                    # the Session Id is 0xFFFFFFFFFFFFFFFF.
                    # https://social.msdn.microsoft.com/Forums/en-US/a580f7bc-6746-4876-83db-6ac209b202c4/mssmb2-change-notify-response-sessionid?forum=os_fileservices
                    session_id = header['session_id'].get_value()
                    if session_id == 0xFFFFFFFFFFFFFFFF:
                        session_id = request.session_id

                    # No need to waste CPU cycles to verify the signature if we already decrypted the header.
                    if not is_encrypted:
                        self.verify_signature(header, session_id)

                    credit_response = header['credit_response'].get_value()
                    with self.sequence_lock:
                        self.sequence_window['high'] += credit_response if credit_response > 0 else 1

                    with request.response_event_lock:
                        if header['flags'].has_flag(Smb2Flags.SMB2_FLAGS_ASYNC_COMMAND):
                            request.async_id = b_header[32:40]

                        request.response = header
                        request.response_event.set()
            except Exception as exc:
                # The exception is raised in _check_worker_running by the main thread when send/receive is called next.
                self._t_exc = exc

                # Make sure we fire all the request events to ensure the main thread isn't waiting on a receive.
                for request in self.outstanding_requests.values():
                    request.response_event.set()

                # While a caller of send/receive could theoretically catch this exception, we consider any failures
                # here a fatal errors and the connection should be closed so we exit the worker thread.
                self.disconnect(False)
                return

    def _generate_signature(self, b_header, signing_key):
        b_header = b_header[:48] + (b"\x00" * 16) + b_header[64:]

        if self.dialect >= Dialects.SMB_3_0_0:
            c = cmac.CMAC(algorithms.AES(signing_key), backend=default_backend())
            c.update(b_header)
            signature = c.finalize()
        else:
            hmac_algo = hmac.new(signing_key, msg=b_header, digestmod=hashlib.sha256)
            signature = hmac_algo.digest()[:16]

        return signature

    def _encrypt(self, b_data, session):
        header = SMB2TransformHeader()
        header['original_message_size'] = len(b_data)
        header['session_id'] = session.session_id

        encryption_key = session.encryption_key
        if self.dialect >= Dialects.SMB_3_1_1:
            cipher = self.cipher_id
        else:
            cipher = Ciphers.get_cipher(Ciphers.AES_128_CCM)
        if cipher == aead.AESGCM:
            nonce = os.urandom(12)
            header['nonce'] = nonce + (b"\x00" * 4)
        else:
            nonce = os.urandom(11)
            header['nonce'] = nonce + (b"\x00" * 5)

        cipher_text = cipher(encryption_key).encrypt(nonce, b_data, header.pack()[20:])
        signature = cipher_text[-16:]
        enc_message = cipher_text[:-16]

        header['signature'] = signature
        header['data'] = enc_message

        return header

    def _decrypt(self, message):
        if message['flags'].get_value() != 0x0001:
            error_msg = "Expecting flag of 0x0001 but got %s in the SMB Transform Header Response" \
                        % format(message['flags'].get_value(), 'x')
            raise SMBException(error_msg)

        session_id = message['session_id'].get_value()
        session = self.session_table.get(session_id, None)
        if session is None:
            error_msg = "Failed to find valid session %s for message decryption" % session_id
            raise SMBException(error_msg)

        if self.dialect >= Dialects.SMB_3_1_1:
            cipher = self.cipher_id
        else:
            cipher = Ciphers.get_cipher(Ciphers.AES_128_CCM)

        nonce_length = 12 if cipher == aead.AESGCM else 11
        nonce = message['nonce'].get_value()[:nonce_length]

        signature = message['signature'].get_value()
        enc_message = message['data'].get_value() + signature

        c = cipher(session.decryption_key)
        dec_message = c.decrypt(nonce, enc_message, message.pack()[20:52])
        return dec_message

    def _send_smb2_negotiate(self, dialect, timeout):
        self.salt = os.urandom(32)

        if dialect is None:
            neg_req = SMB3NegotiateRequest()
            self.negotiated_dialects = [
                Dialects.SMB_2_0_2,
                Dialects.SMB_2_1_0,
                Dialects.SMB_3_0_0,
                Dialects.SMB_3_0_2,
                Dialects.SMB_3_1_1
            ]
            highest_dialect = Dialects.SMB_3_1_1
        else:
            if dialect >= Dialects.SMB_3_1_1:
                neg_req = SMB3NegotiateRequest()
            else:
                neg_req = SMB2NegotiateRequest()
            self.negotiated_dialects = [
                dialect
            ]
            highest_dialect = dialect
        neg_req['dialects'] = self.negotiated_dialects
        log.info("Negotiating with SMB2 protocol with highest client dialect "
                 "of: %s" % [dialect for dialect, v in vars(Dialects).items()
                             if v == highest_dialect][0])

        neg_req['security_mode'] = self.client_security_mode

        if highest_dialect >= Dialects.SMB_2_1_0:
            log.debug("Adding client guid %s to negotiate request"
                      % self.client_guid)
            neg_req['client_guid'] = self.client_guid

        if highest_dialect >= Dialects.SMB_3_0_0:
            log.debug("Adding client capabilities %d to negotiate request"
                      % self.client_capabilities)
            neg_req['capabilities'] = self.client_capabilities

        if highest_dialect >= Dialects.SMB_3_1_1:
            int_cap = SMB2NegotiateContextRequest()
            int_cap['context_type'] = \
                NegotiateContextType.SMB2_PREAUTH_INTEGRITY_CAPABILITIES
            int_cap['data'] = SMB2PreauthIntegrityCapabilities()
            int_cap['data']['hash_algorithms'] = [
                HashAlgorithms.SHA_512
            ]
            int_cap['data']['salt'] = self.salt
            log.debug("Adding preauth integrity capabilities of hash SHA512 "
                      "and salt %s to negotiate request" % self.salt)

            enc_cap = SMB2NegotiateContextRequest()
            enc_cap['context_type'] = \
                NegotiateContextType.SMB2_ENCRYPTION_CAPABILITIES
            enc_cap['data'] = SMB2EncryptionCapabilities()
            supported_ciphers = Ciphers.get_supported_ciphers()
            enc_cap['data']['ciphers'] = supported_ciphers
            # remove extra padding for last list entry
            enc_cap['padding'].size = 0
            enc_cap['padding'] = b""
            log.debug("Adding encryption capabilities of AES128 GCM and "
                      "AES128 CCM to negotiate request")

            neg_req['negotiate_context_list'] = [
                int_cap,
                enc_cap
            ]

        log.info("Sending SMB2 Negotiate message")
        log.debug(neg_req)
        request = self.send(neg_req)
        self.preauth_integrity_hash_value.append(request.message)

        response = self.receive(request, timeout=timeout)
        log.info("Receiving SMB2 Negotiate response")
        log.debug(response)
        self.preauth_integrity_hash_value.append(response)

        smb_response = SMB2NegotiateResponse()
        smb_response.unpack(response['data'].get_value())

        return smb_response

    def _calculate_credit_charge(self, message):
        """
        Calculates the credit charge for a request based on the command. If
        connection.supports_multi_credit is not True then the credit charge
        isn't valid so it returns 0.

        The credit charge is the number of credits that are required for
        sending/receiving data over 64 kilobytes, in the existing messages only
        the Read, Write, Query Directory or IOCTL commands will end in this
        scenario and each require their own calculation to get the proper
        value. The generic formula for calculating the credit charge is

        https://msdn.microsoft.com/en-us/library/dn529312.aspx
        (max(SendPayloadSize, Expected ResponsePayloadSize) - 1) / 65536 + 1

        :param message: The message being sent
        :return: The credit charge to set on the header
        """
        credit_size = MAX_PAYLOAD_SIZE

        if (not self.supports_multi_credit) or (message.COMMAND == Commands.SMB2_CANCEL):
            credit_charge = 0
        elif message.COMMAND == Commands.SMB2_READ:
            max_size = message['length'].get_value() + \
                       message['read_channel_info_length'].get_value() - 1
            credit_charge = math.ceil(max_size / credit_size)
        elif message.COMMAND == Commands.SMB2_WRITE:
            max_size = message['length'].get_value() + \
                       message['write_channel_info_length'].get_value() - 1
            credit_charge = math.ceil(max_size / credit_size)
        elif message.COMMAND == Commands.SMB2_IOCTL:
            max_in_size = len(message['buffer'])
            max_out_size = message['max_output_response'].get_value()
            max_size = max(max_in_size, max_out_size) - 1
            credit_charge = math.ceil(max_size / credit_size)
        elif message.COMMAND == Commands.SMB2_QUERY_DIRECTORY:
            max_in_size = len(message['buffer'])
            max_out_size = message['output_buffer_length'].get_value()
            max_size = max(max_in_size, max_out_size) - 1
            credit_charge = math.ceil(max_size / credit_size)
        else:
            credit_charge = 1

        # python 2 returns a float where we need an integer
        return int(credit_charge)


class Request(object):

    def __init__(self, message, message_type, connection, session_id=None):
        """
        [MS-SMB2] v53.0 2017-09-15

        3.2.1.7 Per Pending Request
        For each request that was sent to the server and is await a response
        :param message: The message to be sent in the request
        :param message_type: The type of message that is set in the header's data field.
        :param connection: The Connection the request was sent under.
        :param session_id: The Session Id the request was for.
        """
        self.async_id = None
        self.message = message
        self.timestamp = datetime.now()
        self.cancelled = False

        # Used to contain the corresponding response from the server as the receiving in done in a separate thread.
        self.response = None

        # Used by the recv processing thread to say the response has been received and is ready for consumption.
        self.response_event = threading.Event()

        # Used to lock the request when the main thread is processing the PENDING result in case the background thread
        # receives the final result and fires the event before main clears it.
        self.response_event_lock = threading.Lock()

        # Stores the message_ids of related messages that are sent in a compound request. This is only set on the 1st
        # message in the request. Used when STATUS_STOPPED_ON_SYMLINK is set and we need to send the whole compound
        # request again with the new path.
        self.related_ids = []

        # Cannot rely on the message values as it could be a related compound msg which does not set these values.
        self.session_id = session_id

        self._connection = connection
        self._message_type = message_type  # Used to rehydrate the message data in case it's needed again.

    def cancel(self):
        if self.cancelled is True:
            return

        message_id = self.message['message_id'].get_value()
        log.info("Cancelling message %s" % message_id)
        self._connection.send(SMB2CancelRequest(), sid=self.session_id, credit_request=0, message_id=message_id,
                              async_id=self.async_id)
        self.cancelled = True

    def get_message_data(self):
        message_obj = self._message_type()
        message_obj.unpack(self.message['data'].get_value())
        return message_obj

    def update_request(self, new_request):
        self.async_id = new_request.async_id
        self.message = new_request.message
        self.timestamp = new_request.timestamp
        self.response = new_request.response
        self.response_event = new_request.response_event
        self.response_event_lock = new_request.response_event_lock
        self.related_ids = new_request.related_ids