# Volatility
# Copyright (C) 2007-2013 Volatility Foundation
# Copyright (c) 2010, 2011, 2012 Michael Ligh <michael.ligh@mnin.org>
#
# This file is part of Volatility.
#
# Volatility is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# Volatility is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Volatility.  If not, see <http://www.gnu.org/licenses/>.
#

# File-wide pylint filter for protected members, since we have three _BLAH structures
#pylint: disable-msg=W0212

import os
import volatility.utils as utils
import volatility.obj as obj
import volatility.debug as debug
import volatility.win32.tasks as tasks
import volatility.win32.modules as modules
import volatility.plugins.taskmods as taskmods
import volatility.plugins.vadinfo as vadinfo
import volatility.plugins.overlays.windows.windows as windows
import volatility.constants as constants

try:
    import yara
    has_yara = True
except ImportError:
    has_yara = False

try:
    import distorm3
    has_distorm3 = True
except ImportError:
    has_distorm3 = False

#--------------------------------------------------------------------------------
# functions 
#--------------------------------------------------------------------------------

def Disassemble(data, start, bits = '32bit', stoponret = False):
    """Dissassemble code with distorm3. 

    @param data: python byte str to decode
    @param start: address where `data` is found in memory
    @param bits: use 32bit or 64bit decoding 
    @param stoponret: stop disasm when function end is reached
    
    @returns: tuple of (offset, instruction, hex bytes)
    """

    if not has_distorm3:
        raise StopIteration

    if bits == '32bit':
        mode = distorm3.Decode32Bits
    else:
        mode = distorm3.Decode64Bits

    for o, _, i, h in distorm3.DecodeGenerator(start, data, mode):
        if stoponret and i.startswith("RET"):
            raise StopIteration
        yield o, i, h

#--------------------------------------------------------------------------------
# scanners by scudette
#
# unfortunately the existing scanning framework (i.e. scan.BaseScanner) has 
# some shortcomings that don't allow us to integrate yara easily. 
#
# FIXME: these may need updating after resolving issue 310 which aims to 
# enhance the scan.BaseScanner to better support things like this
#--------------------------------------------------------------------------------

class BaseYaraScanner(object):
    """An address space scanner for Yara signatures."""
    overlap = 1024

    def __init__(self, address_space = None, rules = None):
        self.rules = rules
        self.address_space = address_space

    def scan(self, offset, maxlen):
        # Start scanning from offset until maxlen:
        i = offset
        
        if isinstance(self.rules, list):
            rules = self.rules
        else:
            rules = [self.rules]

        while i < offset + maxlen:
            # Read some data and match it.
            to_read = min(constants.SCAN_BLOCKSIZE + self.overlap, offset + maxlen - i)
            data = self.address_space.zread(i, to_read)
            if data:
                for rule in rules:
                    for match in rule.match(data = data):
                        # We currently don't use name or value from the 
                        # yara results but they can be yielded in the 
                        # future if necessary. 
                        for moffset, _name, _value in match.strings:
                            if moffset < constants.SCAN_BLOCKSIZE:
                                yield match, moffset + i

            i += constants.SCAN_BLOCKSIZE

class VadYaraScanner(BaseYaraScanner):
    """A scanner over all memory regions of a process."""

    def __init__(self, task = None, **kwargs):
        """Scan the process address space through the Vads.

        Args:
          task: The _EPROCESS object for this task.
        """
        self.task = task
        BaseYaraScanner.__init__(self, address_space = task.get_process_address_space(), **kwargs)

    def scan(self, offset = 0, maxlen = None):
        for vad, self.address_space in self.task.get_vads(skip_max_commit = True):
            for match in BaseYaraScanner.scan(self, vad.Start, vad.Length):
                yield match

