#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
#  mayhem/exploit/windows.py
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are
#  met:
#
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above
#    copyright notice, this list of conditions and the following disclaimer
#    in the documentation and/or other materials provided with the
#    distribution.
#  * Neither the name of the project nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
#  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
#  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
#  A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
#  OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
#  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
#  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
#  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
#  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
#  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import code
import ctypes
import os
import platform
import sqlite3

import mayhem.datatypes.windows as wintypes
import mayhem.utilities
import mayhem.windll.kernel32 as m_k32
import mayhem.windll.ntdll as m_ntdll
from mayhem.proc.windows import flags
from mayhem.proc.windows import process_is_wow64
from mayhem.proc.windows import WindowsProcess

INVALID_HANDLE = (0xffffffffffffffff if platform.architecture()[0] == '64bit' else 0xffffffff)

class Driver(object):
	"""
	An object of conveniently opening a handle to a device driver and then
	communicating with it using NtDeviceIoControlFile.
	"""
	def __init__(self, handle):
		self.handle = handle

	def close(self):
		m_k32.CloseHandle(self.handle)

	@classmethod
	def from_create_file(cls, file_name):
		handle = m_k32.CreateFileW(
			file_name,      # lpFileName [in]
			3,              # dwDesiredAccess [in]
			0,              # dwShareMode [in]
			None,           # lpSecurityAttributes [in-opt]
			3,              # dwCreationDesposition [in]
			0,              # dwFlagsAndAttributes [in]
			0               # hTemplateFile [in-opt]
		)
		if handle == INVALID_HANDLE:
			raise ctypes.WinError()
		return cls(handle)

	def io_control_file(self, io_control_code, input_buffer=None, output_buffer_length=None):
		io_status_block = wintypes.IO_STATUS_BLOCK()
		input_buffer_length = (0 if input_buffer is None else len(input_buffer))
		if output_buffer_length is None:
			output_buffer = None
			output_buffer_length = 0
		else:
			output_buffer = (ctypes.c_byte * output_buffer_length)()
		value = m_ntdll.NtDeviceIoControlFile(
			self.handle,                         # FileHandle [in]
			None,                                # Event [in-opt]
			None,                                # ApcRoutine [in-opt]
			None,                                # ApcContext [in-opt]
			ctypes.byref(io_status_block),       # IoStatusBlock [out]
			io_control_code,                     # IoControlCode [in]
			input_buffer,                        # InputBuffer [in-opt]
			input_buffer_length,                 # InputBufferLength [in]
			output_buffer,                       # OutputBuffer [out-opt]
			output_buffer_length                 # OutputBufferLength [out]
		)
		return (value, mayhem.utilities.ctarray_to_bytes(output_buffer))

class WindowsSyscallFunction(object):
	__slots__ = ('__function', 'name', 'number')
	def __init__(self, function, name, number):
		self.__function = function
		self.name = name
		self.number = number

	def __call__(self, *args):
		return self.__function(self.number, *args)

	def __repr__(self):
		return "<{0} (0x{1:04x}) >".format(self.name, self.number)

class SyscallPrototype(ctypes._CFuncPtr):
	_argtypes_ = []
	_restype_ = ctypes.c_uint32
	_flags_ = ctypes._FUNCFLAG_CDECL

class WindowsAsmFunctionBase(object):
	_asm_function_prototype = ctypes.CFUNCTYPE(ctypes.c_void_p)
	_asm_function_stub = None
	def __init__(self, stub=None):
		if process_is_wow64():
			raise RuntimeError('python running in WOW64 is not supported')
		asm_function_stub = stub or self._asm_function_stub
		process_h = WindowsProcess(pid=-1)
		shellcode_sz = mayhem.utilities.align_up(len(asm_function_stub), 1024)
		self.address = process_h.allocate(size=shellcode_sz)
		process_h.write_memory(self.address, asm_function_stub)
		process_h.protect(self.address, size=shellcode_sz, permissions='PAGE_EXECUTE_READ')
		self._asm_function = self._asm_function_prototype(self.address)

	def __call__(self, *args):
		return self._asm_function(*args)

	def __repr__(self):
		return "<{0} address=0x{1:08x} >".format(self.__class__.__name__, self.address)

