import logging as log
import shlex
import os
import sys
import subprocess
from subprocess import PIPE, TimeoutExpired

from collections import namedtuple


from . import trailerfilter

# Turning off StrictHostKeyChecking is a nasty hack to approximate
# just accepting the hostkey sight unseen the first time marge
# connects. The proper solution would be to pass in known_hosts as
# a commandline parameter, but in practice few people will bother anyway and
# in this case the threat of MiTM seems somewhat bogus.
GIT_SSH_COMMAND = "ssh -o StrictHostKeyChecking=no "


def _filter_branch_script(trailer_name, trailer_values):
    filter_script = 'TRAILERS={trailers} python3 {script}'.format(
        trailers=shlex.quote(
            '\n'.join(
                '{}: {}'.format(trailer_name, trailer_value)
                for trailer_value in trailer_values or [''])
        ),
        script=trailerfilter.__file__,
    )
    return filter_script


class Repo(namedtuple('Repo', 'remote_url local_path ssh_key_file timeout reference')):
    def clone(self):
        reference_flag = '--reference=' + self.reference if self.reference else ''
        self.git('clone', '--origin=origin', reference_flag, self.remote_url,
                 self.local_path, from_repo=False)

    def config_user_info(self, user_name, user_email):
        self.git('config', 'user.email', user_email)
        self.git('config', 'user.name', user_name)

    def fetch(self, remote_name, remote_url=None):
        if remote_name != 'origin':
            assert remote_url is not None
            # upsert remote
            try:
                self.git('remote', 'rm', remote_name)
            except GitError:
                pass
            self.git('remote', 'add', remote_name, remote_url)
        self.git('fetch', '--prune', remote_name)

    def tag_with_trailer(self, trailer_name, trailer_values, branch, start_commit):
        """Replace `trailer_name` in commit messages with `trailer_values` in `branch` from `start_commit`.
        """

        # Strips all `$trailer_name``: lines and trailing newlines, adds an empty
        # newline and tags on the `$trailer_name: $trailer_value` for each `trailer_value` in
        # `trailer_values`.
        filter_script = _filter_branch_script(trailer_name, trailer_values)
        commit_range = start_commit + '..' + branch
        try:
            # --force = overwrite backup of last filter-branch
            self.git('filter-branch', '--force', '--msg-filter', filter_script, commit_range)
        except GitError:
            log.warning('filter-branch failed, will try to restore')
            try:
                self.get_commit_hash('refs/original/refs/heads/')
            except GitError:
                log.warning('No changes have been effected by filter-branch')
            else:
                self.git('reset', '--hard', 'refs/original/refs/heads/' + branch)
            raise
        return self.get_commit_hash()

    def merge(self, source_branch, target_branch, *merge_args, source_repo_url=None, local=False):
        """Merge `target_branch` into `source_branch` and return the new HEAD commit id.

        By default `source_branch` and `target_branch` are assumed to reside in the same
        repo as `self`. However, if `source_repo_url` is passed and not `None`,
        `source_branch` is taken from there.

        Throws a `GitError` if the merge fails. Will also try to --abort it.
        """
        return self._fuse_branch(
            'merge', source_branch, target_branch, *merge_args, source_repo_url=source_repo_url, local=local,
        )

    def fast_forward(self, source, target, source_repo_url=None, local=False):
        return self.merge(source, target, '--ff', '--ff-only', source_repo_url=source_repo_url, local=local)

    def rebase(self, branch, new_base, source_repo_url=None, local=False):
        """Rebase `new_base` into `branch` and return the new HEAD commit id.

        By default `branch` and `new_base` are assumed to reside in the same
        repo as `self`. However, if `source_repo_url` is passed and not `None`,
        `branch` is taken from there.

        Throws a `GitError` if the rebase fails. Will also try to --abort it.
        """
        return self._fuse_branch('rebase', branch, new_base, source_repo_url=source_repo_url, local=local)

    def _fuse_branch(self, strategy, branch, target_branch, *fuse_args, source_repo_url=None, local=False):
        assert source_repo_url or branch != target_branch, branch

        if not local:
            self.fetch('origin')
            target = 'origin/' + target_branch
            if source_repo_url:
                self.fetch('source', source_repo_url)
                self.checkout_branch(branch, 'source/' + branch)
            else:
                self.checkout_branch(branch, 'origin/' + branch)
        else:
            self.checkout_branch(branch)
            target = target_branch

        try:
            self.git(strategy, target, *fuse_args)
        except GitError:
            log.warning('%s failed, doing an --abort', strategy)
            self.git(strategy, '--abort')
            raise
        return self.get_commit_hash()

    def remove_branch(self, branch, *, new_current_branch='master'):
        assert branch != new_current_branch
        self.git('branch', '-D', branch)

    def checkout_branch(self, branch, start_point=''):
        self.git('checkout', '-B', branch, start_point, '--')

    def push(self, branch, *, source_repo_url=None, force=False):
        self.git('checkout', branch, '--')

        self.git('diff-index', '--quiet', 'HEAD')  # check it is not dirty

        untracked_files = self.git('ls-files', '--others').stdout  # check no untracked files
        if untracked_files:
            raise GitError('There are untracked files', untracked_files)

        if source_repo_url:
            assert self.get_remote_url('source') == source_repo_url
            source = 'source'
        else:
            source = 'origin'
        force_flag = '--force' if force else ''
        self.git('push', force_flag, source, '%s:%s' % (branch, branch))

    def get_commit_hash(self, rev='HEAD'):
        """Return commit hash for `rev` (default "HEAD")."""
        result = self.git('rev-parse', rev)
        return result.stdout.decode('ascii').strip()

    def get_remote_url(self, name):
        return self.git('config', '--get', 'remote.{}.url'.format(name)).stdout.decode('utf-8').strip()

    def git(self, *args, from_repo=True):
        env = None
        if self.ssh_key_file:
            env = os.environ.copy()
            # ssh's handling of identity files is infuriatingly dumb, to get it
            # to actually really use the IdentityFile we pass in via -i we also
            # need to tell it to ignore ssh-agent (IdentitiesOnly=true) and not
            # read in any identities from ~/.ssh/config etc (-F /dev/null),
            # because they append and it tries them in order, starting with config file
            env['GIT_SSH_COMMAND'] = " ".join([
                GIT_SSH_COMMAND,
                "-F", "/dev/null",
                "-o", "IdentitiesOnly=yes",
                "-i", self.ssh_key_file,
            ])

        command = ['git']
        if from_repo:
            command.extend(['-C', self.local_path])
        command.extend([arg for arg in args if str(arg)])

        log.info('Running %s', ' '.join(shlex.quote(w) for w in command))
        try:
            timeout_seconds = self.timeout.total_seconds() if self.timeout is not None else None
            return _run(*command, env=env, check=True, timeout=timeout_seconds)
        except subprocess.CalledProcessError as err:
            log.warning('git returned %s', err.returncode)
            log.warning('stdout: %r', err.stdout)
            log.warning('stderr: %r', err.stderr)
            raise GitError(err)


def _run(*args, env=None, check=False, timeout=None):
    encoded_args = [a.encode('utf-8') for a in args] if sys.platform != 'win32' else args
    with subprocess.Popen(encoded_args, env=env, stdout=PIPE, stderr=PIPE) as process:
        try:
            stdout, stderr = process.communicate(input, timeout=timeout)
        except TimeoutExpired:
            process.kill()
            stdout, stderr = process.communicate()
            raise TimeoutExpired(
                process.args, timeout, output=stdout, stderr=stderr,
            )
        except Exception:
            process.kill()
            process.wait()
            raise
        retcode = process.poll()
        if check and retcode:
            raise subprocess.CalledProcessError(
                retcode, process.args, output=stdout, stderr=stderr,
            )
        return subprocess.CompletedProcess(process.args, retcode, stdout, stderr)


class GitError(Exception):
    pass