# -*- coding: utf-8 -*-
__author__ = 'fyrestone@outlook.com'
__version__ = '1.3'

import sys
import ast
import difflib
import operator
import argparse
import itertools
from collections import Counter

# avoid using six to keep dependency clean
if sys.version_info >= (3, 3):
    import collections.abc as collections
else:
    import collections

if sys.version_info[0] == 3:
    string_types = str
else:
    string_types = basestring


class FuncNodeCollector(ast.NodeTransformer):
    """
    Clean node attributes, delete the attributes that are not helpful for recognition repetition.
    Then collect all function nodes.
    """

    def __init__(self):
        super(FuncNodeCollector, self).__init__()
        self._curr_class_names = []
        self._func_nodes = []
        self._last_node_lineno = -1
        self._node_count = 0

    @staticmethod
    def _mark_docstring_sub_nodes(node):
        """
        Inspired by ast.get_docstring, mark all docstring sub nodes.

        Case1:
        regular docstring of function/class/module

        Case2:
        def foo(self):
            '''pure string expression'''
            for x in self.contents:
                '''pure string expression'''
                print x
            if self.abc:
                '''pure string expression'''
                pass

        Case3:
        def foo(self):
            if self.abc:
                print('ok')
            else:
                '''pure string expression'''
                pass

        :param node: every ast node
        :return:
        """

        def _mark_docstring_nodes(body):
            if body and isinstance(body, collections.Sequence):
                for n in body:
                    if isinstance(n, ast.Expr) and isinstance(n.value, ast.Str):
                        n.is_docstring = True

        node_body = getattr(node, 'body', None)
        _mark_docstring_nodes(node_body)
        node_orelse = getattr(node, 'orelse', None)
        _mark_docstring_nodes(node_orelse)

    @staticmethod
    def _is_docstring(node):
        return getattr(node, 'is_docstring', False)

    def generic_visit(self, node):
        self._node_count = self._node_count + 1
        self._last_node_lineno = max(getattr(node, 'lineno', -1), self._last_node_lineno)
        self._mark_docstring_sub_nodes(node)
        return super(FuncNodeCollector, self).generic_visit(node)

    def visit_Constant(self, node):
        # introduce a special value for erasing constant node value,
        # del node.value will make node.s and node.n raise Exception.
        # for Python 3.8
        dummy_value = '__pycode_similar_dummy_value__'
        if type(node) == str:
            node.value = dummy_value
        self.generic_visit(node)

    def visit_Str(self, node):
        del node.s
        self.generic_visit(node)
        return node

    def visit_Expr(self, node):
        if not self._is_docstring(node):
            self.generic_visit(node)
            if hasattr(node, 'value'):
                return node

    def visit_arg(self, node):
        """
        remove arg name & annotation for python3
        :param node: ast.arg
        :return:
        """
        del node.arg
        del node.annotation
        self.generic_visit(node)
        return node

    def visit_Name(self, node):
        del node.id
        del node.ctx
        self.generic_visit(node)
        return node

    def visit_Attribute(self, node):
        del node.attr
        del node.ctx
        self.generic_visit(node)
        return node

    def visit_Call(self, node):
        func = getattr(node, 'func', None)
        if func and isinstance(func, ast.Name) and func.id == 'print':
            return  # remove print call and its sub nodes for python3
        return node

    def visit_ClassDef(self, node):
        self._curr_class_names.append(node.name)
        self.generic_visit(node)
        self._curr_class_names.pop()
        return node

    def visit_FunctionDef(self, node):
        node.name = '.'.join(itertools.chain(self._curr_class_names, [node.name]))
        self._func_nodes.append(node)
        count = self._node_count
        self.generic_visit(node)
        node.endlineno = self._last_node_lineno
        node.nsubnodes = self._node_count - count
        return node

    def visit_Compare(self, node):

        def _simple_nomalize(*ops_type_names):
            if node.ops and len(node.ops) == 1 and type(node.ops[0]).__name__ in ops_type_names:
                if node.left and node.comparators and len(node.comparators) == 1:
                    left, right = node.left, node.comparators[0]
                    if type(left).__name__ > type(right).__name__:
                        left, right = right, left
                        node.left = left
                        node.comparators = [right]
                        return True
            return False

        if _simple_nomalize('Eq'):
            pass

        if _simple_nomalize('Gt', 'Lt'):
            node.ops = [{ast.Lt: ast.Gt, ast.Gt: ast.Lt}[type(node.ops[0])]()]

        if _simple_nomalize('GtE', 'LtE'):
            node.ops = [{ast.LtE: ast.GtE, ast.GtE: ast.LtE}[type(node.ops[0])]()]

        self.generic_visit(node)
        return node

    def visit_Print(self, node):
        # remove print expr for python2
        pass

    def visit_Import(self, node):
        # remote import ...
        pass

    def visit_ImportFrom(self, node):
        # remote from ... import ...
        pass

    def clear(self):
        self._func_nodes = []

    def get_function_nodes(self):
        return self._func_nodes


