"""
Main Lomb-Scargle Implementation.

The ``lombscargle`` function here is essentially a sophisticated switch
statement for the various implementations available in this submodule
"""

__all__ = ["lombscargle", "available_methods"]

import numpy as np

from .chi2_impl import lombscargle_chi2
from .cython_impl import lombscargle_cython
from .fast_impl import lombscargle_fast
from .fastchi2_impl import lombscargle_fastchi2
from .scipy_impl import lombscargle_scipy
from .slow_impl import lombscargle_slow

METHODS = {
    "slow": lombscargle_slow,
    "fast": lombscargle_fast,
    "chi2": lombscargle_chi2,
    "scipy": lombscargle_scipy,
    "fastchi2": lombscargle_fastchi2,
    "cython": lombscargle_cython,
}


def available_methods():
    methods = ["auto", "slow", "chi2", "cython", "fast", "fastchi2"]

    # Scipy required for scipy algorithm (obviously)
    try:
        import scipy  # noqa: F401
    except ImportError:
        pass
    else:
        methods.append("scipy")
    return methods


def _is_regular(frequency):
    frequency = np.asarray(frequency)

    if frequency.ndim != 1:
        return False
    elif len(frequency) == 1:
        return True
    else:
        diff = np.diff(frequency)
        return np.allclose(diff[0], diff)


def _get_frequency_grid(frequency, assume_regular_frequency=False):
    """Utility to get grid parameters from a frequency array.

    Parameters
    ----------
    frequency : array-like or `~astropy.units.Quantity` ['frequency']
        input frequency grid
    assume_regular_frequency : bool (default = False)
        if True, then do not check whether frequency is a regular grid

    Returns
    -------
    f0, df, N : scalar
        Parameters such that all(frequency == f0 + df * np.arange(N))
    """
    frequency = np.asarray(frequency)
    if frequency.ndim != 1:
        raise ValueError("frequency grid must be 1 dimensional")
    elif len(frequency) == 1:
        return frequency[0], frequency[0], 1
    elif not (assume_regular_frequency or _is_regular(frequency)):
        raise ValueError("frequency must be a regular grid")

    return frequency[0], frequency[1] - frequency[0], len(frequency)


def validate_method(method, dy, fit_mean, nterms, frequency, assume_regular_frequency):
    """
    Validate the method argument, and if method='auto'
    choose the appropriate method.
    """
    methods = available_methods()
    prefer_fast = len(frequency) > 200 and (
        assume_regular_frequency or _is_regular(frequency)
    )
    prefer_scipy = "scipy" in methods and dy is None and not fit_mean

    # automatically choose the appropriate method
    if method == "auto":
        if nterms != 1:
            if prefer_fast:
                method = "fastchi2"
            else:
                method = "chi2"
        elif prefer_fast:
            method = "fast"
        elif prefer_scipy:
            method = "scipy"
        else:
            method = "cython"

    if method not in METHODS:
        raise ValueError(f"invalid method: {method}")

    return method


def lombscargle(
    t,
    y,
    dy=None,
    frequency=None,
    method="auto",
    assume_regular_frequency=False,
    normalization="standard",
    fit_mean=True,
    center_data=True,
    method_kwds=None,
    nterms=1,
):
    """
    Compute the Lomb-scargle Periodogram with a given method.

    Parameters
    ----------
    t : array-like
        sequence of observation times
    y : array-like
        sequence of observations associated with times t
    dy : float or array-like, optional
        error or sequence of observational errors associated with times t
    frequency : array-like
        frequencies (not angular frequencies) at which to evaluate the
        periodogram. If not specified, optimal frequencies will be chosen using
        a heuristic which will attempt to provide sufficient frequency range
        and sampling so that peaks will not be missed. Note that in order to
        use method='fast', frequencies must be regularly spaced.
    method : str, optional
        specify the lomb scargle implementation to use. Options are:

        - 'auto': choose the best method based on the input
        - 'fast': use the O[N log N] fast method. Note that this requires
          evenly-spaced frequencies: by default this will be checked unless
          ``assume_regular_frequency`` is set to True.
        - `slow`: use the O[N^2] pure-python implementation
        - `chi2`: use the O[N^2] chi2/linear-fitting implementation
        - `fastchi2`: use the O[N log N] chi2 implementation. Note that this
          requires evenly-spaced frequencies: by default this will be checked
          unless `assume_regular_frequency` is set to True.
        - `scipy`: use ``scipy.signal.lombscargle``, which is an O[N^2]
          implementation written in C. Note that this does not support
          heteroskedastic errors.

    assume_regular_frequency : bool, optional
        if True, assume that the input frequency is of the form
        freq = f0 + df * np.arange(N). Only referenced if method is 'auto'
        or 'fast'.
    normalization : str, optional
        Normalization to use for the periodogram.
        Options are 'standard' or 'psd'.
    fit_mean : bool, optional
        if True, include a constant offset as part of the model at each
        frequency. This can lead to more accurate results, especially in the
        case of incomplete phase coverage.
    center_data : bool, optional
        if True, pre-center the data by subtracting the weighted mean
        of the input data. This is especially important if `fit_mean = False`
    method_kwds : dict, optional
        additional keywords to pass to the lomb-scargle method
    nterms : int, optional
        number of Fourier terms to use in the periodogram.
        Not supported with every method.

    Returns
    -------
    PLS : array-like
        Lomb-Scargle power associated with each frequency omega
    """
    # frequencies should be one-dimensional arrays
    output_shape = frequency.shape
    frequency = frequency.ravel()

    # we'll need to adjust args and kwds for each method
    args = (t, y, dy)
    kwds = dict(
        frequency=frequency,
        center_data=center_data,
        fit_mean=fit_mean,
        normalization=normalization,
        nterms=nterms,
        **(method_kwds or {}),
    )

    method = validate_method(
        method,
        dy=dy,
        fit_mean=fit_mean,
        nterms=nterms,
        frequency=frequency,
        assume_regular_frequency=assume_regular_frequency,
    )

    # scipy doesn't support dy or fit_mean=True
    if method == "scipy":
        if kwds.pop("fit_mean"):
            raise ValueError("scipy method does not support fit_mean=True")
        if dy is not None:
            dy = np.ravel(np.asarray(dy))
            if not np.allclose(dy[0], dy):
                raise ValueError("scipy method only supports uniform uncertainties dy")
        args = (t, y)

    # fast methods require frequency expressed as a grid
    if method.startswith("fast"):
        f0, df, Nf = _get_frequency_grid(
            kwds.pop("frequency"), assume_regular_frequency
        )
        kwds.update(f0=f0, df=df, Nf=Nf)

    # only chi2 methods support nterms
    if not method.endswith("chi2"):
        if kwds.pop("nterms") != 1:
            raise ValueError(
                "nterms != 1 only supported with 'chi2' or 'fastchi2' methods"
            )

    PLS = METHODS[method](*args, **kwds)
    return PLS.reshape(output_shape)
