from __future__ import annotations

from dask.rewrite import VAR, RewriteRule, RuleSet, Traverser, args, head
from dask.utils_test import add, inc


def double(x):
    return x * 2


def test_head():
    assert head((inc, 1)) == inc
    assert head((add, 1, 2)) == add
    assert head((add, (inc, 1), (inc, 1))) == add
    assert head([1, 2, 3]) == list


def test_args():
    assert args((inc, 1)) == (1,)
    assert args((add, 1, 2)) == (1, 2)
    assert args(1) == ()
    assert args([1, 2, 3]) == [1, 2, 3]


def test_traverser():
    term = (add, (inc, 1), (double, (inc, 1), 2))
    t = Traverser(term)
    t2 = t.copy()
    assert t.current == add
    t.next()
    assert t.current == inc
    # Ensure copies aren't advanced when the original advances
    assert t2.current == add
    t.skip()
    assert t.current == double
    t.next()
    assert t.current == inc
    assert list(t2) == [add, inc, 1, double, inc, 1, 2]


vars = ("a", "b", "c")
# add(a, 1) -> inc(a)
rule1 = RewriteRule((add, "a", 1), (inc, "a"), vars)
# add(a, a) -> double(a)
rule2 = RewriteRule((add, "a", "a"), (double, "a"), vars)
# add(inc(a), inc(a)) -> add(double(a), 2)
rule3 = RewriteRule((add, (inc, "a"), (inc, "a")), (add, (double, "a"), 2), vars)
# add(inc(b), inc(a)) -> add(add(a, b), 2)
rule4 = RewriteRule((add, (inc, "b"), (inc, "a")), (add, (add, "a", "b"), 2), vars)
# sum([c, b, a]) -> add(add(a, b), c)
rule5 = RewriteRule((sum, ["c", "b", "a"]), (add, (add, "a", "b"), "c"), vars)
# list(x) -> x if x is a list


def repl_list(sd):
    x = sd["x"]
    if isinstance(x, list):
        return x
    else:
        return (list, x)


rule6 = RewriteRule((list, "x"), repl_list, ("x",))


def test_RewriteRule():
    # Test extraneous vars are removed, varlist is correct
    assert rule1.vars == ("a",)
    assert rule1._varlist == ["a"]
    assert rule2.vars == ("a",)
    assert rule2._varlist == ["a", "a"]
    assert rule3.vars == ("a",)
    assert rule3._varlist == ["a", "a"]
    assert rule4.vars == ("a", "b")
    assert rule4._varlist == ["b", "a"]
    assert rule5.vars == ("a", "b", "c")
    assert rule5._varlist == ["c", "b", "a"]


def test_RewriteRuleSubs():
    # Test both rhs substitution and callable rhs
    assert rule1.subs({"a": 1}) == (inc, 1)
    assert rule6.subs({"x": [1, 2, 3]}) == [1, 2, 3]


rules = [rule1, rule2, rule3, rule4, rule5, rule6]
rs = RuleSet(*rules)


def test_RuleSet():
    net = (
        {
            add: (
                {
                    VAR: ({VAR: ({}, [1]), 1: ({}, [0])}, []),
                    inc: ({VAR: ({inc: ({VAR: ({}, [2, 3])}, [])}, [])}, []),
                },
                [],
            ),
            list: ({VAR: ({}, [5])}, []),
            sum: ({list: ({VAR: ({VAR: ({VAR: ({}, [4])}, [])}, [])}, [])}, []),
        },
        [],
    )
    assert rs._net == net
    assert rs.rules == rules


def test_matches():
    term = (add, 2, 1)
    matches = list(rs.iter_matches(term))
    assert len(matches) == 1
    assert matches[0] == (rule1, {"a": 2})
    # Test matches specific before general
    term = (add, 1, 1)
    matches = list(rs.iter_matches(term))
    assert len(matches) == 2
    assert matches[0] == (rule1, {"a": 1})
    assert matches[1] == (rule2, {"a": 1})
    # Test matches unhashable. What it's getting rewritten to doesn't make
    # sense, this is just to test that it works. :)
    term = (add, [1], [1])
    matches = list(rs.iter_matches(term))
    assert len(matches) == 1
    assert matches[0] == (rule2, {"a": [1]})
    # Test match at depth
    term = (add, (inc, 1), (inc, 1))
    matches = list(rs.iter_matches(term))
    assert len(matches) == 3
    assert matches[0] == (rule3, {"a": 1})
    assert matches[1] == (rule4, {"a": 1, "b": 1})
    assert matches[2] == (rule2, {"a": (inc, 1)})
    # Test non-linear pattern checking
    term = (add, 2, 3)
    matches = list(rs.iter_matches(term))
    assert len(matches) == 0


def test_rewrite():
    # Rewrite inside list
    term = (sum, [(add, 1, 1), (add, 1, 1), (add, 1, 1)])
    new_term = rs.rewrite(term)
    assert new_term == (add, (add, (inc, 1), (inc, 1)), (inc, 1))
    # Rules aren't applied to exhaustion, this can be further simplified
    new_term = rs.rewrite(new_term)
    assert new_term == (add, (add, (double, 1), 2), (inc, 1))
    term = (
        add,
        (add, (add, (add, 1, 2), (add, 1, 2)), (add, (add, 1, 2), (add, 1, 2))),
        1,
    )
    assert rs.rewrite(term) == (inc, (double, (double, (add, 1, 2))))
    # Callable RewriteRule rhs
    term = (list, [1, 2, 3])
    assert rs.rewrite(term) == [1, 2, 3]
    term = (list, (map, inc, [1, 2, 3]))
    assert rs.rewrite(term) == term
