from importlib import import_module

from .logging import get_logger


logger = get_logger(__name__)


class _PatchedModuleObj:
    """Set all the modules components as attributes of the _PatchedModuleObj object."""

    def __init__(self, module, attrs=None):
        attrs = attrs or []
        if module is not None:
            for key in module.__dict__:
                if key in attrs or not key.startswith("__"):
                    setattr(self, key, getattr(module, key))
        self._original_module = module._original_module if isinstance(module, _PatchedModuleObj) else module


class patch_submodule:
    """
    Patch a submodule attribute of an object, by keeping all other submodules intact at all levels.

    Example::

        >>> import importlib
        >>> from datasets.load import dataset_module_factory
        >>> from datasets.streaming import patch_submodule, xjoin
        >>>
        >>> dataset_module = dataset_module_factory("snli")
        >>> snli_module = importlib.import_module(dataset_module.module_path)
        >>> patcher = patch_submodule(snli_module, "os.path.join", xjoin)
        >>> patcher.start()
        >>> assert snli_module.os.path.join is xjoin
    """

    _active_patches = []

    def __init__(self, obj, target: str, new, attrs=None):
        self.obj = obj
        self.target = target
        self.new = new
        self.key = target.split(".")[0]
        self.original = {}
        self.attrs = attrs or []

    def __enter__(self):
        *submodules, target_attr = self.target.split(".")

        # Patch modules:
        # it's used to patch attributes of submodules like "os.path.join";
        # in this case we need to patch "os" and "os.path"

        for i in range(len(submodules)):
            try:
                submodule = import_module(".".join(submodules[: i + 1]))
            except ModuleNotFoundError:
                continue
            # We iterate over all the globals in self.obj in case we find "os" or "os.path"
            for attr in self.obj.__dir__():
                obj_attr = getattr(self.obj, attr)
                # We don't check for the name of the global, but rather if its value *is* "os" or "os.path".
                # This allows to patch renamed modules like "from os import path as ospath".
                if obj_attr is submodule or (
                    (isinstance(obj_attr, _PatchedModuleObj) and obj_attr._original_module is submodule)
                ):
                    self.original[attr] = obj_attr
                    # patch at top level
                    setattr(self.obj, attr, _PatchedModuleObj(obj_attr, attrs=self.attrs))
                    patched = getattr(self.obj, attr)
                    # construct lower levels patches
                    for key in submodules[i + 1 :]:
                        setattr(patched, key, _PatchedModuleObj(getattr(patched, key, None), attrs=self.attrs))
                        patched = getattr(patched, key)
                    # finally set the target attribute
                    setattr(patched, target_attr, self.new)

        # Patch attribute itself:
        # it's used for builtins like "open",
        # and also to patch "os.path.join" we may also need to patch "join"
        # itself if it was imported as "from os.path import join".

        if submodules:  # if it's an attribute of a submodule like "os.path.join"
            try:
                attr_value = getattr(import_module(".".join(submodules)), target_attr)
            except (AttributeError, ModuleNotFoundError):
                return
            # We iterate over all the globals in self.obj in case we find "os.path.join"
            for attr in self.obj.__dir__():
                # We don't check for the name of the global, but rather if its value *is* "os.path.join".
                # This allows to patch renamed attributes like "from os.path import join as pjoin".
                if getattr(self.obj, attr) is attr_value:
                    self.original[attr] = getattr(self.obj, attr)
                    setattr(self.obj, attr, self.new)
        elif target_attr in globals()["__builtins__"]:  # if it'a s builtin like "open"
            self.original[target_attr] = globals()["__builtins__"][target_attr]
            setattr(self.obj, target_attr, self.new)
        else:
            raise RuntimeError(f"Tried to patch attribute {target_attr} instead of a submodule.")

    def __exit__(self, *exc_info):
        for attr in list(self.original):
            setattr(self.obj, attr, self.original.pop(attr))

    def start(self):
        """Activate a patch."""
        self.__enter__()
        self._active_patches.append(self)

    def stop(self):
        """Stop an active patch."""
        try:
            self._active_patches.remove(self)
        except ValueError:
            # If the patch hasn't been started this will fail
            return None

        return self.__exit__()