class DiscontigYaraScanner(BaseYaraScanner):
    """A Scanner for Discontiguous scanning."""

    def scan(self, start_offset = 0, maxlen = None):
        contiguous_offset = 0
        total_length = 0
        for (offset, length) in self.address_space.get_available_addresses():
            # Skip ranges before the start_offset
            if self.address_space.address_compare(offset, start_offset) == -1:
                continue

            # Skip ranges that are too high (if maxlen is specified)
            if maxlen != None:
                if self.address_space.address_compare(offset, start_offset + maxlen) > 0:
                    continue

            # Try to join up adjacent pages as much as possible.
            if offset == contiguous_offset + total_length:
                total_length += length
            else:
                # Scan the last contiguous range.
                for match in BaseYaraScanner.scan(self, contiguous_offset, total_length):
                    yield match

                # Reset the contiguous range.
                contiguous_offset = offset
                total_length = length

        if total_length > 0:
            # Do the last range.
            for match in BaseYaraScanner.scan(self, contiguous_offset, total_length):
                yield match

#--------------------------------------------------------------------------------
# yarascan
#--------------------------------------------------------------------------------

class YaraScan(taskmods.DllList):
    "Scan process or kernel memory with Yara signatures"

    def __init__(self, config, *args, **kwargs):
        taskmods.DllList.__init__(self, config, *args, **kwargs)
        config.add_option("KERNEL", short_option = 'K', default = False, action = 'store_true',
                        help = 'Scan kernel modules')
        config.add_option("WIDE", short_option = 'W', default = False, action = 'store_true',
                        help = 'Match wide (unicode) strings')
        config.add_option('YARA-RULES', short_option = 'Y', default = None,
                        help = 'Yara rules (as a string)')
        config.add_option('YARA-FILE', short_option = 'y', default = None,
                        help = 'Yara rules (rules file)')
        config.add_option('DUMP-DIR', short_option = 'D', default = None,
                        help = 'Directory in which to dump the files')
        config.add_option('SIZE', short_option = 's', default = 256,
                          help = 'Size of preview hexdump (in bytes)',
                          action = 'store', type = 'int')
        config.add_option('REVERSE', short_option = 'R', default = 0,
                          help = 'Reverse this number of bytes',
                          action = 'store', type = 'int')

    def _compile_rules(self):
        """Compile the YARA rules from command-line parameters. 
        
        @returns: a YARA object on which you can call 'match'
        
        This function causes the plugin to exit if the YARA 
        rules have syntax errors or are not supplied correctly. 
        """
    
        rules = None
    
        try:
            if self._config.YARA_RULES:
                s = self._config.YARA_RULES
                # Don't wrap hex or regex rules in quotes 
                if s[0] not in ("{", "/"): s = '"' + s + '"'
                # Scan for unicode strings 
                if self._config.WIDE: s += "wide"
                rules = yara.compile(sources = {
                            'n' : 'rule r1 {strings: $a = ' + s + ' condition: $a}'
                            })
            elif self._config.YARA_FILE:
                rules = yara.compile(self._config.YARA_FILE)
            else:
                debug.error("You must specify a string (-Y) or a rules file (-y)")
        except yara.SyntaxError, why:
            debug.error("Cannot compile rules: {0}".format(str(why)))
            
        return rules

    def calculate(self):

        if not has_yara:
            debug.error("Please install Yara from https://plusvic.github.io/yara/")

        addr_space = utils.load_as(self._config)

        rules = self._compile_rules()

        if self._config.KERNEL:

            # Find KDBG so we know where kernel memory begins. Do not assume
            # the starting range is 0x80000000 because we may be dealing with
            # an image with the /3GB boot switch. 
            kdbg = tasks.get_kdbg(addr_space)

            start = kdbg.MmSystemRangeStart.dereference_as("Pointer")

            # Modules so we can map addresses to owners
            mods = dict((addr_space.address_mask(mod.DllBase), mod)
                        for mod in modules.lsmod(addr_space))
            mod_addrs = sorted(mods.keys())

            # There are multiple views (GUI sessions) of kernel memory.
            # Since we're scanning virtual memory and not physical, 
            # all sessions must be scanned for full coverage. This 
            # really only has a positive effect if the data you're
            # searching for is in GUI memory. 
            sessions = []

            for proc in tasks.pslist(addr_space):
                sid = proc.SessionId
                # Skip sessions we've already seen 
                if sid == None or sid in sessions:
                    continue

                session_space = proc.get_process_address_space()
                if session_space == None:
                    continue

                sessions.append(sid)
                scanner = DiscontigYaraScanner(address_space = session_space,
                                               rules = rules)

                for hit, address in scanner.scan(start_offset = start):
                    module = tasks.find_module(mods, mod_addrs, addr_space.address_mask(address))
                    yield (module, address, hit, session_space.zread(address - self._config.REVERSE, self._config.SIZE))

        else:
            for task in self.filter_tasks(tasks.pslist(addr_space)):
                scanner = VadYaraScanner(task = task, rules = rules)
                for hit, address in scanner.scan():
                    yield (task, address, hit, scanner.address_space.zread(address - self._config.REVERSE, self._config.SIZE))

    def render_text(self, outfd, data):

        if self._config.DUMP_DIR and not os.path.isdir(self._config.DUMP_DIR):
            debug.error(self._config.DUMP_DIR + " is not a directory")

        for o, addr, hit, content in data:
            outfd.write("Rule: {0}\n".format(hit.rule))

            # Find out if the hit is from user or kernel mode 
            if o == None:
                outfd.write("Owner: (Unknown Kernel Memory)\n")
                filename = "kernel.{0:#x}.dmp".format(addr)
            elif o.obj_name == "_EPROCESS":
                outfd.write("Owner: Process {0} Pid {1}\n".format(o.ImageFileName,
                    o.UniqueProcessId))
                filename = "process.{0:#x}.{1:#x}.dmp".format(o.obj_offset, addr)
            else:
                outfd.write("Owner: {0}\n".format(o.BaseDllName))
                filename = "kernel.{0:#x}.{1:#x}.dmp".format(o.obj_offset, addr)

            # Dump the data if --dump-dir was supplied
            if self._config.DUMP_DIR:
                path = os.path.join(self._config.DUMP_DIR, filename)
                fh = open(path, "wb")
                fh.write(content)
                fh.close()

            outfd.write("".join(
                ["{0:#010x}  {1:<48}  {2}\n".format(addr + o, h, ''.join(c))
                for o, h, c in utils.Hexdump(content)
                ]))

