#!/usr/bin/env python

"""
Module that performs extraction. For usage, refer to documentation for the class
'Extractor'. This module can also be executed directly,
e.g. 'extractor.py <input> <output>'.
"""

import argparse
import hashlib
import multiprocessing
import os
import shutil
import tempfile
import traceback

import magic
import binwalk

class Extractor(object):
    """
    Class that extracts kernels and filesystems from firmware images, given an
    input file or directory and output directory.
    """

    # Directories that define the root of a UNIX filesystem, and the
    # appropriate threshold condition
    UNIX_DIRS = ["bin", "etc", "dev", "home", "lib", "mnt", "opt", "root",
                 "run", "sbin", "tmp", "usr", "var"]
    UNIX_THRESHOLD = 4

    # Lock to prevent concurrent access to visited set. Unfortunately, must be
    # static because it cannot be pickled or passed as instance attribute.
    visited_lock = multiprocessing.Lock()

    def __init__(self, indir, outdir=None, rootfs=True, kernel=True,
                 numproc=True, server=None, brand=None):
        # Input firmware update file or directory
        self._input = os.path.abspath(indir)
        # Output firmware directory
        self.output_dir = os.path.abspath(outdir) if outdir else None

        # Whether to attempt to extract kernel
        self.do_kernel = kernel

        # Whether to attempt to extract root filesystem
        self.do_rootfs = rootfs

        # Brand of the firmware
        self.brand = brand

        # Hostname of SQL server
        self.database = server

        # Worker pool.
        self._pool = multiprocessing.Pool() if numproc else None

        # Set containing MD5 checksums of visited items
        self.visited = set()

        # List containing tagged items to extract as 2-tuple: (tag [e.g. MD5],
        # path)
        self._list = list()

    def __getstate__(self):
        """
        Eliminate attributes that should not be pickled.
        """
        self_dict = self.__dict__.copy()
        del self_dict["_pool"]
        del self_dict["_list"]
        return self_dict

    @staticmethod
    def io_dd(indir, offset, size, outdir):
        """
        Given a path to a target file, extract size bytes from specified offset
        to given output file.
        """
        if not size:
            return

        with open(indir, "rb") as ifp:
            with open(outdir, "wb") as ofp:
                ifp.seek(offset, 0)
                ofp.write(ifp.read(size))

    @staticmethod
    def magic(indata, mime=False):
        """
        Performs file magic while maintaining compatibility with different
        libraries.
        """

        try:
            if mime:
                mymagic = magic.open(magic.MAGIC_MIME_TYPE)
            else:
                mymagic = magic.open(magic.MAGIC_NONE)
            mymagic.load()
        except AttributeError:
            mymagic = magic.Magic(mime)
            mymagic.file = mymagic.from_file
        return mymagic.file(indata)

    @staticmethod
    def io_md5(target):
        """
        Performs MD5 with a block size of 64kb.
        """
        blocksize = 65536
        hasher = hashlib.md5()

        with open(target, 'rb') as ifp:
            buf = ifp.read(blocksize)
            while buf:
                hasher.update(buf)
                buf = ifp.read(blocksize)
            return hasher.hexdigest()

    @staticmethod
    def io_rm(target):
        """
        Attempts to recursively delete a directory.
        """
        shutil.rmtree(target, ignore_errors=False, onerror=Extractor._io_err)

    @staticmethod
    def _io_err(function, path, excinfo):
        """
        Internal function used by '_rm' to print out errors.
        """
        print(("!! %s: Cannot delete %s!\n%s" % (function, path, excinfo)))

    @staticmethod
    def io_find_rootfs(start, recurse=True):
        """
        Attempts to find a Linux root directory.
        """

        # Recurse into single directory chains, e.g. jffs2-root/fs_1/.../
        path = start
        while (len(os.listdir(path)) == 1 and
               os.path.isdir(os.path.join(path, os.listdir(path)[0]))):
            path = os.path.join(path, os.listdir(path)[0])

        # count number of unix-like directories
        count = 0
        for subdir in os.listdir(path):
            if subdir in Extractor.UNIX_DIRS and \
                os.path.isdir(os.path.join(path, subdir)):
                count += 1

        # check for extracted filesystem, otherwise update queue
        if count >= Extractor.UNIX_THRESHOLD:
            return (True, path)

        # in some cases, multiple filesystems may be extracted, so recurse to
        # find best one
        if recurse:
            for subdir in os.listdir(path):
                if os.path.isdir(os.path.join(path, subdir)):
                    res = Extractor.io_find_rootfs(os.path.join(path, subdir),
                                                   False)
                    if res[0]:
                        return res

        return (False, start)

    def extract(self):
        """
        Perform extraction of firmware updates from input to tarballs in output
        directory using a thread pool.
        """
        if os.path.isdir(self._input):
            for path, _, files in os.walk(self._input):
                for item in files:
                    self._list.append(os.path.join(path, item))
        elif os.path.isfile(self._input):
            self._list.append(self._input)

        if self.output_dir and not os.path.isdir(self.output_dir):
            os.makedirs(self.output_dir)

        if self._pool:
            self._pool.map(self._extract_item, self._list)
        else:
            for item in self._list:
                return self._extract_item(item)

    def _extract_item(self, path):
        """
        Wrapper function that creates an ExtractionItem and calls the extract()
        method.
        """

        e = ExtractionItem(self, path, 0)
        e.extract()
        return (e.tag, e.repeated)

