#!/usr/bin/env python

# run_echo.py
#   A script to measure remote command execution speed with ansible
#
# usage:
#   python2 run_ssh_cmd.py
#
# examples:
#   python2 run_ssh_cmd.py \
#    --debug \
#    --use_plugin \
#    --username=vagrant \
#    --hostname=vagrant \
#    --iterations=100 \
#    --command="umask 77 && mkdir -p ~/.ansible_test && rm -rf ~/.ansible_test"

import fcntl
import json
import mock
import os
import re
import subprocess
import time

from datetime import datetime
from optparse import OptionParser

from ansible.compat import selectors
from ansible.errors import AnsibleError
from ansible.plugins.loader import connection_loader


HAS_LOGZERO = False
try:
    from logzero import logger
    HAS_LOGZERO = True
except ImportError:
    pass


class MockLogger(object):
    level = 'DEBUG'

    def setLevel(self, level):
        self.level = level

    def info(self, msg):
        print(msg)

    def error(self, msg):
        print(msg)

    def debug(self, msg):
        if self.level == 'DEBUG':
            print(msg)

    @staticmethod
    def debug(msg, host=None):
        print(msg)

    @staticmethod
    def v(msg, host=None):
        print(msg)

    @staticmethod
    def vv(msg, host=None):
        print(msg)

    @staticmethod
    def vvv(msg, host=None):
        print(msg)

    @staticmethod
    def vvvv(msg, host=None):
        print(msg)

    @staticmethod
    def vvvvv(msg, host=None):
        print(msg)


class MockPlayContext(object):
    executable = '/bin/sh'
    shell = 'sh'
    ssh_executable = 'ssh'
    port = 22
    remote_user = 'vagrant'
    password = None
    _load_name = 'ssh'
    name = 'ssh'
    timeout = 10
    verbosity = 5
    ssh_args = None
    private_key_file = None
    prompt = None
    become = False


if not HAS_LOGZERO:
    print('PLEASE INSTALL LOGZERO FOR BEST EXPERIENCE')
    logger = MockLogger()


SSHCMD = [
  "/usr/bin/ssh",
  "-vvvvvv",
  "-C",
  "-o",
  "ControlMaster=auto",
  "-o",
  "ControlPersist=60s",
  "-o",
  "IdentityFile=\"~/.ssh/id_rsa\"",
  "-o",
  "KbdInteractiveAuthentication=no",
  "-o",
  "PreferredAuthentications=gssapi-with-mic,gssapi-keyex,hostbased,publickey",
  "-o",
  "PasswordAuthentication=no",
  "-o",
  "User=vagrant",
  "-o",
  "ConnectTimeout=10",
  "-o",
  "ControlPath=~/.ansible/cp/testcp",
  "el6host",
  "/bin/sh -c 'echo ~vagrant && sleep 0'"
]


def validate_control_socket(SSHCMD):
    # $ ssh -O check -o ControlPath=... vagrant@el6host
    # Master running (pid=24779)
    for idx, x in enumerate(SSHCMD):
        if x.startswith('ControlPath'):
            cppath = x.split('=')[1]

            if not os.path.exists(cppath):
                logger.info('%s does not exist' % cppath)
            else:
                cpcmd = SSHCMD[:-1]

                checkcmd = cpcmd[:]
                checkcmd.insert(-1, '-O')
                checkcmd.insert(-1, 'check')
                print('# %s' % ' '.join(checkcmd))
                (rc, so, se) = run_ssh_cmd(
                    ' '.join(checkcmd),
                    use_selectors=False
                )
                logger.debug('rc: %s' % rc)
                logger.debug('so: %s' % so)
                logger.debug('se: %s' % se)

                if rc != 0 or so.strip():
                    logger.info('checkcmd rc != 0 or has stdout')
                    logger.info(so)
                    logger.info(se)


def set_vcount(SSHCMD, count=None):
    if count is None:
        return SSHCMD

    isset = False
    for idx, x in enumerate(SSHCMD):
        if x.startswith('-v'):
            isset = True
            SSHCMD[idx] = '-' + ''.join(['v' for x in range(0, count)])

    if not isset:
        SSHCMD.insert(1, '-' + ''.join(['v' for x in range(0, count)]))

    return SSHCMD


def set_hostname(SSHCMD, hostname):
    SSHCMD[-2] = hostname
    return SSHCMD