class WindowsX64FlushReload(WindowsAsmFunctionBase):
	_asm_function_prototype = ctypes.CFUNCTYPE(ctypes.c_uint64, ctypes.c_uint64)
	# this is the "plunger" variant that returns the elapsed time
	# see: https://github.com/DanGe42/flush-reload/releases/download/cis-700-submission/paper.pdf
	_asm_function_stub  = b'\x55'                  # push    rbp
	_asm_function_stub += b'\x48\x89\xe5'          # mov     rbp, rsp
	_asm_function_stub += b'\x53'                  # push    rbx
	_asm_function_stub += b'\x0f\xae\xf0'          # mfence
	_asm_function_stub += b'\x0f\xae\xe8'          # lfence
	_asm_function_stub += b'\x0f\x31'              # rdtsc
	_asm_function_stub += b'\x0f\xae\xe8'          # lfence
	_asm_function_stub += b'\x48\x89\xc3'          # mov     rbx, rax
	_asm_function_stub += b'\x48\x8b\x01'          # mov     rax, QWORD PTR [rcx]
	_asm_function_stub += b'\x0f\xae\xe8'          # lfence
	_asm_function_stub += b'\x0f\x31'              # rdtsc
	_asm_function_stub += b'\x48\x29\xd8'          # sub      rax, rbx
	_asm_function_stub += b'\x0f\xae\x39'          # clflush  BYTE PTR [rcx]
	_asm_function_stub += b'\x5b'                  # pop      rbx
	_asm_function_stub += b'\x5d'                  # pop      rbp
	_asm_function_stub += b'\xc3'                  # ret
	def __call__(self, address):
		return super(WindowsX64FlushReload, self).__call__(address)

class WindowsSyscallBase(WindowsAsmFunctionBase):
	__name_map = {
		'Windows-7-6.1.7601-SP1': '7 SP1'
	}
	_syscall_db_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'data', 'windows_syscalls.db'))
	_asm_function_prototype = ctypes.CFUNCTYPE(ctypes.c_uint32)
	def __init__(self, os_name=None):
		"""
		:param str os_name: The name of the Windows operating system for which to resolve syscall names to numbers.
		"""
		super(WindowsSyscallBase, self).__init__()
		self._syscall_db_con = None
		if os_name is None:
			os_name = self.__name_map.get(platform.platform())
		if os_name is not None:
			self._syscall_db_con = sqlite3.connect(self._syscall_db_path, check_same_thread=False)
			if not self._syscall_db_con.execute('SELECT COUNT(os_name) FROM syscalls WHERE arch = ? AND os_name = ?', (self._syscall_arch, os_name,)).fetchone()[0]:
				raise ValueError('no syscall numbers available in the database for ' + os_name)
		self.os_name = os_name

	def __call__(self, syscall, *args):
		if isinstance(syscall, str):
			syscall_number = self.lookup(syscall)
			if syscall_number is None:
				raise LookupError('failed to look up the specified syscall')
		elif isinstance(syscall, int):
			syscall_number = syscall
		else:
			raise TypeError('argument 1 must be str or int')
		return self._asm_function(syscall_number, *args)

	def __getattr__(self, syscall_name):
		syscall_number = self.lookup(syscall_name)
		if syscall_number is None:
			raise LookupError('failed to look up the specified syscall')
		return WindowsSyscallFunction(self._asm_function, syscall_name, syscall_number)

	def __repr__(self):
		return "<{0} address=0x{1:08x} os_name='{2}' >".format(self.__class__.__name__, self.address, self.os_name)

	def lookup(self, syscall_name):
		"""
		Lookup the number for a syscall by it's name.

		:param str syscall_name: The name of the syscall to lookup.
		:return: The syscall's number specific to the environment.
		:rtype: int
		"""
		if not self._syscall_db_con:
			raise AttributeError('no syscall database is connected')
		cur = self._syscall_db_con.execute(
			'SELECT number FROM syscalls WHERE name = ? AND arch = ? AND os_name = ?',
			(syscall_name, self._syscall_arch, self.os_name)
		)
		result = cur.fetchone()
		if result:
			result = result[0]
		return result

