# Licensed under a 3-clause BSD style license - see LICENSE.rst

import inspect
import sys
from io import StringIO

import numpy as np
import pytest

from astropy import units as u
from astropy.cosmology import core, flrw
from astropy.cosmology.funcs import _z_at_scalar_value, z_at_value
from astropy.cosmology.realizations import (
    WMAP1,
    WMAP3,
    WMAP5,
    WMAP7,
    WMAP9,
    Planck13,
    Planck15,
    Planck18,
)
from astropy.units import allclose
from astropy.utils.compat.optional_deps import HAS_SCIPY
from astropy.utils.exceptions import AstropyUserWarning


@pytest.mark.skipif(not HAS_SCIPY, reason="test requires scipy")
def test_z_at_value_scalar():
    # These are tests of expected values, and hence have less precision
    # than the roundtrip tests below (test_z_at_value_roundtrip);
    # here we have to worry about the cosmological calculations
    # giving slightly different values on different architectures,
    # there we are checking internal consistency on the same architecture
    # and so can be more demanding
    cosmo = Planck13
    assert allclose(z_at_value(cosmo.age, 2 * u.Gyr), 3.19812268, rtol=1e-6)
    assert allclose(z_at_value(cosmo.lookback_time, 7 * u.Gyr), 0.795198375, rtol=1e-6)
    assert allclose(z_at_value(cosmo.distmod, 46 * u.mag), 1.991389168, rtol=1e-6)
    assert allclose(
        z_at_value(cosmo.luminosity_distance, 1e4 * u.Mpc), 1.36857907, rtol=1e-6
    )
    assert allclose(
        z_at_value(cosmo.luminosity_distance, 26.037193804 * u.Gpc, ztol=1e-10),
        3,
        rtol=1e-9,
    )
    assert allclose(
        z_at_value(cosmo.angular_diameter_distance, 1500 * u.Mpc, zmax=2),
        0.681277696,
        rtol=1e-6,
    )
    assert allclose(
        z_at_value(cosmo.angular_diameter_distance, 1500 * u.Mpc, zmin=2.5),
        3.7914908,
        rtol=1e-6,
    )

    # test behavior when the solution is outside z limits (should
    # raise a CosmologyError)
    with pytest.raises(core.CosmologyError):
        with pytest.warns(AstropyUserWarning, match=r"fval is not bracketed"):
            z_at_value(cosmo.angular_diameter_distance, 1500 * u.Mpc, zmax=0.5)

    with pytest.raises(core.CosmologyError):
        with pytest.warns(AstropyUserWarning, match=r"fval is not bracketed"):
            z_at_value(cosmo.angular_diameter_distance, 1500 * u.Mpc, zmin=4.0)


@pytest.mark.skipif(not HAS_SCIPY, reason="test requires scipy")
class Test_ZatValue:
    def setup_class(self):
        self.cosmo = Planck13

    def test_broadcast_arguments(self):
        """Test broadcast of arguments."""
        # broadcasting main argument
        assert allclose(
            z_at_value(self.cosmo.age, [2, 7] * u.Gyr),
            [3.1981206134773115, 0.7562044333305182],
            rtol=1e-6,
        )

        # basic broadcast of secondary arguments
        assert allclose(
            z_at_value(
                self.cosmo.angular_diameter_distance,
                1500 * u.Mpc,
                zmin=[0, 2.5],
                zmax=[2, 4],
            ),
            [0.681277696, 3.7914908],
            rtol=1e-6,
        )

        # more interesting broadcast
        assert allclose(
            z_at_value(
                self.cosmo.angular_diameter_distance,
                1500 * u.Mpc,
                zmin=[[0, 2.5]],
                zmax=[2, 4],
            ),
            [[0.681277696, 3.7914908]],
            rtol=1e-6,
        )

    def test_broadcast_bracket(self):
        """`bracket` has special requirements."""
        # start with an easy one
        assert allclose(
            z_at_value(self.cosmo.age, 2 * u.Gyr, bracket=None),
            3.1981206134773115,
            rtol=1e-6,
        )

        # now actually have a bracket
        assert allclose(
            z_at_value(self.cosmo.age, 2 * u.Gyr, bracket=[0, 4]),
            3.1981206134773115,
            rtol=1e-6,
        )

        # now a bad length
        with pytest.raises(ValueError, match="sequence"):
            z_at_value(self.cosmo.age, 2 * u.Gyr, bracket=[0, 4, 4, 5])

        # now the wrong dtype : an ndarray, but not an object array
        with pytest.raises(TypeError, match="dtype"):
            z_at_value(self.cosmo.age, 2 * u.Gyr, bracket=np.array([0, 4]))

        # now an object array of brackets
        bracket = np.array([[0, 4], [0, 3, 4]], dtype=object)
        assert allclose(
            z_at_value(self.cosmo.age, 2 * u.Gyr, bracket=bracket),
            [3.1981206134773115, 3.1981206134773115],
            rtol=1e-6,
        )

    def test_bad_broadcast(self):
        """Shapes mismatch as expected"""
        with pytest.raises(ValueError, match="broadcast"):
            z_at_value(
                self.cosmo.angular_diameter_distance,
                1500 * u.Mpc,
                zmin=[0, 2.5, 0.1],
                zmax=[2, 4],
            )

    def test_scalar_input_to_output(self):
        """Test scalar input returns a scalar."""
        z = z_at_value(
            self.cosmo.angular_diameter_distance, 1500 * u.Mpc, zmin=0, zmax=2
        )
        assert isinstance(z, u.Quantity)
        assert z.dtype == np.float64
        assert z.shape == ()