class ExtractionItem(object):
    """
    Class that encapsulates the state of a single item that is being extracted.
    """

    # Maximum recursion breadth and depth
    RECURSION_BREADTH = 5
    RECURSION_DEPTH = 2

    def __init__(self, extractor, path, depth, tag=None):
        # Temporary directory
        self.temp = None

        # Recursion depth counter
        self.depth = depth

        # Reference to parent extractor object
        self.extractor = extractor

        # File path
        self.item = path

        # Database connection
        if self.extractor.database:
            import psycopg2
            self.database = psycopg2.connect(database="firmware",
                                             user="firmadyne",
                                             password="firmadyne",
                                             host=self.extractor.database)
        else:
            self.database = None

        # Checksum
        self.checksum = Extractor.io_md5(path)

        # Tag
        self.tag = tag if tag else self.generate_tag()

        # Output file path and filename prefix
        self.output = os.path.join(self.extractor.output_dir, self.tag) if \
                                   self.extractor.output_dir else None

        # Status, with terminate indicating early termination for this item
        self.terminate = False
        self.status = None
        self.update_status()

        self.repeated = False

    def __del__(self):
        if self.database:
            self.database.close()

        if self.temp:
            self.printf(">> Cleaning up %s..." % self.temp)
            Extractor.io_rm(self.temp)

    def printf(self, fmt):
        """
        Prints output string with appropriate depth indentation.
        """
        print(("\t" * self.depth + fmt))

    def generate_tag(self):
        """
        Generate the filename tag.
        """
        if not self.database:
            return os.path.basename(self.item) + "_" + self.checksum

        try:
            image_id = None
            cur = self.database.cursor()
            if self.extractor.brand:
                brand = self.extractor.brand
            else:
                brand = os.path.relpath(self.item).split(os.path.sep)[0]
            cur.execute("SELECT id FROM brand WHERE name=%s", (brand, ))
            brand_id = cur.fetchone()
            if not brand_id:
                cur.execute("INSERT INTO brand (name) VALUES (%s) RETURNING id",
                            (brand, ))
                brand_id = cur.fetchone()
            if brand_id:
                cur.execute("SELECT id FROM image WHERE hash=%s",
                            (self.checksum, ))
                image_id = cur.fetchone()
                if not image_id:
                    cur.execute("INSERT INTO image (filename, brand_id, hash) \
                                VALUES (%s, %s, %s) RETURNING id",
                                (os.path.basename(self.item), brand_id[0],
                                 self.checksum))
                    image_id = cur.fetchone()
            self.database.commit()
        except BaseException:
            traceback.print_exc()
            self.database.rollback()
        finally:
            if cur:
                cur.close()

        if image_id:
            self.printf(">> Database Image ID: %s" % image_id[0])

        return str(image_id[0]) if \
               image_id else os.path.basename(self.item) + "_" + self.checksum

    def get_kernel_status(self):
        """
        Get the flag corresponding to the kernel status.
        """
        return self.status[0]

    def get_rootfs_status(self):
        """
        Get the flag corresponding to the root filesystem status.
        """
        return self.status[1]

    def update_status(self):
        """
        Updates the status flags using the tag to determine completion status.
        """
        kernel_done = os.path.isfile(self.get_kernel_path()) if \
            self.extractor.do_kernel and self.output else \
            not self.extractor.do_kernel
        rootfs_done = os.path.isfile(self.get_rootfs_path()) if \
            self.extractor.do_rootfs and self.output else \
            not self.extractor.do_rootfs
        self.status = (kernel_done, rootfs_done)

        if self.database and kernel_done and self.extractor.do_kernel:
            self.update_database("kernel_extracted", "True")

        if self.database and rootfs_done and self.extractor.do_rootfs:
            self.update_database("rootfs_extracted", "True")

        return self.get_status()

    def update_database(self, field, value):
        """
        Update a given field in the database.
        """
        ret = True
        if self.database:
            try:
                cur = self.database.cursor()
                cur.execute("UPDATE image SET " + field + "='" + value +
                            "' WHERE id=%s", (self.tag, ))
                self.database.commit()
            except BaseException:
                ret = False
                traceback.print_exc()
                self.database.rollback()
            finally:
                if cur:
                    cur.close()
        return ret

    def get_status(self):
        """
        Returns True if early terminate signaled, extraction is complete,
        otherwise False.
        """
        return True if self.terminate or all(i for i in self.status) else False

    def get_kernel_path(self):
        """
        Return the full path (including filename) to the output kernel file.
        """
        return self.output + ".kernel" if self.output else None

    def get_rootfs_path(self):
        """
        Return the full path (including filename) to the output root filesystem
        file.
        """
        return self.output + ".tar.gz" if self.output else None

    def extract(self):
        """
        Perform the actual extraction of firmware updates, recursively. Returns
        True if extraction complete, otherwise False.
        """
        self.printf("\n" + self.item.encode("utf-8", "replace").decode("utf-8"))

        # check if item is complete
        if self.get_status():
            self.printf(">> Skipping: completed!")
            self.repeated = True
            return True

        # check if exceeding recursion depth
        if self.depth > ExtractionItem.RECURSION_DEPTH:
            self.printf(">> Skipping: recursion depth %d" % self.depth)
            return self.get_status()

        # check if checksum is in visited set
        self.printf(">> MD5: %s" % self.checksum)
        with Extractor.visited_lock:
            if self.checksum in self.extractor.visited:
                self.printf(">> Skipping: %s..." % self.checksum)
                return self.get_status()
            else:
                self.extractor.visited.add(self.checksum)

        # check if filetype is blacklisted
        if self._check_blacklist():
            return self.get_status()

        # create working directory
        self.temp = tempfile.mkdtemp()

        try:
            self.printf(">> Tag: %s" % self.tag)
            self.printf(">> Temp: %s" % self.temp)
            self.printf(">> Status: Kernel: %s, Rootfs: %s, Do_Kernel: %s, \
                Do_Rootfs: %s" % (self.get_kernel_status(),
                                  self.get_rootfs_status(),
                                  self.extractor.do_kernel,
                                  self.extractor.do_rootfs))

            for analysis in [self._check_archive, self._check_firmware,
                             self._check_kernel, self._check_rootfs,
                             self._check_compressed]:
                # Move to temporary directory so binwalk does not write to input
                os.chdir(self.temp)

                # Update status only if analysis changed state
                if analysis():
                    if self.update_status():
                        self.printf(">> Skipping: completed!")
                        return True

        except Exception:
            traceback.print_exc()

        return False

    def _check_blacklist(self):
        """
        Check if this file is blacklisted for analysis based on file type.
        """
        # First, use MIME-type to exclude large categories of files
        filetype = Extractor.magic(self.item.encode("utf-8", "surrogateescape"),
                                   mime=True)
        if any(s in filetype for s in ["application/x-dosexec",
                                       "application/pdf",
                                       "application/msword",
                                       "image/", "video/"]):
            self.printf(">> Skipping: %s..." % filetype)
            return True

        # Next, check for specific file types that have MIME-type
        # 'application/octet-stream'
        filetype = Extractor.magic(self.item.encode("utf-8", "surrogateescape"))
        if any(s in filetype for s in ["applet"]):
            self.printf(">> Skipping: %s..." % filetype)
            return True

        # Finally, check for specific file extensions that would be incorrectly
        # identified
        if self.item.endswith(".dmg"):
            self.printf(">> Skipping: %s..." % (self.item))
            return True

        return False

    def _check_archive(self):
        """
        If this file is an archive, recurse over its contents, unless it matches
        an extracted root filesystem.
        """
        return self._check_recursive("archive")

    def _check_firmware(self):
        """
        If this file is of a known firmware type, directly attempt to extract
        the kernel and root filesystem.
        """
        for module in binwalk.scan(self.item, "-y", "header", signature=True,
                                   quiet=True):
            for entry in module.results:
                # uImage
                if "uImage header" in entry.description:
                    if not self.get_kernel_status() and \
                        "OS Kernel Image" in entry.description:
                        kernel_offset = entry.offset + 64
                        kernel_size = 0

                        for stmt in entry.description.split(','):
                            if "image size:" in stmt:
                                kernel_size = int(''.join(
                                    i for i in stmt if i.isdigit()), 10)

                        if kernel_size != 0 and kernel_offset + kernel_size \
                            <= os.path.getsize(self.item):
                            self.printf(">>>> %s" % entry.description)

                            tmp_fd, tmp_path = tempfile.mkstemp(dir=self.temp)
                            os.close(tmp_fd)
                            Extractor.io_dd(self.item, kernel_offset,
                                            kernel_size, tmp_path)
                            kernel = ExtractionItem(self.extractor, tmp_path,
                                                    self.depth, self.tag)

                            return kernel.extract()
                    # elif "RAMDisk Image" in entry.description:
                    #     self.printf(">>>> %s" % entry.description)
                    #     self.printf(">>>> Skipping: RAMDisk / initrd")
                    #     self.terminate = True
                    #     return True

                # TP-Link or TRX
                elif not self.get_kernel_status() and \
                    not self.get_rootfs_status() and \
                    "rootfs offset: " in entry.description and \
                    "kernel offset: " in entry.description:
                    kernel_offset = 0
                    kernel_size = 0
                    rootfs_offset = 0
                    rootfs_size = 0

                    for stmt in entry.description.split(','):
                        if "kernel offset:" in stmt:
                            kernel_offset = int(stmt.split(':')[1], 16)
                        elif "kernel length:" in stmt:
                            kernel_size = int(stmt.split(':')[1], 16)
                        elif "rootfs offset:" in stmt:
                            rootfs_offset = int(stmt.split(':')[1], 16)
                        elif "rootfs length:" in stmt:
                            rootfs_size = int(stmt.split(':')[1], 16)

                    # compute sizes if only offsets provided
                    if kernel_offset != rootfs_size and kernel_size == 0 and \
                        rootfs_size == 0:
                        kernel_size = rootfs_offset - kernel_offset
                        rootfs_size = os.path.getsize(self.item) - rootfs_offset

                    # ensure that computed values are sensible
                    if (kernel_size > 0 and kernel_offset + kernel_size \
                        <= os.path.getsize(self.item)) and \
                        (rootfs_size != 0 and rootfs_offset + rootfs_size \
                            <= os.path.getsize(self.item)):
                        self.printf(">>>> %s" % entry.description)

                        tmp_fd, tmp_path = tempfile.mkstemp(dir=self.temp)
                        os.close(tmp_fd)
                        Extractor.io_dd(self.item, kernel_offset, kernel_size,
                                        tmp_path)
                        kernel = ExtractionItem(self.extractor, tmp_path,
                                                self.depth, self.tag)
                        kernel.extract()

                        tmp_fd, tmp_path = tempfile.mkstemp(dir=self.temp)
                        os.close(tmp_fd)
                        Extractor.io_dd(self.item, rootfs_offset, rootfs_size,
                                        tmp_path)
                        rootfs = ExtractionItem(self.extractor, tmp_path,
                                                self.depth, self.tag)
                        rootfs.extract()

                        return self.update_status()
        return False

    def _check_kernel(self):
        """
        If this file contains a kernel version string, assume it is a kernel.
        Only Linux kernels are currently extracted.
        """
        if not self.get_kernel_status():
            for module in binwalk.scan(self.item, "-y", "kernel",
                                       signature=True, quiet=True):
                for entry in module.results:
                    if "kernel version" in entry.description:
                        self.update_database("kernel_version",
                                             entry.description)
                        if "Linux" in entry.description:
                            if self.get_kernel_path():
                                shutil.copy(self.item, self.get_kernel_path())
                            else:
                                self.extractor.do_kernel = False
                            self.printf(">>>> %s" % entry.description)
                            return True
                        # VxWorks, etc
                        else:
                            self.printf(">>>> Ignoring: %s" % entry.description)
                            return False
                return False
        return False

    def _check_rootfs(self):
        """
        If this file contains a known filesystem type, extract it.
        """

        if not self.get_rootfs_status():
            for module in binwalk.scan(self.item, "-e", "-r", "-y",
                                       "filesystem", signature=True,
                                       quiet=True):
                for entry in module.results:
                    self.printf(">>>> %s" % entry.description)
                    break

                if module.extractor.directory:
                    unix = Extractor.io_find_rootfs(module.extractor.directory)

                    if not unix[0]:
                        self.printf(">>>> Extraction failed!")
                        return False

                    self.printf(">>>> Found Linux filesystem in %s!" % unix[1])
                    if self.output:
                        shutil.make_archive(self.output, "gztar",
                                            root_dir=unix[1])
                    else:
                        self.extractor.do_rootfs = False
                    return True
        return False

    def _check_compressed(self):
        """
        If this file appears to be compressed, decompress it and recurse over
        its contents.
        """
        return self._check_recursive("compressed")

    # treat both archived and compressed files using the same pathway. this is
    # because certain files may appear as e.g. "xz compressed data" but still
    # extract into a root filesystem.
    def _check_recursive(self, fmt):
        """
        Unified implementation for checking both "archive" and "compressed"
        items.
        """
        desc = None
        # perform extraction
        for module in binwalk.scan(self.item, "-e", "-r", "-y", fmt,
                                   signature=True, quiet=True):
            for entry in module.results:
                # skip cpio/initrd files since they should be included with
                # kernel
                # if "cpio archive" in entry.description:
                #     self.printf(">> Skipping: cpio: %s" % entry.description)
                #     self.terminate = True
                #     return True
                desc = entry.description
                self.printf(">>>> %s" % entry.description)
                break

            if module.extractor.directory:
                unix = Extractor.io_find_rootfs(module.extractor.directory)

                # check for extracted filesystem, otherwise update queue
                if unix[0]:
                    self.printf(">>>> Found Linux filesystem in %s!" % unix[1])
                    if self.output:
                        shutil.make_archive(self.output, "gztar",
                                            root_dir=unix[1])
                    else:
                        self.extractor.do_rootfs = False
                    return True
                else:
                    count = 0
                    self.printf(">> Recursing into %s ..." % fmt)
                    for root, _, files in os.walk(module.extractor.directory):
                        # sort both descending alphabetical and increasing
                        # length
                        files.sort()
                        files.sort(key=len)

                        # handle case where original file name is restored; put
                        # it to front of queue
                        if desc and "original file name:" in desc:
                            orig = None
                            for stmt in desc.split(","):
                                if "original file name:" in stmt:
                                    orig = stmt.split("\"")[1]
                            if orig and orig in files:
                                files.remove(orig)
                                files.insert(0, orig)

                        for filename in files:
                            if count > ExtractionItem.RECURSION_BREADTH:
                                self.printf(">> Skipping: recursion breadth %d"\
                                    % ExtractionItem.RECURSION_BREADTH)
                                self.terminate = True
                                return True
                            else:
                                new_item = ExtractionItem(self.extractor,
                                                          os.path.join(root,
                                                                       filename),
                                                          self.depth + 1,
                                                          self.tag)
                                if new_item.extract():
                                    # check that we are actually done before
                                    # performing early termination. for example,
                                    # we might decide to skip on one subitem,
                                    # but we still haven't finished
                                    if self.update_status():
                                        return True
                            count += 1
        return False

def main():
    parser = argparse.ArgumentParser(description="Extracts filesystem and \
        kernel from Linux-based firmware images")
    parser.add_argument("input", action="store", help="Input file or directory")
    parser.add_argument("output", action="store", nargs="?", default="images",
                        help="Output directory for extracted firmware")
    parser.add_argument("-sql ", dest="sql", action="store", default=None,
                        help="Hostname of SQL server")
    parser.add_argument("-nf", dest="rootfs", action="store_false",
                        default=True, help="Disable extraction of root \
                        filesystem (may decrease extraction time)")
    parser.add_argument("-nk", dest="kernel", action="store_false",
                        default=True, help="Disable extraction of kernel \
                        (may decrease extraction time)")
    parser.add_argument("-np", dest="parallel", action="store_false",
                        default=True, help="Disable parallel operation \
                        (may increase extraction time)")
    parser.add_argument("-b", dest="brand", action="store", default=None,
                        help="Brand of the firmware image")
    result = parser.parse_args()

    extract = Extractor(result.input, result.output, result.rootfs,
                        result.kernel, result.parallel, result.sql,
                        result.brand)
    extract.extract()

if __name__ == "__main__":
    main()