"""
Utilities for manipulating the Abstract Syntax Tree of Python constructs
"""
import re
import copy
import inspect
import ast
import textwrap


class NameVisitor(ast.NodeVisitor):
    """
    NodeVisitor that builds a set of all of the named identifiers in an AST
    """
    def __init__(self, *args, **kwargs):
        super(NameVisitor, self).__init__(*args, **kwargs)
        self.names = set()

    def visit_Name(self, node):
        self.names.add(node.id)

    def visit_arg(self, node):
        if hasattr(node, 'arg'):
            self.names.add(node.arg)
        elif hasattr(node, 'id'):
            self.names.add(node.id)

    def get_new_names(self, num_names):
        """
        Returns a list of new names that are not already present in the AST.

        New names will have the form _N, for N a non-negative integer. If the
        AST has no existing identifiers of this form, then the returned names
        will start at 0 ('_0', '_1', '_2'). If the AST already has identifiers
        of this form, then the names returned will not include the existing
        identifiers.

        Parameters
        ----------
        num_names: int
            The number of new names to return

        Returns
        -------
        list of str
        """
        prop_re = re.compile(r"^_(\d+)$")
        matching_names = [n for n in self.names if prop_re.match(n)]
        if matching_names:
            start_number = max([int(n[1:]) for n in matching_names]) + 1
        else:
            start_number = 0

        return ["_" + str(n) for n in
                range(start_number, start_number + num_names)]


class ExpandVarargTransformer(ast.NodeTransformer):
    """
    Node transformer that replaces the starred use of a variable in an AST
    with a collection of unstarred named variables.
    """
    def __init__(self, starred_name, expand_names, *args, **kwargs):
        """
        Parameters
        ----------
        starred_name: str
            The name of the starred variable to replace
        expand_names: list of stf
            List of the new names that should be used to replace the starred
            variable

        """
        super(ExpandVarargTransformer, self).__init__(*args, **kwargs)
        self.starred_name = starred_name
        self.expand_names = expand_names


class ExpandVarargTransformerStarred(ExpandVarargTransformer):
    # Python 3
    def visit_Starred(self, node):
        if node.value.id == self.starred_name:
            return [ast.Name(id=name, ctx=node.ctx) for name in
                    self.expand_names]
        else:
            return node


def function_to_ast(fn):
    """
    Get the AST representation of a function
    """
    # Get source code for function
    # Dedent is needed if this is a nested function
    fn_source = textwrap.dedent(inspect.getsource(fn))

    # Parse function source code into an AST
    fn_ast = ast.parse(fn_source)

    # # The function will be the fist element of the module body
    # fn_ast = module_ast.body[0]

    return fn_ast


def ast_to_source(ast):
    """Convert AST to source code string using the astor package"""
    import astor
    return astor.to_source(ast)


def compile_function_ast(fn_ast):
    """
    Compile function AST into a code object suitable for use in eval/exec
    """
    assert isinstance(fn_ast, ast.Module)
    fndef_ast = fn_ast.body[0]
    assert isinstance(fndef_ast, ast.FunctionDef)
    return compile(fn_ast, "<%s>" % fndef_ast.name, mode='exec')


def function_ast_to_function(fn_ast, stacklevel=1):
    # Validate
    assert isinstance(fn_ast, ast.Module)
    fndef_ast = fn_ast.body[0]
    assert isinstance(fndef_ast, ast.FunctionDef)

    # Compile AST to code object
    code = compile_function_ast(fn_ast)

    # Evaluate the function in a scope that includes the globals and
    # locals of desired frame.
    current_frame = inspect.currentframe()
    eval_frame = current_frame
    for _ in range(stacklevel):
        eval_frame = eval_frame.f_back

    eval_locals = eval_frame.f_locals
    eval_globals = eval_frame.f_globals
    del current_frame
    scope = copy.copy(eval_globals)
    scope.update(eval_locals)

    # Evaluate function in scope
    eval(code, scope)

    # Return the newly evaluated function from the scope
    return scope[fndef_ast.name]


def _build_arg(name):
    return ast.arg(arg=name)


def expand_function_ast_varargs(fn_ast, expand_number):
    """
    Given a function AST that use a variable length positional argument
    (e.g. *args), return a function that replaces the use of this argument
    with one or more fixed arguments.

    To be supported, a function must have a starred argument in the function
    signature, and it may only use this argument in starred form as the
    input to other functions.

    For example, suppose expand_number is 3 and fn_ast is an AST
    representing this function...

    def my_fn1(a, b, *args):
        print(a, b)
        other_fn(a, b, *args)

    Then this function will return the AST of a function equivalent to...

    def my_fn1(a, b, _0, _1, _2):
        print(a, b)
        other_fn(a, b, _0, _1, _2)

    If the input function uses `args` for anything other than passing it to
    other functions in starred form, an error will be raised.

    Parameters
    ----------
    fn_ast: ast.FunctionDef
    expand_number: int

    Returns
    -------
    ast.FunctionDef
    """
    assert isinstance(fn_ast, ast.Module)

    # Copy ast so we don't modify the input
    fn_ast = copy.deepcopy(fn_ast)

    # Extract function definition
    fndef_ast = fn_ast.body[0]
    assert isinstance(fndef_ast, ast.FunctionDef)

    # Get function args
    fn_args = fndef_ast.args

    # Function variable arity argument
    fn_vararg = fn_args.vararg

    # Require vararg
    if not fn_vararg:
        raise ValueError("""\
Input function AST does not have a variable length positional argument
(e.g. *args) in the function signature""")
    assert fn_vararg

    # Get vararg name
    if isinstance(fn_vararg, str):
        vararg_name = fn_vararg
    else:
        vararg_name = fn_vararg.arg

    # Compute new unique names to use in place of the variable argument
    before_name_visitor = NameVisitor()
    before_name_visitor.visit(fn_ast)
    expand_names = before_name_visitor.get_new_names(expand_number)

    # Replace use of *args in function body
    expand_transformer = ExpandVarargTransformerStarred

    new_fn_ast = expand_transformer(
        vararg_name, expand_names
    ).visit(fn_ast)

    new_fndef_ast = new_fn_ast.body[0]

    # Replace vararg with additional args in function signature
    new_fndef_ast.args.args.extend(
        [_build_arg(name=name) for name in expand_names]
    )
    new_fndef_ast.args.vararg = None

    # Run a new NameVistor an see if there were any other non-starred uses
    # of the variable length argument. If so, raise an exception
    after_name_visitor = NameVisitor()
    after_name_visitor.visit(new_fn_ast)
    if vararg_name in after_name_visitor.names:
        raise ValueError("""\
The variable length positional argument {n} is used in an unsupported context
""".format(n=vararg_name))

    # Remove decorators if present to avoid recursion
    fndef_ast.decorator_list = []

    # Add missing source code locations
    ast.fix_missing_locations(new_fn_ast)

    # Return result
    return new_fn_ast


def expand_varargs(expand_number):
    """
    Decorator to expand the variable length (starred) argument in a function
    signature with a fixed number of arguments.

    Parameters
    ----------
    expand_number: int
        The number of fixed arguments that should replace the variable length
        argument

    Returns
    -------
    function
        Decorator Function
    """
    if not isinstance(expand_number, int) or expand_number < 0:
        raise ValueError("expand_number must be a non-negative integer")

    def _expand_varargs(fn):
        fn_ast = function_to_ast(fn)
        fn_expanded_ast = expand_function_ast_varargs(fn_ast, expand_number)
        return function_ast_to_function(fn_expanded_ast, stacklevel=2)
    return _expand_varargs