@pytest.mark.skipif(not HAS_SCIPY, reason="test requires scipy")
def test_z_at_value_numpyvectorize():
    """Test that numpy vectorize fails on Quantities.

    If this test starts failing then numpy vectorize can be used instead of
    the home-brewed vectorization. Please submit a PR making the change.
    """
    z_at_value = np.vectorize(
        _z_at_scalar_value, excluded=["func", "method", "verbose"]
    )
    with pytest.raises(u.UnitConversionError, match="dimensionless quantities"):
        z_at_value(Planck15.age, 10 * u.Gyr)


@pytest.mark.skipif(not HAS_SCIPY, reason="test requires scipy")
def test_z_at_value_verbose(monkeypatch):
    cosmo = Planck13

    # Test the "verbose" flag. Since this uses "print", need to mod stdout
    mock_stdout = StringIO()
    monkeypatch.setattr(sys, "stdout", mock_stdout)

    resx = z_at_value(cosmo.age, 2 * u.Gyr, verbose=True)
    assert str(resx.value) in mock_stdout.getvalue()  # test "verbose" prints res


@pytest.mark.skipif(not HAS_SCIPY, reason="test requires scipy")
@pytest.mark.parametrize("method", ["Brent", "Golden", "Bounded"])
def test_z_at_value_bracketed(method):
    """
    Test 2 solutions for angular diameter distance by not constraining zmin, zmax,
    but setting `bracket` on the appropriate side of the turning point z.
    Setting zmin / zmax should override `bracket`.
    """
    cosmo = Planck13

    if method == "Bounded":
        with pytest.warns(AstropyUserWarning, match=r"fval is not bracketed"):
            z = z_at_value(cosmo.angular_diameter_distance, 1500 * u.Mpc, method=method)
        if z > 1.6:
            z = 3.7914908
            bracket = (0.9, 1.5)
        else:
            z = 0.6812777
            bracket = (1.6, 2.0)
        with pytest.warns(UserWarning, match=r"Option 'bracket' is ignored"):
            assert allclose(
                z_at_value(
                    cosmo.angular_diameter_distance,
                    1500 * u.Mpc,
                    method=method,
                    bracket=bracket,
                ),
                z,
                rtol=1e-6,
            )
    else:
        assert allclose(
            z_at_value(
                cosmo.angular_diameter_distance,
                1500 * u.Mpc,
                method=method,
                bracket=(0.3, 1.0),
            ),
            0.6812777,
            rtol=1e-6,
        )
        assert allclose(
            z_at_value(
                cosmo.angular_diameter_distance,
                1500 * u.Mpc,
                method=method,
                bracket=(2.0, 4.0),
            ),
            3.7914908,
            rtol=1e-6,
        )
        assert allclose(
            z_at_value(
                cosmo.angular_diameter_distance,
                1500 * u.Mpc,
                method=method,
                bracket=(0.1, 1.5),
            ),
            0.6812777,
            rtol=1e-6,
        )
        assert allclose(
            z_at_value(
                cosmo.angular_diameter_distance,
                1500 * u.Mpc,
                method=method,
                bracket=(0.1, 1.0, 2.0),
            ),
            0.6812777,
            rtol=1e-6,
        )
        with pytest.warns(AstropyUserWarning, match=r"fval is not bracketed"):
            assert allclose(
                z_at_value(
                    cosmo.angular_diameter_distance,
                    1500 * u.Mpc,
                    method=method,
                    bracket=(0.9, 1.5),
                ),
                0.6812777,
                rtol=1e-6,
            )
            assert allclose(
                z_at_value(
                    cosmo.angular_diameter_distance,
                    1500 * u.Mpc,
                    method=method,
                    bracket=(1.6, 2.0),
                ),
                3.7914908,
                rtol=1e-6,
            )
        assert allclose(
            z_at_value(
                cosmo.angular_diameter_distance,
                1500 * u.Mpc,
                method=method,
                bracket=(1.6, 2.0),
                zmax=1.6,
            ),
            0.6812777,
            rtol=1e-6,
        )
        assert allclose(
            z_at_value(
                cosmo.angular_diameter_distance,
                1500 * u.Mpc,
                method=method,
                bracket=(0.9, 1.5),
                zmin=1.5,
            ),
            3.7914908,
            rtol=1e-6,
        )

    with pytest.raises(core.CosmologyError):
        with pytest.warns(AstropyUserWarning, match=r"fval is not bracketed"):
            z_at_value(
                cosmo.angular_diameter_distance,
                1500 * u.Mpc,
                method=method,
                bracket=(3.9, 5.0),
                zmin=4.0,
            )


