# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Define Numpy Ufuncs as Models.
"""

import numpy as np

from astropy.modeling.core import Model

trig_ufuncs = [
    "sin",
    "cos",
    "tan",
    "arcsin",
    "arccos",
    "arctan",
    "arctan2",
    "hypot",
    "sinh",
    "cosh",
    "tanh",
    "arcsinh",
    "arccosh",
    "arctanh",
    "deg2rad",
    "rad2deg",
]


math_ops = [
    "add",
    "subtract",
    "multiply",
    "logaddexp",
    "logaddexp2",
    "true_divide",
    "floor_divide",
    "negative",
    "positive",
    "power",
    "remainder",
    "fmod",
    "divmod",
    "absolute",
    "fabs",
    "rint",
    "exp",
    "exp2",
    "log",
    "log2",
    "log10",
    "expm1",
    "log1p",
    "sqrt",
    "square",
    "cbrt",
    "reciprocal",
    "divide",
    "mod",
]


supported_ufuncs = trig_ufuncs + math_ops


# These names are just aliases for other ufunc objects
# in the numpy API.  The alias name must occur later
# in the lists above.
alias_ufuncs = {
    "divide": "true_divide",
    "mod": "remainder",
}


class _NPUfuncModel(Model):
    _is_dynamic = True

    def __init__(self, **kwargs):
        super().__init__(**kwargs)


def _make_class_name(name):
    """Make a ufunc model class name from the name of the ufunc."""
    return name[0].upper() + name[1:] + "Ufunc"


def ufunc_model(name):
    """Define a Model from a Numpy ufunc name."""
    ufunc = getattr(np, name)
    nin = ufunc.nin
    nout = ufunc.nout
    if nin == 1:
        separable = True

        def evaluate(self, x):
            return self.func(x)

    else:
        separable = False

        def evaluate(self, x, y):
            return self.func(x, y)

    klass_name = _make_class_name(name)

    members = {
        "n_inputs": nin,
        "n_outputs": nout,
        "func": ufunc,
        "linear": False,
        "fittable": False,
        "_separable": separable,
        "_is_dynamic": True,
        "evaluate": evaluate,
    }

    klass = type(str(klass_name), (_NPUfuncModel,), members)
    klass.__module__ = "astropy.modeling.math_functions"
    return klass


__all__ = []

for name in supported_ufuncs:
    if name in alias_ufuncs:
        klass_name = _make_class_name(name)
        alias_klass_name = _make_class_name(alias_ufuncs[name])
        globals()[klass_name] = globals()[alias_klass_name]
        __all__.append(klass_name)  # noqa: PYI056
    else:
        m = ufunc_model(name)
        klass_name = m.__name__
        globals()[klass_name] = m
        __all__.append(klass_name)  # noqa: PYI056