class WindowsX64Syscall(WindowsSyscallBase):
	_syscall_arch = 'x64'
	_asm_function_stub  = b'\x55'                  # push  rbp
	_asm_function_stub += b'\x48\x89\xe5'          # mov   rbp, rsp
	_asm_function_stub += b'\x41\x51'              # push  r9
	_asm_function_stub += b'\x41\x50'              # push  r8
	_asm_function_stub += b'\x52'                  # push  rdx
	_asm_function_stub += b'\x51'                  # push  rcx
	_asm_function_stub += b'\xff\x75\x50'          # push  QWORD PTR [rbp+0x50]
	_asm_function_stub += b'\xff\x75\x48'          # push  QWORD PTR [rbp+0x48]
	_asm_function_stub += b'\xff\x75\x40'          # push  QWORD PTR [rbp+0x40]
	_asm_function_stub += b'\xff\x75\x38'          # push  QWORD PTR [rbp+0x38]
	_asm_function_stub += b'\x48\x83\xec\x28'      # sub   rsp, 0x28
	_asm_function_stub += b'\x48\x89\xc8'          # mov   rax, rcx
	_asm_function_stub += b'\x48\x89\xd1'          # mov   rcx, rdx
	_asm_function_stub += b'\x4c\x89\xc2'          # mov   rdx, r8
	_asm_function_stub += b'\x4d\x89\xc8'          # mov   r8, r9
	_asm_function_stub += b'\x4c\x8b\x4d\x30'      # mov   r9, QWORD PTR [rbp+0x30]
	_asm_function_stub += b'\x49\x89\xca'          # mov   r10, rcx
	_asm_function_stub += b'\x0f\x05'              # syscall
	_asm_function_stub += b'\x48\x83\xc4\x48'      # add   rsp, 0x48
	_asm_function_stub += b'\x59'                  # pop   rcx
	_asm_function_stub += b'\x5a'                  # pop   rdx
	_asm_function_stub += b'\x41\x58'              # pop   r8
	_asm_function_stub += b'\x41\x59'              # pop   r9
	_asm_function_stub += b'\x5d'                  # pop   rbp
	_asm_function_stub += b'\xc3'                  # ret

class WindowsX86Syscall(WindowsSyscallBase):
	_syscall_arch = 'x86'
	_asm_function_stub  = b'\x5a'                  # pop   edx  ; ret -> edx
	_asm_function_stub += b'\x58'                  # pop   eax  ; arg0 -> eax
	_asm_function_stub += b'\x6a\x00'              # push  0
	_asm_function_stub += b'\x50'                  # push  eax
	_asm_function_stub += b'\x52'                  # push  edx
	_asm_function_stub += b'\x83\xc4\x08'          # add   esp, 0x8
	_asm_function_stub += b'\xba\x00\x03\xfe\x7f'  # mov   edx, 0x7ffe0300
	_asm_function_stub += b'\xff\x12'              # call  DWORD PTR [edx]
	_asm_function_stub += b'\x83\xec\x08'          # sub   esp, 0xc8
	_asm_function_stub += b'\x5a'                  # pop   edx  ; ret -> edx
	_asm_function_stub += b'\x83\xc4\x04'          # add   esp, 0x4
	_asm_function_stub += b'\x52'                  # push  edx
	_asm_function_stub += b'\xc3'                  # ret

if platform.architecture()[0] == '64bit':
	_WindowsSyscall = WindowsX64Syscall
else:
	_WindowsSyscall = WindowsX86Syscall

class WindowsSyscall(_WindowsSyscall):
	"""
	An object which facilitates the dynamic execution of raw syscalls though an
	assembly stub. This allows syscalls to be executed like other functions
	using the Python ctypes library.

	Example Usage:

	.. code-block:: python

	  # initialize the object for Windows 7 SP1
	  syscall = WindowsSyscall('7 SP1')
	  syscall.NtQuerySystemInformation(5, 0x010000, 1024**2, None)
	"""
	pass

def allocate_null_page(size=0x1000):
	address = ctypes.c_void_p(1)
	if platform.architecture()[0] == '64bit':
		page_size = ctypes.c_uint64()
	else:
		page_size = ctypes.c_uint32()
	page_size.value = size
	result = m_ntdll.NtAllocateVirtualMemory(
		-1,
		ctypes.byref(address),
		0,
		ctypes.byref(page_size),
		flags('MEM_RESERVE | MEM_COMMIT | MEM_TOP_DOWN'),
		flags('PAGE_EXECUTE_READWRITE')
	)
	return result == 0

