################################################################################
# Name   : SockPuppet.py
# Author : Tyson Smith
#
# Copyright 2014 BlackBerry Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.
################################################################################
import logging as log
import argparse
import pickle
import marshal
import os
import socket
import struct
import subprocess
import sys
import tempfile
import time
import traceback
import types
import zlib

class SockPuppetBase(object):
    # DEFAULTS
    CHUNK_BUF = 4 * 1024 * 1024
    SOCK_BUF = 64 * 1024
    # COMMANDS
    ACK =    0
    CHUNK =  1
    CODE =   2
    DEBUG =  3
    EXCEPT = 4
    FILE =   5
    QUIT =   6
    RESULT = 7
    RETURN = 8
    RUN =    9

    def __init__(self, ip=None, is_server=False, port=1701, timeout=60, debug=False):
        self.conn = None
        self.ip = ip
        self.is_server = is_server
        self.port = port
        self.timeout = timeout
        self.debugging = debug
        log.basicConfig(level=log.INFO)
        if self.debugging and log.getLogger().level == log.INFO:
            self.toggle_debug()

    def connect(self):
        if self.is_server:
            log.debug("waiting for client to connect...")
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.bind(('', self.port))
            s.settimeout(0.1)
            start_time = time.time()
            s.listen(0)
            while True:
                try:
                    conn, _ = s.accept()
                    self.conn = conn
                    break
                except socket.timeout:
                    pass
                if self.timeout > 0 and time.time() - start_time >= self.timeout:
                    s.close()
                    raise RuntimeError("Timeout exceeded (%ds)" % self.timeout)
            self.conn.setblocking(True)
        else:
            log.debug("connecting to server (%s:%d)...", self.ip, self.port)
            self.conn = socket.create_connection((self.ip, self.port), self.timeout)

    def disconnect(self):
        log.debug("disconnecting")
        if self.conn:
            self.conn.close()

    def recv_data(self):
        data_remaining = struct.unpack("I", self.conn.recv(4))[0]
        if not data_remaining:
            log.debug("no data?!")
            return None
        log.debug("<- recving %d bytes", data_remaining)
        data = []
        while data_remaining:
            recv_bytes = data_remaining if data_remaining < self.SOCK_BUF else self.SOCK_BUF
            data.append(self.conn.recv(recv_bytes))
            data_len = len(data[-1])
            if data_len == 0:
                break
            data_remaining -= data_len
        data = pickle.loads("".join(data))
        if data["cmd"] != self.ACK:
            self.send_ack()
        return data

    def recv_file(self, data):
        name = os.path.join(data["path"], data["name"])
        data_remaining = data["size"]
        expected_chksum = data["chksum"]
        log.debug("receiving file: %s (%0.02fKB)", name, data_remaining/1024.0)
        chksum = 0
        with open(name, "wb") as fp:
            while data_remaining:
                data = self.recv_data()
                assert data["cmd"] == self.CHUNK, "Expecting data chunk."
                chksum = zlib.adler32(data["data"], chksum)
                fp.write(data["data"])
                data_remaining -= len(data["data"])
        if expected_chksum != chksum:
            raise RuntimeError("Checksum mismatch!")

    def send_ack(self):
        log.debug("ACK'ing")
        self.send_data({"cmd":self.ACK})

    def send_data(self, data):
        is_ack = (data["cmd"] == self.ACK)
        data = pickle.dumps(data, pickle.HIGHEST_PROTOCOL)
        data_len = len(data)
        assert data_len < 0xFFFFFFFF, "Transfer too large!"
        log.debug("-> sending %d bytes", data_len)
        self.conn.sendall(struct.pack("I", data_len))
        self.conn.sendall(data)
        if not is_ack:
            assert self.recv_data()["cmd"] == self.ACK
            log.debug("ACK received")

    def send_file(self, src, dst=None):
        if not os.path.isfile(src):
            raise RuntimeError("%s does not exist!" % src)
        if dst is None:
            dst = os.path.basename(src)
        log.debug("sending file (%s) -> (%s)", src, dst)
        file_size = int(os.stat(src).st_size)
        chksum = 0
        with open(src, "rb") as fp:
            while fp.tell() < file_size:
                chksum = zlib.adler32(fp.read(self.CHUNK_BUF), chksum)
        data = {
                "cmd":self.FILE,
                "name":os.path.basename(dst),
                "path":os.path.dirname(dst),
                "size":file_size,
                "chksum":chksum
               }
        self.send_data(data)
        with open(src, "rb") as fp:
            data = {"cmd":self.CHUNK}
            while fp.tell() < file_size:
                data["data"] = fp.read(self.CHUNK_BUF)
                self.send_data(data)
                if len(data["data"]) < self.CHUNK_BUF:
                    break

    def toggle_debug(self):
        self.debugging = not self.debugging
        if self.debugging:
            log.getLogger().setLevel(level=log.DEBUG)
            log.debug("debugging enabled")
        else:
            log.debug("debugging disabled")
            log.getLogger().setLevel(level=log.INFO)

