import unittest
from contextlib import contextmanager
from functools import cached_property

from numba import njit
from numba.core import errors, cpu, typing
from numba.core.descriptors import TargetDescriptor
from numba.core.dispatcher import TargetConfigurationStack
from numba.core.retarget import BasicRetarget
from numba.core.extending import overload
from numba.core.target_extension import (
    dispatcher_registry,
    CPUDispatcher,
    CPU,
    target_registry,
    jit_registry,
)


# ------------ A custom target ------------

CUSTOM_TARGET = ".".join([__name__, "CustomCPU"])


class CustomCPU(CPU):
    """Extend from the CPU target
    """
    pass


# Nested contexts to help with isolatings bits of compilations
class _NestedContext(object):
    _typing_context = None
    _target_context = None

    @contextmanager
    def nested(self, typing_context, target_context):
        old_nested = self._typing_context, self._target_context
        try:
            self._typing_context = typing_context
            self._target_context = target_context
            yield
        finally:
            self._typing_context, self._target_context = old_nested


# Implement a CustomCPU TargetDescriptor, this one borrows bits from the CPU
class CustomTargetDescr(TargetDescriptor):
    options = cpu.CPUTargetOptions
    _nested = _NestedContext()

    @cached_property
    def _toplevel_target_context(self):
        # Lazily-initialized top-level target context, for all threads
        return cpu.CPUContext(self.typing_context, self._target_name)

    @cached_property
    def _toplevel_typing_context(self):
        # Lazily-initialized top-level typing context, for all threads
        return typing.Context()

    @property
    def target_context(self):
        """
        The target context for DPU targets.
        """
        nested = self._nested._target_context
        if nested is not None:
            return nested
        else:
            return self._toplevel_target_context

    @property
    def typing_context(self):
        """
        The typing context for CPU targets.
        """
        nested = self._nested._typing_context
        if nested is not None:
            return nested
        else:
            return self._toplevel_typing_context

    def nested_context(self, typing_context, target_context):
        """
        A context manager temporarily replacing the contexts with the
        given ones, for the current thread of execution.
        """
        return self._nested.nested(typing_context, target_context)


custom_target = CustomTargetDescr(CUSTOM_TARGET)


class CustomCPUDispatcher(CPUDispatcher):
    targetdescr = custom_target


target_registry[CUSTOM_TARGET] = CustomCPU
dispatcher_registry[target_registry[CUSTOM_TARGET]] = CustomCPUDispatcher


def custom_jit(*args, **kwargs):
    assert 'target' not in kwargs
    assert '_target' not in kwargs
    return njit(*args, _target=CUSTOM_TARGET, **kwargs)


jit_registry[target_registry[CUSTOM_TARGET]] = custom_jit

# ------------ For switching target ------------


class CustomCPURetarget(BasicRetarget):
    @property
    def output_target(self):
        return CUSTOM_TARGET

    def compile_retarget(self, cpu_disp):
        kernel = njit(_target=CUSTOM_TARGET)(cpu_disp.py_func)
        return kernel


class TestRetargeting(unittest.TestCase):
    def setUp(self):
        # Generate fresh functions for each test method to avoid caching

        @njit(_target="cpu")
        def fixed_target(x):
            """
            This has a fixed target to "cpu".
            Cannot be used in CUSTOM_TARGET target.
            """
            return x + 10

        @njit
        def flex_call_fixed(x):
            """
            This has a flexible target, but uses a fixed target function.
            Cannot be used in CUSTOM_TARGET target.
            """
            return fixed_target(x) + 100

        @njit
        def flex_target(x):
            """
            This has a flexible target.
            Can be used in CUSTOM_TARGET target.
            """
            return x + 1000

        # Save these functions for use
        self.functions = locals()
        # Refresh the retarget function
        self.retarget = CustomCPURetarget()

    def switch_target(self):
        return TargetConfigurationStack.switch_target(self.retarget)

    @contextmanager
    def check_retarget_error(self):
        with self.assertRaises(errors.NumbaError) as raises:
            yield
        self.assertIn(f"{CUSTOM_TARGET} != cpu", str(raises.exception))

    def check_non_empty_cache(self):
        # Retargeting occurred. The cache must NOT be empty
        stats = self.retarget.cache.stats()
        # Because multiple function compilations are triggered, we don't know
        # precisely how many cache hit/miss there are.
        self.assertGreater(stats['hit'] + stats['miss'], 0)

    def test_case0(self):
        fixed_target = self.functions["fixed_target"]
        flex_target = self.functions["flex_target"]

        @njit
        def foo(x):
            x = fixed_target(x)
            x = flex_target(x)
            return x

        r = foo(123)
        self.assertEqual(r, 123 + 10 + 1000)
        # No retargeting occurred. The cache must be empty
        stats = self.retarget.cache.stats()
        self.assertEqual(stats, dict(hit=0, miss=0))

    def test_case1(self):
        flex_target = self.functions["flex_target"]

        @njit
        def foo(x):
            x = flex_target(x)
            return x

        with self.switch_target():
            r = foo(123)
        self.assertEqual(r, 123 + 1000)
        self.check_non_empty_cache()

    def test_case2(self):
        """
        The non-nested call into fixed_target should raise error.
        """
        fixed_target = self.functions["fixed_target"]
        flex_target = self.functions["flex_target"]

        @njit
        def foo(x):
            x = fixed_target(x)
            x = flex_target(x)
            return x

        with self.check_retarget_error():
            with self.switch_target():
                foo(123)

    def test_case3(self):
        """
        The nested call into fixed_target should raise error
        """
        flex_target = self.functions["flex_target"]
        flex_call_fixed = self.functions["flex_call_fixed"]

        @njit
        def foo(x):
            x = flex_call_fixed(x)  # calls fixed_target indirectly
            x = flex_target(x)
            return x

        with self.check_retarget_error():
            with self.switch_target():
                foo(123)

    def test_case4(self):
        """
        Same as case2 but flex_call_fixed() is invoked outside of CUSTOM_TARGET
        target before the switch_target.
        """
        flex_target = self.functions["flex_target"]
        flex_call_fixed = self.functions["flex_call_fixed"]

        r = flex_call_fixed(123)
        self.assertEqual(r, 123 + 100 + 10)

        @njit
        def foo(x):
            x = flex_call_fixed(x)  # calls fixed_target indirectly
            x = flex_target(x)
            return x

        with self.check_retarget_error():
            with self.switch_target():
                foo(123)

    def test_case5(self):
        """
        Tests overload resolution with target switching
        """

        def overloaded_func(x):
            pass

        @overload(overloaded_func, target=CUSTOM_TARGET)
        def ol_overloaded_func_custom_target(x):
            def impl(x):
                return 62830
            return impl

        @overload(overloaded_func, target='cpu')
        def ol_overloaded_func_cpu(x):
            def impl(x):
                return 31415
            return impl

        @njit
        def flex_resolve_overload(x):
            return

        @njit
        def foo(x):
            return x + overloaded_func(x)

        r = foo(123)
        self.assertEqual(r, 123 + 31415)

        with self.switch_target():
            r = foo(123)
            self.assertEqual(r, 123 + 62830)

        self.check_non_empty_cache()
