# Licensed under a 3-clause BSD style license - see LICENSE.rst
# pylint: disable=invalid-name

"""
Optimization algorithms used in `~astropy.modeling.fitting`.
"""

import abc
import warnings

import numpy as np

from astropy.utils.exceptions import AstropyUserWarning

__all__ = ["Optimization", "SLSQP", "Simplex"]

# Maximum number of iterations
DEFAULT_MAXITER = 100

# Step for the forward difference approximation of the Jacobian
DEFAULT_EPS = np.sqrt(np.finfo(float).eps)

# Default requested accuracy
DEFAULT_ACC = 1e-07

DEFAULT_BOUNDS = (-(10**12), 10**12)


class Optimization(metaclass=abc.ABCMeta):
    """
    Base class for optimizers.

    Parameters
    ----------
    opt_method : callable
        Implements optimization method

    Notes
    -----
    The base Optimizer does not support any constraints by default; individual
    optimizers should explicitly set this list to the specific constraints
    it supports.

    """

    supported_constraints = []

    def __init__(self, opt_method):
        self._opt_method = opt_method
        self._maxiter = DEFAULT_MAXITER
        self._eps = DEFAULT_EPS
        self._acc = DEFAULT_ACC

    @property
    def maxiter(self):
        """Maximum number of iterations."""
        return self._maxiter

    @maxiter.setter
    def maxiter(self, val):
        """Set maxiter."""
        self._maxiter = val

    @property
    def eps(self):
        """Step for the forward difference approximation of the Jacobian."""
        return self._eps

    @eps.setter
    def eps(self, val):
        """Set eps value."""
        self._eps = val

    @property
    def acc(self):
        """Requested accuracy."""
        return self._acc

    @acc.setter
    def acc(self, val):
        """Set accuracy."""
        self._acc = val

    def __repr__(self):
        fmt = f"{self.__class__.__name__}()"
        return fmt

    @property
    def opt_method(self):
        """Return the optimization method."""
        return self._opt_method

    @abc.abstractmethod
    def __call__(self):
        raise NotImplementedError("Subclasses should implement this method")


class SLSQP(Optimization):
    """
    Sequential Least Squares Programming optimization algorithm.

    The algorithm is described in [1]_. It supports tied and fixed
    parameters, as well as bounded constraints. Uses
    `scipy.optimize.fmin_slsqp`.

    References
    ----------
    .. [1] http://www.netlib.org/toms/733
    """

    supported_constraints = ["bounds", "eqcons", "ineqcons", "fixed", "tied"]

    def __init__(self):
        from scipy.optimize import fmin_slsqp

        super().__init__(fmin_slsqp)
        self.fit_info = {
            "final_func_val": None,
            "numiter": None,
            "exit_mode": None,
            "message": None,
        }

    def __call__(self, objfunc, initval, fargs, **kwargs):
        """
        Run the solver.

        Parameters
        ----------
        objfunc : callable
            objection function
        initval : iterable
            initial guess for the parameter values
        fargs : tuple
            other arguments to be passed to the statistic function
        kwargs : dict
            other keyword arguments to be passed to the solver

        """
        kwargs["iter"] = kwargs.pop("maxiter", self._maxiter)

        if "epsilon" not in kwargs:
            kwargs["epsilon"] = self._eps
        if "acc" not in kwargs:
            kwargs["acc"] = self._acc
        # Get the verbosity level
        disp = kwargs.pop("verblevel", None)

        # set the values of constraints to match the requirements of fmin_slsqp
        model = fargs[0]
        pars = [getattr(model, name) for name in model.param_names]
        bounds = [par.bounds for par in pars if not (par.fixed or par.tied)]
        bounds = np.asarray(bounds)
        for i in bounds:
            if i[0] is None:
                i[0] = DEFAULT_BOUNDS[0]
            if i[1] is None:
                i[1] = DEFAULT_BOUNDS[1]
        # older versions of scipy require this array to be float
        bounds = np.asarray(bounds, dtype=float)
        eqcons = np.array(model.eqcons)
        ineqcons = np.array(model.ineqcons)
        fitparams, final_func_val, numiter, exit_mode, mess = self.opt_method(
            objfunc,
            initval,
            args=fargs,
            full_output=True,
            disp=disp,
            bounds=bounds,
            eqcons=eqcons,
            ieqcons=ineqcons,
            **kwargs,
        )

        self.fit_info["final_func_val"] = final_func_val
        self.fit_info["numiter"] = numiter
        self.fit_info["exit_mode"] = exit_mode
        self.fit_info["message"] = mess

        if exit_mode != 0:
            warnings.warn(
                "The fit may be unsuccessful; check "
                "fit_info['message'] for more information.",
                AstropyUserWarning,
            )

        return fitparams, self.fit_info


class Simplex(Optimization):
    """
    Neald-Mead (downhill simplex) algorithm.

    This algorithm [1]_ only uses function values, not derivatives.
    Uses `scipy.optimize.fmin`.

    References
    ----------
    .. [1] Nelder, J.A. and Mead, R. (1965), "A simplex method for function
       minimization", The Computer Journal, 7, pp. 308-313
    """

    supported_constraints = ["bounds", "fixed", "tied"]

    def __init__(self):
        from scipy.optimize import fmin as simplex

        super().__init__(simplex)
        self.fit_info = {
            "final_func_val": None,
            "numiter": None,
            "exit_mode": None,
            "num_function_calls": None,
        }

    def __call__(self, objfunc, initval, fargs, **kwargs):
        """
        Run the solver.

        Parameters
        ----------
        objfunc : callable
            objection function
        initval : iterable
            initial guess for the parameter values
        fargs : tuple
            other arguments to be passed to the statistic function
        kwargs : dict
            other keyword arguments to be passed to the solver

        """
        if "maxiter" not in kwargs:
            kwargs["maxiter"] = self._maxiter
        if "acc" in kwargs:
            self._acc = kwargs["acc"]
            kwargs.pop("acc")
        if "xtol" in kwargs:
            self._acc = kwargs["xtol"]
            kwargs.pop("xtol")
        # Get the verbosity level
        disp = kwargs.pop("verblevel", None)

        fitparams, final_func_val, numiter, funcalls, exit_mode = self.opt_method(
            objfunc,
            initval,
            args=fargs,
            xtol=self._acc,
            disp=disp,
            full_output=True,
            **kwargs,
        )
        self.fit_info["final_func_val"] = final_func_val
        self.fit_info["numiter"] = numiter
        self.fit_info["exit_mode"] = exit_mode
        self.fit_info["num_function_calls"] = funcalls
        if self.fit_info["exit_mode"] == 1:
            warnings.warn(
                "The fit may be unsuccessful; "
                "Maximum number of function evaluations reached.",
                AstropyUserWarning,
            )
        elif self.fit_info["exit_mode"] == 2:
            warnings.warn(
                "The fit may be unsuccessful; Maximum number of iterations reached.",
                AstropyUserWarning,
            )
        return fitparams, self.fit_info
