# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Helpers for letting numpy functions interact with distributions.

The module supplies helper routines for numpy functions that propagate
distributions appropriately., for use in the ``__array_function__``
implementation of `~astropy.uncertainty.core.Distribution`.  They are not
very useful on their own, but the ones with docstrings are included in
the documentation so that there is a place to find out how the distributions
are interpreted.

"""

import numpy as np

from astropy.units.quantity_helper.function_helpers import FunctionAssigner

# This module should not really be imported, but we define __all__
# such that sphinx can typeset the functions with docstrings.
# The latter are added to __all__ at the end.
__all__ = [
    "DISTRIBUTION_SAFE_FUNCTIONS",
    "DISPATCHED_FUNCTIONS",
    "UNSUPPORTED_FUNCTIONS",
]


DISTRIBUTION_SAFE_FUNCTIONS = set()
"""Set of functions that work fine on Distribution classes already.

Most of these internally use `numpy.ufunc` or other functions that
are already covered.
"""

DISPATCHED_FUNCTIONS = {}
"""Dict of functions that provide the numpy function's functionality.

These are for more complicated versions where the numpy function itself
cannot easily be used.  It should return the result of the function.

It should raise `NotImplementedError` if one of the arguments is a
distribution when it should not be or vice versa.
"""


FUNCTION_HELPERS = {}
"""Dict of functions for which Distribution can be used after some conversions.

The `dict` is keyed by the numpy function and the values are functions
that take the input arguments of the numpy function and organize these
for passing the distribution data to the numpy function, by returning
``args, kwargs, out``. Here, the former two are passed on, while ``out``
is used to indicate whether there was an output argument.  If ``out`` is
set to `True`, then no further processing should be done; otherwise, it
it is assumed that the function operates on unwrapped distributions and
that the results need to be rewrapped as |Distribution|.

The function should raise `NotImplementedError` if one of the arguments is a
distribution when it should not be or vice versa.

"""


UNSUPPORTED_FUNCTIONS = set()
"""Set of numpy functions that are not supported for distributions.

For most, distributions simply make no sense, but for others it may have
been lack of time.  Issues or PRs for support for functions are welcome.
"""


function_helper = FunctionAssigner(FUNCTION_HELPERS)
dispatched_function = FunctionAssigner(DISPATCHED_FUNCTIONS)


def is_distribution(x):
    from astropy.uncertainty import Distribution

    return isinstance(x, Distribution)


def get_n_samples(*arrays):
    """Get n_samples from the first Distribution amount arrays.

    The logic of getting ``n_samples`` from the first |Distribution|
    is that the code will raise an appropriate exception later if
    distributions do not have the same ``n_samples``.
    """
    # TODO: add verification if another function needs it.
    for array in arrays:
        if is_distribution(array):
            return array.n_samples

    raise RuntimeError("no Distribution found! Please raise an issue.")


@function_helper
def empty_like(prototype, dtype=None, *args, **kwargs):
    dtype = prototype._get_distribution_dtype(
        prototype.dtype if dtype is None else dtype, prototype.n_samples
    )
    return (prototype, dtype) + args, kwargs, None


@function_helper
def broadcast_arrays(*args, subok=False):
    """Broadcast arrays to a common shape.

    Like `numpy.broadcast_arrays`, applied to both distributions and other data.
    Note that ``subok`` is taken to mean whether or not subclasses of
    the distribution are allowed, i.e., for ``subok=False``,
    `~astropy.uncertainty.NdarrayDistribution` instances will be returned.
    """
    if not subok:
        args = tuple(
            arg.view(np.ndarray) if isinstance(arg, np.ndarray) else np.array(arg)
            for arg in args
        )
    return args, {"subok": True}, True


@function_helper
def concatenate(arrays, axis=0, out=None, dtype=None, casting="same_kind"):
    """Concatenate arrays.

    Like `numpy.concatenate`, but any array that is not already a |Distribution|
    is turned into one with identical samples.
    """
    n_samples = get_n_samples(*arrays, out)
    converted = tuple(
        array.distribution
        if is_distribution(array)
        else (
            np.broadcast_to(
                array[..., np.newaxis], array.shape + (n_samples,), subok=True
            )
            if getattr(array, "shape", False)
            else array
        )
        for array in arrays
    )
    if axis < 0:
        axis = axis - 1  # not in-place, just in case.
    kwargs = dict(axis=axis, dtype=dtype, casting=casting)
    if out is not None:
        if is_distribution(out):
            kwargs["out"] = out.distribution
        else:
            raise NotImplementedError
    return (converted,), kwargs, out


# Add any dispatched or helper function that has a docstring to __all__, so
# they will be typeset by sphinx. The logic is that for those presumably the
# way distributions are dealt with is not entirely obvious.
__all__ += sorted(  # noqa: PLE0605
    helper.__name__
    for helper in (set(FUNCTION_HELPERS.values()) | set(DISPATCHED_FUNCTIONS.values()))
    if helper.__doc__
)