def set_username(SSHCMD, username):
    for idx, x in enumerate(SSHCMD):
        if x.startswith('User='):
            SSHCMD[idx] = 'User=%s' % username
        if 'echo ~' in x:
            orig = re.search(r'~\w+', x).group()
            new = '~%s' % username
            SSHCMD[idx] = x.replace(orig, new, 1)
    return SSHCMD


def set_keyfile(SSHCMD, keyfile):
    # "IdentityFile=\"~/.ssh/id_rsa\"",
    for idx, x in enumerate(SSHCMD):
        if x.startswith('IdentityFile'):
            SSHCMD[idx] = 'IdentityFile="%s"' % keyfile
            break
    return SSHCMD


def remove_control_persist(SSHCMD):
    while True:
        if not [x for x in SSHCMD if x.startswith('Control')]:
            break

        for idx, x in enumerate(SSHCMD):
            if x.startswith('Control'):
                print('popping %s' % x)
                SSHCMD.pop(idx)
                SSHCMD.pop(idx-1)
                print(' '.join(SSHCMD))
                break

    return SSHCMD


def extract_speeed_from_stdtout(so):

    '''Strip transfer statistics from stderr/stdout'''

    # Transferred: sent 3192, received 2816 bytes, in 1.6 seconds
    # Bytes per second: sent 1960.0, received 1729.1
    data = {}
    for line in so.split('\n'):
        if 'Transferred' in line:
            sent = re.search(r'sent \d+', line).group()
            received = re.search(r'received \d+', line).group()
            duration = re.search(r'in \d+\.\d+', line).group()
            data['transfered'] = {
                'sent': float(sent.split()[1]),
                'received': float(received.split()[1]),
                'duration': float(duration.split()[1]),
            }

        elif 'Bytes per second' in line:
            sent = re.search(r'sent \d+', line).group()
            received = re.search(r'received \d+', line).group()
            data['speeds'] = {
                'sent': float(sent.split()[1]),
                'received': float(received.split()[1]),
            }

    return data


def run_ssh_exec(command=None, hostname=None, username=None, keyfile=None):

    '''Use ansible's connection plugin to execute the command'''

    with mock.patch('ansible.plugins.connection.ssh.display', MockLogger):
        pc = MockPlayContext()
        if hostname:
            pc.remote_addr = hostname
        if username:
            pc.remote_user = username
        if keyfile:
            pc.private_key_file = keyfile

        ssh = connection_loader.get('ssh', pc, None)
        (rc, so, se) = ssh.exec_command(command)

    return (
        rc,
        so.decode('utf-8'),
        se.decode('utf-8')
    )


def run_ssh_cmd(SSHCMD, command=None, hostname=None, username=None, use_selectors=False):

    '''Run the command with subprocess and communicate or selectors'''

    if not use_selectors:
        p = subprocess.Popen(
            ' '.join(SSHCMD),
            shell=True,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )

        (so, se) = p.communicate()
        return (p.returncode, so.decode('utf-8'), se.decode('utf-8'))

    else:

        # This is kinda how ansible runs ssh commands ...

        logger.info('using selectors ...')
        p = subprocess.Popen(
            ' '.join(SSHCMD),
            shell=True,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )

        for fd in (p.stdout, p.stderr):
            fcntl.fcntl(
                fd,
                fcntl.F_SETFL,
                fcntl.fcntl(
                    fd,
                    fcntl.F_GETFL
                ) | os.O_NONBLOCK
            )

        states = [
            'awaiting_prompt',
            'awaiting_escalation',
            'ready_to_send',
            'awaiting_exit'
        ]
        state = states.index('ready_to_send')
        state += 1

        selector = selectors.DefaultSelector()
        selector.register(p.stdout, selectors.EVENT_READ)
        selector.register(p.stderr, selectors.EVENT_READ)

        timeout = 0
        events = None

        b_stdout = b_stderr = b''
        b_tmp_stdout = b_tmp_stderr = b''

        try:

            counter = 0
            while True:
                counter += 1

                if counter == 1:
                    time.sleep(2)

                poll = p.poll()
                events = selector.select(timeout)

                if not events:
                    if state <= states.index('awaiting_escalation'):
                        if poll is not None:
                            break
                        p.terminate()
                        raise AnsibleError('timeout')

                for key, event in events:

                    if key.fileobj == p.stdout:
                        b_chunk = p.stdout.read()
                        logger.debug('b_chunk %s' % b_chunk)
                        if b_chunk == b'':
                            selector.unregister(p.stdout)
                            timeout = 1
                        b_tmp_stdout += b_chunk
                    elif key.fileobj == p.stderr:
                        b_chunk = p.stderr.read()
                        logger.debug('b_chunk %s' % b_chunk)
                        if b_chunk == b'':
                            selector.unregister(p.stderr)
                        b_tmp_stderr += b_chunk

                if state < states.index('ready_to_send'):
                    if b_tmp_stdout:
                        b_stdout += b_tmp_stdout
                    if b_tmp_stderr:
                        b_stderr += b_tmp_stderr
                else:
                    b_stdout += b_tmp_stdout
                    b_stderr += b_tmp_stderr
                    b_tmp_stdout = b_tmp_stderr = b''

                if states[state] == 'awaiting_prompt':
                    state += 1

                if states[state] == 'awaiting_escalation':
                    state += 1

                if states[state] == 'ready_to_send':
                    state += 1

                if poll is not None:
                    if not selector.get_map() or not events:
                        break
                    timeout = 0
                    continue

                elif not selector.get_map():
                    p.wait()
                    break

                logger.debug(counter)
                logger.debug(state)
                logger.debug(states[state])
                logger.debug(poll)
                logger.debug(selector.get_map())
                logger.debug(events)

        finally:
            selector.close()

        return (
            p.returncode,
            b_stdout.decode('utf-8'),
            b_stderr.decode('utf-8')
        )