@pytest.mark.skipif(not HAS_SCIPY, reason="test requires scipy")
@pytest.mark.parametrize("method", ["Brent", "Golden", "Bounded"])
def test_z_at_value_unconverged(method):
    """
    Test warnings on non-converged solution when setting `maxfun` to too small iteration number -
    only 'Bounded' returns status value and specific message.
    """
    cosmo = Planck18
    ztol = {"Brent": [1e-4, 1e-4], "Golden": [1e-3, 1e-2], "Bounded": [1e-3, 1e-1]}

    if method == "Bounded":
        ctx = pytest.warns(
            AstropyUserWarning,
            match="Solver returned 1: Maximum number of function calls reached",
        )
    else:
        ctx = pytest.warns(AstropyUserWarning, match="Solver returned None")

    with ctx:
        z0 = z_at_value(
            cosmo.angular_diameter_distance, 1 * u.Gpc, zmax=2, maxfun=13, method=method
        )
    with ctx:
        z1 = z_at_value(
            cosmo.angular_diameter_distance, 1 * u.Gpc, zmin=2, maxfun=13, method=method
        )

    assert allclose(z0, 0.32442, rtol=ztol[method][0])
    assert allclose(z1, 8.18551, rtol=ztol[method][1])


@pytest.mark.skipif(not HAS_SCIPY, reason="test requires scipy")
@pytest.mark.parametrize(
    "cosmo",
    [
        Planck13,
        Planck15,
        Planck18,
        WMAP1,
        WMAP3,
        WMAP5,
        WMAP7,
        WMAP9,
        flrw.LambdaCDM,
        flrw.FlatLambdaCDM,
        flrw.wpwaCDM,
        flrw.w0wzCDM,
        flrw.wCDM,
        flrw.FlatwCDM,
        flrw.w0waCDM,
        flrw.Flatw0waCDM,
    ],
)
def test_z_at_value_roundtrip(cosmo):
    """
    Calculate values from a known redshift, and then check that
    z_at_value returns the right answer.
    """
    z = 0.5

    # Skip Ok, w, de_density_scale because in the Planck cosmologies
    # they are redshift independent and hence uninvertable,
    # *_distance_z1z2 methods take multiple arguments, so require
    # special handling
    # clone is not a redshift-dependent method
    # nu_relative_density is not redshift-dependent in the WMAP cosmologies
    skip = (
        "Ok",
        "Otot",
        "angular_diameter_distance_z1z2",
        "clone",
        "is_equivalent",
        "de_density_scale",
        "w",
    )
    if str(cosmo.name).startswith("WMAP"):
        skip += ("nu_relative_density",)

    methods = inspect.getmembers(cosmo, predicate=inspect.ismethod)

    for name, func in methods:
        if name.startswith("_") or name in skip:
            continue
        fval = func(z)
        # we need zmax here to pick the right solution for
        # angular_diameter_distance and related methods.
        # Be slightly more generous with rtol than the default 1e-8
        # used in z_at_value
        got = z_at_value(func, fval, bracket=[0.3, 1.0], ztol=1e-12)
        assert allclose(got, z, rtol=2e-11), f"Round-trip testing {name} failed"

    # Test distance functions between two redshifts; only for realizations
    if isinstance(cosmo.name, str):
        z2 = 2.0
        func_z1z2 = [
            lambda z1: cosmo._comoving_distance_z1z2(z1, z2),
            lambda z1: cosmo._comoving_transverse_distance_z1z2(z1, z2),
            lambda z1: cosmo.angular_diameter_distance_z1z2(z1, z2),
        ]
        for func in func_z1z2:
            fval = func(z)
            assert allclose(z, z_at_value(func, fval, zmax=1.5, ztol=1e-12), rtol=2e-11)