#--------------------------------------------------------------------------------
# malfind
#--------------------------------------------------------------------------------

class Malfind(vadinfo.VADDump):
    "Find hidden and injected code"

    def __init__(self, config, *args, **kwargs):
        vadinfo.VADDump.__init__(self, config, *args, **kwargs)
        config.remove_option("BASE")

    def _is_vad_empty(self, vad, address_space):
        """
        Check if a VAD region is either entirely unavailable 
        due to paging, entirely consiting of zeros, or a 
        combination of the two. This helps ignore false positives
        whose VAD flags match task._injection_filter requirements
        but there's no data and thus not worth reporting it. 

        @param vad: an MMVAD object in kernel AS
        @param address_space: the process address space 
        """
        
        PAGE_SIZE = 0x1000
        all_zero_page = "\x00" * PAGE_SIZE

        offset = 0
        while offset < vad.Length:
            next_addr = vad.Start + offset
            if (address_space.is_valid_address(next_addr) and 
                    address_space.read(next_addr, PAGE_SIZE) != all_zero_page):
                return False
            offset += PAGE_SIZE

        return True

    def render_text(self, outfd, data):

        if not has_distorm3:
            debug.warning("For best results please install distorm3")

        if self._config.DUMP_DIR and not os.path.isdir(self._config.DUMP_DIR):
            debug.error(self._config.DUMP_DIR + " is not a directory")

        for task in data:
            for vad, address_space in task.get_vads(vad_filter = task._injection_filter):

                if self._is_vad_empty(vad, address_space):
                    continue

                content = address_space.zread(vad.Start, 64)

                outfd.write("Process: {0} Pid: {1} Address: {2:#x}\n".format(
                    task.ImageFileName, task.UniqueProcessId, vad.Start))

                outfd.write("Vad Tag: {0} Protection: {1}\n".format(
                    vad.Tag, vadinfo.PROTECT_FLAGS.get(vad.u.VadFlags.Protection.v(), "")))

                outfd.write("Flags: {0}\n".format(str(vad.u.VadFlags)))
                outfd.write("\n")

                outfd.write("{0}\n".format("\n".join(
                    ["{0:#010x}  {1:<48}  {2}".format(vad.Start + o, h, ''.join(c))
                    for o, h, c in utils.Hexdump(content)
                    ])))

                outfd.write("\n")
                outfd.write("\n".join(
                    ["{0:#x} {1:<16} {2}".format(o, h, i)
                    for o, i, h in Disassemble(content, vad.Start)
                    ]))

                # Dump the data if --dump-dir was supplied
                if self._config.DUMP_DIR:

                    filename = os.path.join(self._config.DUMP_DIR,
                        "process.{0:#x}.{1:#x}.dmp".format(
                        task.obj_offset, vad.Start))

                    self.dump_vad(filename, vad, address_space)

                outfd.write("\n\n")