##########################################
#   MAIN
##########################################

def main():

    global SSHCMD

    parser = OptionParser()
    parser.add_option('--iterations', type=int, default=10)
    parser.add_option('--controlpersist', action='store_true')
    parser.add_option('--selectors', action='store_true')
    parser.add_option('--use_plugin', action='store_true')
    parser.add_option('--vcount', type=int, default=None)
    parser.add_option('--debug', action='store_true')
    parser.add_option('--hostname', default=None)
    parser.add_option('--username', default=None)
    parser.add_option('--keyfile', default=None)
    parser.add_option('--command', default=None)
    (options, args) = parser.parse_args()

    if not options.debug:
        logger.setLevel('INFO')

    # munge the example ssh command if not using the connection plugin
    if not options.use_plugin:
        validate_control_socket(SSHCMD)
        if not options.controlpersist:
            SSHCMD = remove_control_persist(SSHCMD)

        if options.hostname:
            SSHCMD = set_hostname(SSHCMD, options.hostname)

        if options.username:
            SSHCMD = set_username(SSHCMD, options.username)

        if options.keyfile:
            SSHCMD = set_keyfile(SSHCMD, options.keyfile)

        if options.vcount is not None:
            SSHCMD = set_vcount(SSHCMD, count=options.vcount)

        if options.command is not None:
            SSHCMD[-1] = '/bin/sh -c "%s"' % options.command

        logger.info(SSHCMD)

    # run the command X times and record the durations + speeds
    durations = []
    for x in range(0, options.iterations):
        logger.info('iteration %s' % x)
        start = datetime.now()
        if options.use_plugin:
            (rc, so, se) = run_ssh_exec(
                command=options.command,
                hostname=options.hostname,
                username=options.username,
                keyfile=options.keyfile,
            )
        else:
            (rc, so, se) = run_ssh_cmd(
                SSHCMD,
                hostname=options.hostname,
                username=options.username,
                use_selectors=options.selectors
            )
        stop = datetime.now()
        durations.append(stop - start)
        stats = extract_speeed_from_stdtout(se)
        logger.info('transfer stats ...')
        for k, v in stats.items():
            for k2, v2 in v.items():
                logger.info('%s.%s = %s' % (k, k2, v2))
        logger.info('rc: %s' % rc)
        logger.info('so:%s' % so.strip())
        if rc != 0:
            logger.error(se)
            logger.error('sshcmd: %s' % ' '.join(SSHCMD))

    durations = [x.total_seconds() for x in durations]
    logger.info('durations ...')
    for idx, x in enumerate(durations):
        logger.info('%s. %s' % (idx, x))
    logger.info('duration min: %s' % min(durations))
    logger.info('duration max: %s' % max(durations))
    avg = sum(durations) / float(len(durations))
    logger.info('duration avg: %s' % avg)


if __name__ == "__main__":
    main()