def error_on_null(value):
	"""
	Check value and raise an appropriate error message by checking
	user32.GetLastError when it is NULL.
	"""
	if value == 0:
		raise ctypes.WinError()
	return value

def find_driver_base(driver=None):
	"""
	Get the base address of the specified driver or the NT Kernel if none is
	specified.

	:param str driver: The name of the driver to get the base address of.
	:return: The base address and the driver name.
	:rtype: tuple
	"""
	if platform.architecture()[0] == '64bit':
		lpImageBase = (ctypes.c_ulonglong * 1024)()
		lpcbNeeded = ctypes.c_longlong()
		ctypes.windll.psapi.GetDeviceDriverBaseNameA.argtypes = [ctypes.c_longlong, ctypes.POINTER(ctypes.c_char), ctypes.c_uint32]
	else:
		if process_is_wow64():
			raise RuntimeError('python running in WOW64 is not supported')
		lpImageBase = (ctypes.c_ulong * 1024)()
		lpcbNeeded = ctypes.c_long()
	driver_name_size = ctypes.c_long()
	driver_name_size.value = 48
	ctypes.windll.psapi.EnumDeviceDrivers(ctypes.byref(lpImageBase), ctypes.c_int(1024), ctypes.byref(lpcbNeeded))
	for base_addr in lpImageBase:
		driver_name = ctypes.c_char_p(b'\x00' * driver_name_size.value)
		if base_addr:
			ctypes.windll.psapi.GetDeviceDriverBaseNameA(base_addr, driver_name, driver_name_size.value)
			driver_name_value = driver_name.value.decode('utf-8')
			if driver is None and driver_name_value.lower().find("krnl") != -1:
				return base_addr, driver_name_value
			elif driver_name_value.lower() == driver:
				return base_addr, driver_name_value
	return None

def get_haldispatchtable():
	"""
	Get the address of the halDispatchTable.

	:return: The address of the halDispatchTable.
	:rtype: int
	"""
	if process_is_wow64():
		raise RuntimeError('python running in WOW64 is not supported')
	(krnlbase, kernelver) = find_driver_base()
	hKernel = m_k32.LoadLibraryExA(kernelver, 0, 1)
	halDispatchTable = m_k32.GetProcAddress(hKernel, 'HalDispatchTable')
	halDispatchTable -= hKernel
	halDispatchTable += krnlbase
	return halDispatchTable

def interact(banner=None, local=None):
	new_local = {
		'ctypes': ctypes,
		'gdi32': ctypes.windll.gdi32,
		'kernel32': ctypes.windll.kernel32,
		'ntdll': ctypes.windll.ntdll,
		'user32': ctypes.windll.user32,
		'WindowsProcess': 'WindowsProcess'
	}
	if local is not None:
		new_local.update(local)

	code.interact(banner=banner, local=new_local)

def print_handle(handle):
	user_object_types = dict(enumerate((
		'Free',
		'Window',
		'Menu',
		'Cursor',
		'SetWindowPos',
		'Hook',
		'Clipboard Data',
		'CallProcData',
		'Accelerator',
		'DDE Access',
		'DDE Conversation',
		'DDE Transaction',
		'Monitor',
		'Keyboard Layout',
		'Keyboard File',
		'Event Hook',
		'Timer',
		'Input Context',
		'Hid Data',
		'Device Info',
		'Touch',
		'Gesture'
	)))
	hid = handle & 0xffff
	shared_info = wintypes.SHARED_INFO.from_user32()
	user_addr = shared_info.aheList + (ctypes.sizeof(wintypes.HANDLE_ENTRY) * hid)
	han = wintypes.HANDLE_ENTRY.from_address(user_addr)
	print("wintypes.HANDLE_ENTRY[0x{0:04x}] (kernel: 0x{1:016x} user: 0x{2:016x})".format(hid, (user_addr + shared_info.ulSharedDelta), user_addr))
	print("  phead:  0x{0:08x}".format(han.phead or 0))
	print("  pOwner: 0x{0:08x}".format(han.pOwner or 0))
	print("  bType:  0x{0:02x} ({1})".format(han.bType, user_object_types.get(han.bType)))
	print("  bFlags: 0x{0:02x}".format(han.bFlags))
	print("  wUniq:  0x{0:04x}".format(han.wUniq))