class FuncInfo(object):
    """
    Part of the astor library for Python AST manipulation.

    License: 3-clause BSD

    Copyright 2012 (c) Patrick Maupin
    Copyright 2013 (c) Berker Peksag

    """

    class NonExistent(object):
        pass

    def __init__(self, func_node, code_lines):
        assert isinstance(func_node, ast.FunctionDef)
        self._func_node = func_node
        self._code_lines = code_lines
        self._func_name = func_node.__dict__.pop('name', '')
        self._func_code = None
        self._func_code_lines = None
        self._func_ast = None
        self._func_ast_lines = None

    def __str__(self):
        return '<' + type(self).__name__ + ': ' + self.func_name + '>'

    @property
    def func_name(self):
        return self._func_name

    @property
    def func_node(self):
        return self._func_node

    @property
    def func_code(self):
        if self._func_code is None:
            self._func_code = ''.join(self.func_code_lines)
        return self._func_code

    @property
    def func_code_lines(self):
        if self._func_code_lines is None:
            self._func_code_lines = self._retrieve_func_code_lines(self._func_node, self._code_lines)
        return self._func_code_lines

    @property
    def func_ast(self):
        if self._func_ast is None:
            self._func_ast = self._dump(self._func_node)
        return self._func_ast

    @property
    def func_ast_lines(self):
        if self._func_ast_lines is None:
            self._func_ast_lines = self.func_ast.splitlines(True)
        return self._func_ast_lines

    @staticmethod
    def _retrieve_func_code_lines(func_node, code_lines):
        if not isinstance(func_node, ast.FunctionDef):
            return []
        if not isinstance(code_lines, collections.Sequence) or isinstance(code_lines, string_types):
            return []
        if getattr(func_node, 'endlineno', -1) < getattr(func_node, 'lineno', 0):
            return []
        lines = code_lines[func_node.lineno - 1: func_node.endlineno]
        if lines:
            padding = lines[0][:-len(lines[0].lstrip())]
            stripped_lines = []
            for l in lines:
                if l.startswith(padding):
                    stripped_lines.append(l[len(padding):])
                else:
                    stripped_lines = []
                    break
            if stripped_lines:
                return stripped_lines
        return lines

    @staticmethod
    def _iter_node(node, name='', missing=NonExistent):
        """Iterates over an object:

           - If the object has a _fields attribute,
             it gets attributes in the order of this
             and returns name, value pairs.

           - Otherwise, if the object is a list instance,
             it returns name, value pairs for each item
             in the list, where the name is passed into
             this function (defaults to blank).

        """
        fields = getattr(node, '_fields', None)
        if fields is not None:
            for name in fields:
                value = getattr(node, name, missing)
                if value is not missing:
                    yield value, name
        elif isinstance(node, list):
            for value in node:
                yield value, name

    @staticmethod
    def _dump(node, name=None, initial_indent='', indentation='    ',
              maxline=120, maxmerged=80, special=ast.AST):
        """Dumps an AST or similar structure:

           - Pretty-prints with indentation
           - Doesn't print line/column/ctx info

        """

        def _inner_dump(node, name=None, indent=''):
            level = indent + indentation
            name = name and name + '=' or ''
            values = list(FuncInfo._iter_node(node))
            if isinstance(node, list):
                prefix, suffix = '%s[' % name, ']'
            elif values:
                prefix, suffix = '%s%s(' % (name, type(node).__name__), ')'
            elif isinstance(node, special):
                prefix, suffix = name + type(node).__name__, ''
            else:
                return '%s%s' % (name, repr(node))
            node = [_inner_dump(a, b, level) for a, b in values if b != 'ctx']
            oneline = '%s%s%s' % (prefix, ', '.join(node), suffix)
            if len(oneline) + len(indent) < maxline:
                return '%s' % oneline
            if node and len(prefix) + len(node[0]) < maxmerged:
                prefix = '%s%s,' % (prefix, node.pop(0))
            node = (',\n%s' % level).join(node).lstrip()
            return '%s\n%s%s%s' % (prefix, level, node, suffix)

        return _inner_dump(node, name, initial_indent)


