# Licensed under a 3-clause BSD style license - see LICENSE.rst

"""
This module provides utility functions for the models package.
"""

import warnings

# pylint: disable=invalid-name
from collections import UserDict
from inspect import signature

import numpy as np

from astropy import units as u

__all__ = ["poly_map_domain", "ellipse_extent"]


def make_binary_operator_eval(oper, f, g):
    """
    Given a binary operator (as a callable of two arguments) ``oper`` and
    two callables ``f`` and ``g`` which accept the same arguments,
    returns a *new* function that takes the same arguments as ``f`` and ``g``,
    but passes the outputs of ``f`` and ``g`` in the given ``oper``.

    ``f`` and ``g`` are assumed to return tuples (which may be 1-tuples).  The
    given operator is applied element-wise to tuple outputs).

    Example
    -------
    >>> from operator import add
    >>> def prod(x, y):
    ...     return (x * y,)
    ...
    >>> sum_of_prod = make_binary_operator_eval(add, prod, prod)
    >>> sum_of_prod(3, 5)
    (30,)
    """
    return lambda inputs, params: tuple(
        oper(x, y) for x, y in zip(f(inputs, params), g(inputs, params))
    )


def poly_map_domain(oldx, domain, window):
    """
    Map domain into window by shifting and scaling.

    Parameters
    ----------
    oldx : array
          original coordinates
    domain : list or tuple of length 2
          function domain
    window : list or tuple of length 2
          range into which to map the domain
    """
    domain = np.array(domain, dtype=np.float64)
    window = np.array(window, dtype=np.float64)
    if domain.shape != (2,) or window.shape != (2,):
        raise ValueError('Expected "domain" and "window" to be a tuple of size 2.')
    scl = (window[1] - window[0]) / (domain[1] - domain[0])
    off = (window[0] * domain[1] - window[1] * domain[0]) / (domain[1] - domain[0])
    return off + scl * oldx


def _validate_domain_window(value):
    if value is not None:
        if np.asanyarray(value).shape != (2,):
            raise ValueError("domain and window should be tuples of size 2.")
        return tuple(value)
    return value


def array_repr_oneline(array):
    """
    Represents a multi-dimensional Numpy array flattened onto a single line.
    """
    r = np.array2string(array, separator=", ", suppress_small=True)
    return " ".join(line.strip() for line in r.splitlines())


def combine_labels(left, right):
    """
    For use with the join operator &: Combine left input/output labels with
    right input/output labels.

    If none of the labels conflict then this just returns a sum of tuples.
    However if *any* of the labels conflict, this appends '0' to the left-hand
    labels and '1' to the right-hand labels so there is no ambiguity).
    """
    if set(left).intersection(right):
        left = tuple(label + "0" for label in left)
        right = tuple(label + "1" for label in right)

    return left + right


def ellipse_extent(a, b, theta):
    """
    Calculates the half size of a box encapsulating a rotated 2D
    ellipse.

    Parameters
    ----------
    a : float or `~astropy.units.Quantity`
        The ellipse semimajor axis.
    b : float or `~astropy.units.Quantity`
        The ellipse semiminor axis.
    theta : float or `~astropy.units.Quantity` ['angle']
        The rotation angle as an angular quantity
        (`~astropy.units.Quantity` or `~astropy.coordinates.Angle`) or
        a value in radians (as a float). The rotation angle increases
        counterclockwise.

    Returns
    -------
    offsets : tuple
        The absolute value of the offset distances from the ellipse center that
        define its bounding box region, ``(dx, dy)``.

    Examples
    --------
    .. plot::
        :include-source:

        import numpy as np
        import matplotlib.pyplot as plt
        from astropy.modeling.models import Ellipse2D
        from astropy.modeling.utils import ellipse_extent, render_model

        amplitude = 1
        x0 = 50
        y0 = 50
        a = 30
        b = 10
        theta = np.pi / 4

        model = Ellipse2D(amplitude, x0, y0, a, b, theta)
        dx, dy = ellipse_extent(a, b, theta)
        limits = [x0 - dx, x0 + dx, y0 - dy, y0 + dy]
        model.bounding_box = limits

        image = render_model(model)

        plt.imshow(image, cmap='binary', interpolation='nearest', alpha=.5,
                  extent = limits)
        plt.show()
    """
    from .parameters import Parameter  # prevent circular import

    if isinstance(theta, Parameter):
        if theta.quantity is None:
            theta = theta.value
        else:
            theta = theta.quantity

    t = np.arctan2(-b * np.tan(theta), a)
    dx = a * np.cos(t) * np.cos(theta) - b * np.sin(t) * np.sin(theta)

    t = np.arctan2(b, a * np.tan(theta))
    dy = b * np.sin(t) * np.cos(theta) + a * np.cos(t) * np.sin(theta)

    if isinstance(dx, u.Quantity) or isinstance(dy, u.Quantity):
        return np.abs(u.Quantity([dx, dy], subok=True))
    return np.abs([dx, dy])


