# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Comparison functions for `astropy.cosmology.Cosmology`.

This module is **NOT** public API. To use these functions, import them from
the top-level namespace -- :mod:`astropy.cosmology`.
This module will be moved.
"""

from __future__ import annotations

import functools
import inspect
from typing import Any, Callable, Union

import numpy as np
from numpy import False_, True_, ndarray

from astropy import table
from astropy.cosmology.core import Cosmology

__all__ = []  # Nothing is scoped here


##############################################################################
# PARAMETERS

_FormatType = Union[bool, None, str]
_FormatsT = Union[_FormatType, tuple[_FormatType, ...]]
_CompFnT = Callable[[Any, _FormatType], Cosmology]

_COSMO_AOK: set[Any] = {None, True_, False_, "astropy.cosmology"}
# The numpy bool also catches real bool for ops "==" and "in"


##############################################################################
# UTILITIES


class _CosmologyWrapper:
    """
    A private wrapper class to hide things from :mod:`numpy`.
    This should never be exposed to the user.
    """

    __slots__ = ("wrapped",)
    # Use less memory and speed up initialization.

    _cantbroadcast: tuple[type, ...] = (table.Row, table.Table)
    """
    Have to deal with things that do not broadcast well. e.g.
    `~astropy.table.Row` cannot be used in an array, even if ``dtype=object``
    and will raise a segfault when used in a `numpy.ufunc`.
    """

    wrapped: Any

    def __init__(self, wrapped: Any) -> None:
        self.wrapped = wrapped


# TODO! when py3.9+ use @functools.partial(np.frompyfunc, nin=2, nout=1)
# TODO! https://github.com/numpy/numpy/issues/9477 segfaults on astropy.row
# and np.vectorize can't coerce table to dtypes
def _wrap_to_ufunc(nin: int, nout: int) -> Callable[[_CompFnT], np.ufunc]:
    def wrapper(pyfunc: _CompFnT) -> np.ufunc:
        ufunc = np.frompyfunc(pyfunc, 2, 1)
        return ufunc

    return wrapper


@_wrap_to_ufunc(2, 1)
def _parse_format(cosmo: Any, format: _FormatType, /) -> Cosmology:
    """Parse Cosmology-like input into Cosmologies, given a format hint.

    Parameters
    ----------
    cosmo : |Cosmology|-like, positional-only
        |Cosmology| to parse.
    format : bool or None or str, positional-only
        Whether to allow, before equivalence is checked, the object to be
        converted to a |Cosmology|. This allows, e.g. a |Table| to be equivalent
        to a |Cosmology|. `False` (default) will not allow conversion. `True` or
        `None` will, and will use the auto-identification to try to infer the
        correct format. A `str` is assumed to be the correct format to use when
        converting.

    Returns
    -------
    |Cosmology| or generator thereof

    Raises
    ------
    TypeError
        If ``cosmo`` is not a |Cosmology| and ``format`` equals `False`.
    TypeError
        If ``cosmo`` is a |Cosmology| and ``format`` is not `None` or equal to
        `True`.
    """
    # Deal with private wrapper
    if isinstance(cosmo, _CosmologyWrapper):
        cosmo = cosmo.wrapped

    # Shortcut if already a cosmology
    if isinstance(cosmo, Cosmology):
        if format not in _COSMO_AOK:
            allowed = "/".join(map(str, _COSMO_AOK))
            raise ValueError(
                f"for parsing a Cosmology, 'format' must be {allowed}, not {format}"
            )
        return cosmo
    # Convert, if allowed.
    elif format == False_:  # catches False and False_
        raise TypeError(
            f"if 'format' is False, arguments must be a Cosmology, not {cosmo}"
        )
    else:
        format = None if format == True_ else format  # str->str, None/True/True_->None
        out = Cosmology.from_format(cosmo, format=format)  # this can error!

    return out


def _parse_formats(*cosmos: object, format: _FormatsT) -> ndarray:
    """Parse Cosmology-like to |Cosmology|, using provided formats.

    ``format`` is broadcast to match the shape of the cosmology arguments. Note
    that the cosmology arguments are not broadcast against ``format``, so it
    cannot determine the output shape.

    Parameters
    ----------
    *cosmos : |Cosmology|-like
        The objects to compare. Must be convertible to |Cosmology|, as specified
        by the corresponding ``format``.

    format : bool or None or str or array-like thereof, positional-only
        Whether to allow, before equivalence is checked, the object to be
        converted to a |Cosmology|. This allows, e.g. a |Table| to be equivalent
        to a |Cosmology|. `False` (default) will not allow conversion. `True` or
        `None` will, and will use the auto-identification to try to infer the
        correct format. A `str` is assumed to be the correct format to use when
        converting. Note ``format`` is broadcast as an object array to match the
        shape of ``cosmos`` so ``format`` cannot determine the output shape.

    Raises
    ------
    TypeError
        If any in ``cosmos`` is not a |Cosmology| and the corresponding
        ``format`` equals `False`.
    """
    formats = np.broadcast_to(np.array(format, dtype=object), len(cosmos))
    # parse each cosmo & format

    # Have to deal with things that do not broadcast well.
    # astropy.row cannot be used in an array, even if dtype=object
    # and will raise a segfault when used in a ufunc.
    towrap = (isinstance(cosmo, _CosmologyWrapper._cantbroadcast) for cosmo in cosmos)
    wcosmos = [
        c if not wrap else _CosmologyWrapper(c) for c, wrap in zip(cosmos, towrap)
    ]

    return _parse_format(wcosmos, formats)


def _comparison_decorator(pyfunc: Callable[..., Any]) -> Callable[..., Any]:
    """Decorator to make wrapper function that parses |Cosmology|-like inputs.

    Parameters
    ----------
    pyfunc : Python function object
        An arbitrary Python function.

    Returns
    -------
    callable[..., Any]
        Wrapped `pyfunc`, as described above.

    Notes
    -----
    All decorated functions should add the following to 'Parameters'.

    format : bool or None or str or array-like thereof, optional keyword-only
        Whether to allow the arguments to be converted to a |Cosmology|. This
        allows, e.g. a |Table| to be given instead a |Cosmology|. `False`
        (default) will not allow conversion. `True` or `None` will, and will use
        the auto-identification to try to infer the correct format. A `str` is
        assumed to be the correct format to use when converting. Note ``format``
        is broadcast as an object array to match the shape of ``cosmos`` so
        ``format`` cannot determine the output shape.
    """
    sig = inspect.signature(pyfunc)
    nin = sum(p.kind == 0 for p in sig.parameters.values())

    # Make wrapper function that parses cosmology-like inputs
    @functools.wraps(pyfunc)
    def wrapper(*cosmos: Any, format: _FormatsT = False, **kwargs: Any) -> bool:
        if len(cosmos) > nin:
            raise TypeError(
                f"{wrapper.__wrapped__.__name__} takes {nin} positional"
                f" arguments but {len(cosmos)} were given"
            )
        # Parse cosmologies to format. Only do specified number.
        cosmos = _parse_formats(*cosmos, format=format)
        # Evaluate pyfunc, erroring if didn't match specified number.
        result = wrapper.__wrapped__(*cosmos, **kwargs)
        # Return, casting to correct type casting is possible.
        return result

    return wrapper


##############################################################################
# COMPARISON FUNCTIONS


@_comparison_decorator
def cosmology_equal(
    cosmo1: Any, cosmo2: Any, /, *, allow_equivalent: bool = False
) -> bool:
    r"""Return element-wise equality check on the cosmologies.

    .. note::

        Cosmologies are currently scalar in their parameters.

    Parameters
    ----------
    cosmo1, cosmo2 : |Cosmology|-like
        The objects to compare. Must be convertible to |Cosmology|, as specified
        by ``format``.

    format : bool or None or str or tuple thereof, optional keyword-only
        Whether to allow the arguments to be converted to a |Cosmology|. This
        allows, e.g. a |Table| to be given instead a |Cosmology|. `False`
        (default) will not allow conversion. `True` or `None` will, and will use
        the auto-identification to try to infer the correct format. A `str` is
        assumed to be the correct format to use when converting. Note ``format``
        is broadcast as an object array to match the shape of ``cosmos`` so
        ``format`` cannot determine the output shape.

    allow_equivalent : bool, optional keyword-only
        Whether to allow cosmologies to be equal even if not of the same class.
        For example, an instance of |LambdaCDM| might have :math:`\Omega_0=1`
        and :math:`\Omega_k=0` and therefore be flat, like |FlatLambdaCDM|.

    Examples
    --------
    Assuming the following imports

        >>> import astropy.units as u
        >>> from astropy.cosmology import FlatLambdaCDM

    Two identical cosmologies are equal.

        >>> cosmo1 = FlatLambdaCDM(70 * (u.km/u.s/u.Mpc), 0.3)
        >>> cosmo2 = FlatLambdaCDM(70 * (u.km/u.s/u.Mpc), 0.3)
        >>> cosmology_equal(cosmo1, cosmo2)
        True

    And cosmologies with different parameters are not.

        >>> cosmo3 = FlatLambdaCDM(70 * (u.km/u.s/u.Mpc), 0.4)
        >>> cosmology_equal(cosmo1, cosmo3)
        False

    Two cosmologies may be equivalent even if not of the same class. In these
    examples the |LambdaCDM| has :attr:`~astropy.cosmology.LambdaCDM.Ode0` set
    to the same value calculated in |FlatLambdaCDM|.

        >>> from astropy.cosmology import LambdaCDM
        >>> cosmo3 = LambdaCDM(70 * (u.km/u.s/u.Mpc), 0.3, 0.7)
        >>> cosmology_equal(cosmo1, cosmo3)
        False
        >>> cosmology_equal(cosmo1, cosmo3, allow_equivalent=True)
        True

    While in this example, the cosmologies are not equivalent.

        >>> cosmo4 = FlatLambdaCDM(70 * (u.km/u.s/u.Mpc), 0.3, Tcmb0=3 * u.K)
        >>> cosmology_equal(cosmo3, cosmo4, allow_equivalent=True)
        False

    Also, using the keyword argument, the notion of equality is extended to any
    Python object that can be converted to a |Cosmology|.

        >>> mapping = cosmo2.to_format("mapping")
        >>> cosmology_equal(cosmo1, mapping, format=True)
        True

    Either (or both) arguments can be |Cosmology|-like.

        >>> cosmology_equal(mapping, cosmo2, format=True)
        True

    The list of valid formats, e.g. the |Table| in this example, may be checked
    with ``Cosmology.from_format.list_formats()``.

    As can be seen in the list of formats, not all formats can be
    auto-identified by ``Cosmology.from_format.registry``. Objects of these
    kinds can still be checked for equality, but the correct format string must
    be used.

        >>> yml = cosmo2.to_format("yaml")
        >>> cosmology_equal(cosmo1, yml, format=(None, "yaml"))
        True

    This also works with an array of ``format`` matching the number of
    cosmologies.

        >>> cosmology_equal(mapping, yml, format=[True, "yaml"])
        True
    """
    # Check parameter equality
    if not allow_equivalent:
        eq = cosmo1 == cosmo2

    else:
        # Check parameter equivalence
        # The options are: 1) same class & parameters; 2) same class, different
        # parameters; 3) different classes, equivalent parameters; 4) different
        # classes, different parameters. (1) & (3) => True, (2) & (4) => False.
        eq = cosmo1.__equiv__(cosmo2)
        if eq is NotImplemented:
            eq = cosmo2.__equiv__(cosmo1)  # that failed, try from 'other'

        eq = False if eq is NotImplemented else eq

    # TODO! include equality check of metadata

    return eq


@_comparison_decorator
def _cosmology_not_equal(
    cosmo1: Any, cosmo2: Any, /, *, allow_equivalent: bool = False
) -> bool:
    r"""Return element-wise cosmology non-equality check.

    .. note::

        Cosmologies are currently scalar in their parameters.

    Parameters
    ----------
    cosmo1, cosmo2 : |Cosmology|-like
        The objects to compare. Must be convertible to |Cosmology|, as specified
        by ``format``.

    out : ndarray, None, optional
        A location into which the result is stored. If provided, it must have a
        shape that the inputs broadcast to. If not provided or None, a
        freshly-allocated array is returned.

    format : bool or None or str or tuple thereof, optional keyword-only
        Whether to allow the arguments to be converted to a |Cosmology|. This
        allows, e.g. a |Table| to be given instead a Cosmology. `False`
        (default) will not allow conversion. `True` or `None` will, and will use
        the auto-identification to try to infer the correct format. A `str` is
        assumed to be the correct format to use when converting. ``format`` is
        broadcast to match the shape of the cosmology arguments. Note that the
        cosmology arguments are not broadcast against ``format``, so it cannot
        determine the output shape.

    allow_equivalent : bool, optional keyword-only
        Whether to allow cosmologies to be equal even if not of the same class.
        For example, an instance of |LambdaCDM| might have :math:`\Omega_0=1`
        and :math:`\Omega_k=0` and therefore be flat, like |FlatLambdaCDM|.

    See Also
    --------
    astropy.cosmology.cosmology_equal
        Element-wise equality check, with argument conversion to Cosmology.
    """
    neq = not cosmology_equal(cosmo1, cosmo2, allow_equivalent=allow_equivalent)
    # TODO! it might eventually be worth the speed boost to implement some of
    #       the internals of cosmology_equal here, but for now it's a hassle.

    return neq