#--------------------------------------------------------------------------------
# ldrmodules 
#--------------------------------------------------------------------------------

class LdrModules(taskmods.DllList):
    "Detect unlinked DLLs"

    def render_text(self, outfd, data):

        self.table_header(outfd,
            [("Pid", "8"),
             ("Process", "20"),
             ("Base", "[addrpad]"),
             ("InLoad", "5"),
             ("InInit", "5"),
             ("InMem", "5"),
             ("MappedPath", "")
            ])

        for task in data:
            # Build a dictionary for all three PEB lists where the
            # keys are base address and module objects are the values
            inloadorder = dict((mod.DllBase.v(), mod)
                                for mod in task.get_load_modules())
            ininitorder = dict((mod.DllBase.v(), mod)
                                for mod in task.get_init_modules())
            inmemorder = dict((mod.DllBase.v(), mod)
                                for mod in task.get_mem_modules())

            # Build a similar dictionary for the mapped files 
            mapped_files = {}
            for vad, address_space in task.get_vads(vad_filter = task._mapped_file_filter):
                # Note this is a lot faster than acquiring the full
                # vad region and then checking the first two bytes. 
                if obj.Object("_IMAGE_DOS_HEADER", offset = vad.Start, vm = address_space).e_magic != 0x5A4D:
                    continue
                mapped_files[int(vad.Start)] = str(vad.FileObject.FileName or '')

            # For each base address with a mapped file, print info on 
            # the other PEB lists to spot discrepancies. 
            for base in mapped_files.keys():
                # Does the base address exist in the PEB DLL lists?
                load_mod = inloadorder.get(base, None)
                init_mod = ininitorder.get(base, None)
                mem_mod = inmemorder.get(base, None)
                # Report if the mapped files are in the PEB lists
                self.table_row(outfd,
                        task.UniqueProcessId,
                        task.ImageFileName,
                        base,
                        str(load_mod != None),
                        str(init_mod != None),
                        str(mem_mod != None),
                        mapped_files[base]
                        )
                # Print the full paths and base names in verbose mode 
                if self._config.verbose:
                    if load_mod:
                        outfd.write("  Load Path: {0} : {1}\n".format(load_mod.FullDllName, load_mod.BaseDllName))
                    if init_mod:
                        outfd.write("  Init Path: {0} : {1}\n".format(init_mod.FullDllName, init_mod.BaseDllName))
                    if mem_mod:
                        outfd.write("  Mem Path:  {0} : {1}\n".format(mem_mod.FullDllName, mem_mod.BaseDllName))