class ArgParser(argparse.ArgumentParser):
    """
    A simple ArgumentParser to print help when got error.
    """

    def error(self, message):
        self.print_help()
        from gettext import gettext as _

        self.exit(2, _('\n%s: error: %s\n') % (self.prog, message))


class FuncDiffInfo(object):
    """
    An object stores the result of candidate python code compared to referenced python code.
    """

    info_ref = None
    info_candidate = None
    plagiarism_count = 0
    total_count = 0

    @property
    def plagiarism_percent(self):
        return 0 if self.total_count == 0 else (self.plagiarism_count / float(self.total_count))

    def __str__(self):
        if isinstance(self.info_ref, FuncInfo) and isinstance(self.info_candidate, FuncInfo):
            return '{:<4.2}: ref {}, candidate {}'.format(self.plagiarism_percent,
                                                          self.info_ref.func_name + '<' + str(
                                                                  self.info_ref.func_node.lineno) + ':' + str(
                                                                  self.info_ref.func_node.col_offset) + '>',
                                                          self.info_candidate.func_name + '<' + str(
                                                                  self.info_candidate.func_node.lineno) + ':' + str(
                                                                  self.info_candidate.func_node.col_offset) + '>')
        return '{:<4.2}: ref {}, candidate {}'.format(0, None, None)


class UnifiedDiff(object):
    """
    Line diff algorithm to formatted AST string lines, naive but efficiency, result is good enough.
    """

    @staticmethod
    def diff(a, b):
        """
        Simpler and faster implementation of difflib.unified_diff.
        """
        assert a is not None
        assert b is not None
        a = a.func_ast_lines
        b = b.func_ast_lines

        def _gen():
            for group in difflib.SequenceMatcher(None, a, b).get_grouped_opcodes(0):
                for tag, i1, i2, j1, j2 in group:
                    if tag == 'equal':
                        for line in a[i1:i2]:
                            yield ''
                        continue
                    if tag in ('replace', 'delete'):
                        for line in a[i1:i2]:
                            yield '-'
                    if tag in ('replace', 'insert'):
                        for line in b[j1:j2]:
                            yield '+'

        return Counter(_gen())['-']

    @staticmethod
    def total(a, b):
        assert a is not None  # b may be None
        return len(a.func_ast_lines)


class TreeDiff(object):
    """
    Tree edit distance algorithm to AST, very slow and the result is not good for small functions.
    """

    @staticmethod
    def diff(a, b):
        assert a is not None
        assert b is not None

        def _str_dist(i, j):
            return 0 if i == j else 1

        def _get_label(n):
            return type(n).__name__

        def _get_children(n):
            if not hasattr(n, 'children'):
                n.children = list(ast.iter_child_nodes(n))
            return n.children

        import zss
        res = zss.distance(a.func_node, b.func_node, _get_children,
                           lambda node: 0,  # insert cost
                           lambda node: _str_dist(_get_label(node), ''),  # remove cost
                           lambda _a, _b: _str_dist(_get_label(_a), _get_label(_b)), )  # update cost
        return res

    @staticmethod
    def total(a, b):
        #  The count of AST nodes in referenced function
        assert a is not None  # b may be None
        return a.func_node.nsubnodes


class NoFuncException(Exception):
    def __init__(self, source):
        super(NoFuncException, self).__init__('Can not find any functions from code, index = {}'.format(source))
        self.source = source