class Target(SockPuppetBase):
    def run(self):
        self.connect()
        try:
            while True:
                log.debug("waiting for command...")
                data = self.recv_data()
                if data["cmd"] == self.QUIT:
                    log.debug("QUIT (%d)", self.QUIT)
                    break
                elif data["cmd"] == self.DEBUG:
                    log.debug("DEBUG")
                    self.toggle_debug()
                elif data["cmd"] == self.FILE:
                    log.debug("FILE (%d)", self.FILE)
                    self.recv_file(data)
                elif data["cmd"] == self.RUN:
                    log.debug("RUN (%d)", self.RUN)
                    log.debug("running cmd: %s", " ".join(data["cmd_to_run"]))
                    with tempfile.TemporaryFile() as fp:
                        try:
                            proc = subprocess.Popen(data["cmd_to_run"],
                                                    shell=False,
                                                    stdout=fp,
                                                    stderr=fp)
                            data = {"cmd":self.RESULT}
                            data["code"] = proc.wait()
                            fp.seek(0)
                            data["output"] = fp.read()
                            log.debug("command returned: %d", data["code"])
                        except Exception:
                            e = sys.exc_info()
                            log.debug("except - %s: %s", e[0].__name__, e[1])
                            data = {"cmd":self.EXCEPT,
                                    "msg":e[1],
                                    "name":e[0].__name__,
                                    "tb":"".join(traceback.format_tb(e[2]))}
                    log.debug("sending results...")
                    self.send_data(data)
                elif data["cmd"] == self.CODE:
                    log.debug("CODE (%d)", self.CODE)
                    try:
                        func = types.FunctionType(marshal.loads(data["code"]), globals(),
                                                  data["name"], data["defaults"], data["closure"])
                        log.debug("%s() args:%s kwargs:%s", data["name"], data["args"], data["kwargs"])
                        data = {"cmd":self.RETURN, "value":func(*data["args"], **data["kwargs"])}
                    except Exception:
                        e = sys.exc_info()
                        log.debug("except - %s: %s", e[0].__name__, e[1])
                        data = {"cmd":self.EXCEPT,
                                "msg":e[1],
                                "name":e[0].__name__,
                                "tb":"".join(traceback.format_tb(e[2]))}
                    self.send_data(data)
                else:
                    log.debug("UNKNOWN (%s)", data)
                    raise RuntimeError("Unknown command: %d" % data["cmd"])
        finally:
            self.disconnect()

class Controller(SockPuppetBase):
    def __init__(self, *args, **kwargs):
        SockPuppetBase.__init__(self, is_server=True, *args, **kwargs)

    @staticmethod
    def _process_target_except(e_data):
        msg = "Client side exception.\n\n%s%s: %s" % (e_data["tb"], e_data["name"], e_data["msg"])
        return RuntimeError(msg)

    def run_cmd(self, cmd_to_run, cmd_timeout=120):
        log.debug("run cmd on target: %s", " ".join(cmd_to_run))
        data = {"cmd":self.RUN,
                "cmd_to_run":cmd_to_run,
                "timeout":cmd_timeout}
        self.send_data(data)
        log.debug("waiting for cmd results...")
        data = self.recv_data()
        if data["cmd"] == self.EXCEPT:
            log.debug("received exception")
            raise self._process_target_except(data)
        assert data["cmd"] == self.RESULT
        return (data["code"], data["output"])

    def run_code(self, function, *args, **kwargs):
        log.debug("%s() args:%s kwargs:%s on target", function.func_name, args, kwargs)
        data = {"cmd":self.CODE,
                "code":marshal.dumps(function.func_code),
                "name":function.func_name,
                "args":args,
                "kwargs":kwargs,
                "defaults":function.__defaults__,
                "closure":function.__closure__}
        self.send_data(data)
        log.debug("waiting for code to execute...")
        data = self.recv_data()
        if data["cmd"] == self.EXCEPT:
            log.debug("received exception")
            raise self._process_target_except(data)
        assert data["cmd"] == self.RETURN
        return data["value"]

    def send_quit(self):
        log.debug("sending QUIT")
        self.send_data({"cmd":self.QUIT})

    def debug_client(self):
        log.debug("sending DEBUG")
        self.send_data({"cmd":self.DEBUG})

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--ip', default="localhost", help='server ip')
    parser.add_argument('--debug', action="store_true", dest="debug", default=False)
    parser.add_argument('--mode', default="server", help='mode: server or client')
    parser.add_argument('--port', type=int, default=1701, help='port')
    parser.add_argument('--timeout', type=int, default=0)

    args = parser.parse_args()

    if args.mode == "server":
        c = Controller(port=args.port, timeout=args.timeout, debug=args.debug)
        c.connect()
        try:
            c.send_quit()
        finally:
            c.disconnect()
        log.debug("SERVER EXIT")
    elif args.mode == "client":
        t = Target(ip=args.ip, port=args.port, debug=args.debug)
        t.run()
        log.debug("CLIENT EXIT")