def get_inputs_and_params(func):
    """
    Given a callable, determine the input variables and the
    parameters.

    Parameters
    ----------
    func : callable

    Returns
    -------
    inputs, params : tuple
        Each entry is a list of inspect.Parameter objects
    """
    sig = signature(func)

    inputs = []
    params = []
    for param in sig.parameters.values():
        if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD):
            raise ValueError("Signature must not have *args or **kwargs")
        if param.default == param.empty:
            inputs.append(param)
        else:
            params.append(param)

    return inputs, params


def _combine_equivalency_dict(keys, eq1=None, eq2=None):
    # Given two dictionaries that give equivalencies for a set of keys, for
    # example input value names, return a dictionary that includes all the
    # equivalencies
    eq = {}
    for key in keys:
        eq[key] = []
        if eq1 is not None and key in eq1:
            eq[key].extend(eq1[key])
        if eq2 is not None and key in eq2:
            eq[key].extend(eq2[key])
    return eq


def _to_radian(value):
    """Convert ``value`` to radian."""
    if isinstance(value, u.Quantity):
        return value.to(u.rad)
    return np.deg2rad(value)


def _to_orig_unit(value, raw_unit=None, orig_unit=None):
    """Convert value with ``raw_unit`` to ``orig_unit``."""
    if raw_unit is not None:
        return (value * raw_unit).to(orig_unit)
    return np.rad2deg(value)


class _ConstraintsDict(UserDict):
    """
    Wrapper around UserDict to allow updating the constraints
    on a Parameter when the dictionary is updated.
    """

    def __init__(self, model, constraint_type):
        self._model = model
        self.constraint_type = constraint_type
        c = {}
        for name in model.param_names:
            param = getattr(model, name)
            c[name] = getattr(param, constraint_type)
        super().__init__(c)

    def __setitem__(self, key, val):
        super().__setitem__(key, val)
        param = getattr(self._model, key)
        setattr(param, self.constraint_type, val)


class _SpecialOperatorsDict(UserDict):
    """
    Wrapper around UserDict to allow for better tracking of the Special
    Operators for CompoundModels. This dictionary is structured so that
    one cannot inadvertently overwrite an existing special operator.

    Parameters
    ----------
    unique_id: int
        the last used unique_id for a SPECIAL OPERATOR
    special_operators: dict
        a dictionary containing the special_operators

    Notes
    -----
    Direct setting of operators (`dict[key] = value`) into the
    dictionary has been deprecated in favor of the `.add(name, value)`
    method, so that unique dictionary keys can be generated and tracked
    consistently.
    """

    def __init__(self, unique_id=0, special_operators={}):
        super().__init__(special_operators)
        self._unique_id = unique_id

    def _set_value(self, key, val):
        if key in self:
            raise ValueError(f'Special operator "{key}" already exists')
        else:
            super().__setitem__(key, val)

    def __setitem__(self, key, val):
        self._set_value(key, val)
        warnings.warn(
            DeprecationWarning(
                """
            Special operator dictionary assignment has been deprecated.
            Please use `.add` instead, so that you can capture a unique
            key for your operator.
            """
            )
        )

    def _get_unique_id(self):
        self._unique_id += 1

        return self._unique_id

    def add(self, operator_name, operator):
        """
        Adds a special operator to the dictionary, and then returns the
        unique key that the operator is stored under for later reference.

        Parameters
        ----------
        operator_name: str
            the name for the operator
        operator: function
            the actual operator function which will be used

        Returns
        -------
        the unique operator key for the dictionary
            `(operator_name, unique_id)`
        """
        key = (operator_name, self._get_unique_id())

        self._set_value(key, operator)

        return key