def detect(pycode_string_list, diff_method=UnifiedDiff):
    if len(pycode_string_list) < 2:
        return []

    func_info_list = []
    for index, code_str in enumerate(pycode_string_list):
        root_node = ast.parse(code_str)
        collector = FuncNodeCollector()
        collector.visit(root_node)
        code_utf8_lines = code_str.splitlines(True)
        func_info = [FuncInfo(n, code_utf8_lines) for n in collector.get_function_nodes()]
        func_info_list.append((index, func_info))

    ast_diff_result = []
    index_ref, func_info_ref = func_info_list[0]
    if len(func_info_ref) == 0:
        raise NoFuncException(index_ref)

    for index_candidate, func_info_candidate in func_info_list[1:]:
        func_ast_diff_list = []
        for fi1 in func_info_ref:
            min_diff_value = int((1 << 31) - 1)
            min_diff_func_info = None
            for fi2 in func_info_candidate:
                dv = diff_method.diff(fi1, fi2)
                if dv < min_diff_value:
                    min_diff_value = dv
                    min_diff_func_info = fi2
                if dv == 0:  # entire function structure is plagiarized by candidate
                    break

            func_diff_info = FuncDiffInfo()
            func_diff_info.info_ref = fi1
            func_diff_info.info_candidate = min_diff_func_info
            func_diff_info.total_count = diff_method.total(fi1, min_diff_func_info)
            func_diff_info.plagiarism_count = func_diff_info.total_count - min_diff_value if min_diff_func_info else 0
            func_ast_diff_list.append(func_diff_info)
        func_ast_diff_list.sort(key=operator.attrgetter('plagiarism_percent'), reverse=True)
        ast_diff_result.append((index_candidate, func_ast_diff_list))

    return ast_diff_result


def _profile(fn):
    """
    A simple profile decorator
    :param fn: target function to be profiled
    :return: The wrapper function
    """
    import functools
    import cProfile

    @functools.wraps(fn)
    def _wrapper(*args, **kwargs):
        pr = cProfile.Profile()
        pr.enable()
        res = fn(*args, **kwargs)
        pr.disable()
        pr.print_stats('cumulative')
        return res

    return _wrapper


# @_profile
def main():
    """
    The console_scripts Entry Point in setup.py
    """

    def check_line_limit(value):
        ivalue = int(value)
        if ivalue < 0:
            raise argparse.ArgumentTypeError("%s is an invalid line limit" % value)
        return ivalue

    def check_percentage_limit(value):
        ivalue = float(value)
        if ivalue < 0:
            raise argparse.ArgumentTypeError("%s is an invalid percentage limit" % value)
        return ivalue

    def get_file(value):
        return open(value, 'rb')

    parser = ArgParser(description='A simple plagiarism detection tool for python code')
    parser.add_argument('files', type=get_file, nargs=2,
                        help='the input files')
    parser.add_argument('-l', type=check_line_limit, default=4,
                        help='if AST line of the function >= value then output detail (default: 4)')
    parser.add_argument('-p', type=check_percentage_limit, default=0.5,
                        help='if plagiarism percentage of the function >= value then output detail (default: 0.5)')
    args = parser.parse_args()
    pycode_list = [(f.name, f.read()) for f in args.files]
    try:
        results = detect([c[1] for c in pycode_list])
    except NoFuncException as ex:
        print('error: can not find functions from {}.'.format(pycode_list[ex.source][0]))
        return

    for index, func_ast_diff_list in results:
        print('ref: {}'.format(pycode_list[0][0]))
        print('candidate: {}'.format(pycode_list[index][0]))
        sum_total_count = sum(func_diff_info.total_count for func_diff_info in func_ast_diff_list)
        sum_plagiarism_count = sum(func_diff_info.plagiarism_count for func_diff_info in func_ast_diff_list)
        print('{:.2f} % ({}/{}) of ref code structure is plagiarized by candidate.'.format(
                sum_plagiarism_count / float(sum_total_count) * 100,
                sum_plagiarism_count,
                sum_total_count))
        print('candidate function plagiarism details (AST lines >= {} and plagiarism percentage >= {}):'.format(
                args.l,
                args.p,
        ))
        output_count = 0
        for func_diff_info in func_ast_diff_list:
            if len(func_diff_info.info_ref.func_ast_lines) >= args.l and func_diff_info.plagiarism_percent >= args.p:
                output_count = output_count + 1
                print(func_diff_info)
        if output_count == 0:
            print('<empty results>')


if __name__ == '__main__':
    main()