# Licensed under a 3-clause BSD style license - see LICENSE.rst

import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_almost_equal_nulp, assert_equal

import astropy.units as u
from astropy.stats.biweight import (
    biweight_location,
    biweight_midcorrelation,
    biweight_midcovariance,
    biweight_midvariance,
    biweight_scale,
)
from astropy.tests.helper import assert_quantity_allclose
from astropy.utils.misc import NumpyRNGContext


def test_biweight_location():
    with NumpyRNGContext(12345):
        # test that it runs
        randvar = np.random.randn(10000)
        val = biweight_location(randvar)
        assert abs(val) < 1e-2


def test_biweight_location_constant():
    arr = np.ones((10, 5))
    val = biweight_location(arr)
    assert val == 1.0
    val = biweight_location(arr, axis=(0, 1))
    assert val == 1.0


def test_biweight_location_constant_axis_2d():
    shape = (10, 5)
    data = np.ones(shape)
    val = biweight_location(data, axis=0)
    assert_allclose(val, np.ones(shape[1]))
    val = biweight_location(data, axis=1)
    assert_allclose(val, np.ones(shape[0]))

    val1 = 100.0
    val2 = 2.0
    val3 = 5.0
    data = np.arange(50).reshape(10, 5)
    data[2] = val1
    data[7] = val2
    data[8] = [val3, 0.8, val3, -0.8, val3]
    val = biweight_location(data, axis=1)
    assert_allclose(val[2], val1)
    assert_allclose(val[7], val2)
    assert_allclose(val[8], val3)


def test_biweight_location_constant_axis_3d():
    shape = (10, 5, 2)
    data = np.ones(shape)
    val = biweight_location(data, axis=0)
    assert_allclose(val, np.ones((shape[1], shape[2])))
    val = biweight_location(data, axis=1)
    assert_allclose(val, np.ones((shape[0], shape[2])))
    val = biweight_location(data, axis=2)
    assert_allclose(val, np.ones((shape[0], shape[1])))


def test_biweight_location_small():
    bw_loc = biweight_location([1, 3, 5, 500, 2])
    assert_allclose(bw_loc, 2.7456117)


def test_biweight_location_axis():
    """Test a 2D array with the axis keyword."""
    with NumpyRNGContext(12345):
        ny = 100
        nx = 200
        data = np.random.normal(5, 2, (ny, nx))

        bw = biweight_location(data, axis=0)
        bwi = []
        for i in range(nx):
            bwi.append(biweight_location(data[:, i]))
        bwi = np.array(bwi)
        assert_allclose(bw, bwi)

        bw = biweight_location(data, axis=1)
        bwi = []
        for i in range(ny):
            bwi.append(biweight_location(data[i, :]))
        bwi = np.array(bwi)
        assert_allclose(bw, bwi)


@pytest.mark.parametrize("nx, ny, nz", [(5, 4, 3), (5, 1, 3), (1, 4, 3), (1, 1, 5)])
def test_biweight_location_axis_3d(nx, ny, nz):
    """Test a 3D array with the axis keyword."""
    with NumpyRNGContext(12345):
        data = np.random.normal(5, 2, (nz, ny, nx))
        bw = biweight_location(data, axis=0)
        assert bw.shape == (ny, nx)

        y = 0
        bwi = []
        for i in range(nx):
            bwi.append(biweight_location(data[:, y, i]))
        bwi = np.array(bwi)
        assert_allclose(bw[y], bwi)


@pytest.mark.parametrize("shape", [(2, 3, 4), (12, 1, 2), (1, 12, 2), (1, 1, 24)])
def test_biweight_location_axis_tuple(shape):
    """Test a 3D array with a tuple axis keyword."""

    data = np.arange(24).reshape(shape)
    data[0, 0] = 100.0

    assert_equal(biweight_location(data, axis=0), biweight_location(data, axis=(0,)))
    assert_equal(biweight_location(data, axis=-1), biweight_location(data, axis=(2,)))
    assert_equal(
        biweight_location(data, axis=(0, 1)), biweight_location(data, axis=(1, 0))
    )
    assert_equal(
        biweight_location(data, axis=(0, 2)), biweight_location(data, axis=(0, -1))
    )
    assert_equal(
        biweight_location(data, axis=(0, 1, 2)), biweight_location(data, axis=(2, 0, 1))
    )
    assert_equal(
        biweight_location(data, axis=(0, 1, 2)), biweight_location(data, axis=None)
    )


@pytest.mark.filterwarnings("ignore:All-NaN slice encountered")
@pytest.mark.filterwarnings("ignore:Invalid value encountered in median")
def test_biweight_location_ignore_nan():
    data1d = np.array([1, 3, 5, 500, 2, np.nan])
    data2d = np.array([data1d, data1d])

    assert np.isnan(biweight_location(data1d, ignore_nan=False))

    biw_expected = biweight_location(data1d[:-1], ignore_nan=False)
    assert_equal(biweight_location(data1d, ignore_nan=True), biw_expected)

    assert_equal(biweight_location(data2d, axis=0, ignore_nan=True), data1d)
    assert_equal(
        biweight_location(data2d, axis=1, ignore_nan=True), [biw_expected, biw_expected]
    )


@pytest.mark.filterwarnings("ignore:All-NaN slice encountered")
@pytest.mark.filterwarnings("ignore:Invalid value encountered in median")
def test_biweight_location_nan():
    data1d = np.array([1, 3, 5, 500, 2, np.nan])
    all_nan = data1d.copy()
    all_nan[:] = np.nan
    data2d = np.array([data1d, data1d, all_nan])
    data1d_masked = np.ma.masked_invalid(data1d)
    data1d_masked.data[0] = np.nan
    data2d_masked = np.ma.masked_invalid(data2d)

    assert np.isnan(biweight_location(data1d))
    bw_loc = biweight_location(data1d_masked)
    assert not isinstance(bw_loc, np.ma.MaskedArray)
    assert np.isnan(biweight_location(data2d))

    for axis in (0, 1):
        assert np.all(np.isnan(biweight_location(data2d, axis=axis)))
        assert isinstance(
            biweight_location(data2d_masked, axis=axis), np.ma.MaskedArray
        )


@pytest.mark.filterwarnings("ignore:All-NaN slice encountered")
@pytest.mark.filterwarnings("ignore:Invalid value encountered in median")
def test_biweight_location_masked():
    data1d = np.array([1, 3, 5, 500, 2, np.nan])
    data2d = np.array([data1d, data1d])

    data1d_masked = np.ma.masked_invalid(data1d)
    data2d_masked = np.ma.masked_invalid(data2d)

    assert_equal(
        biweight_location(data1d, ignore_nan=True), biweight_location(data1d_masked)
    )
    assert_equal(
        biweight_location(data2d, ignore_nan=True), biweight_location(data2d_masked)
    )

    bw_loc = biweight_location(data1d_masked)
    assert_allclose(bw_loc, 2.7456117)
    assert np.isscalar(bw_loc)

    bw_loc = biweight_location(data2d, ignore_nan=True, axis=1)
    bw_loc_masked = biweight_location(data2d_masked, axis=1)
    assert isinstance(bw_loc_masked, np.ma.MaskedArray)
    assert ~np.any(bw_loc_masked.mask)  # mask is all False
    assert_equal(bw_loc, bw_loc_masked.data)

    bw_loc = biweight_location(data2d, ignore_nan=True, axis=0)
    bw_loc_masked = biweight_location(data2d_masked, axis=0)
    assert_equal(bw_loc_masked.data[:-1], bw_loc[:-1])
    assert bw_loc_masked.mask[-1]  # last mask element is True

    data1d_masked.data[0] = np.nan  # unmasked NaN
    bw_loc = biweight_location(data1d_masked)
    assert not isinstance(bw_loc, np.ma.MaskedArray)
    assert np.isscalar(bw_loc)
    assert np.isnan(bw_loc)
    assert_equal(
        biweight_location(data1d_masked, ignore_nan=True),
        biweight_location(data1d[1:], ignore_nan=True),
    )

    # ensure that input masked array is not modified
    assert np.isnan(data1d_masked[0])


def test_biweight_scale():
    # NOTE:  biweight_scale is covered by biweight_midvariance tests
    data = [1, 3, 5, 500, 2]
    scl = biweight_scale(data)
    var = biweight_midvariance(data)
    assert_allclose(scl, np.sqrt(var))

    data = np.ma.masked_invalid([1, 3, 5, 500, 2, np.nan])
    data[0] = np.nan
    scl = biweight_scale(data, ignore_nan=True)
    var = biweight_midvariance(data, ignore_nan=True)
    assert_allclose(scl, np.sqrt(var))


def test_biweight_midvariance():
    with NumpyRNGContext(12345):
        # test that it runs
        randvar = np.random.randn(10000)
        var = biweight_midvariance(randvar)
        assert_allclose(var, 1.0, rtol=0.02)


def test_biweight_midvariance_small():
    data = [1, 3, 5, 500, 2]
    var = biweight_midvariance(data)
    assert_allclose(var, 2.9238456)  # verified with R

    var = biweight_midvariance(data, modify_sample_size=True)
    assert_allclose(var, 2.3390765)


def test_biweight_midvariance_5127():
    # test a regression introduced in #5127
    rand = np.random.default_rng(12345)
    data = rand.normal(loc=0.0, scale=20.0, size=(100, 100))
    var = biweight_midvariance(data)
    assert_allclose(var, 409.87135608846205)


def test_biweight_midvariance_axis():
    """Test a 2D array with the axis keyword."""
    with NumpyRNGContext(12345):
        ny = 100
        nx = 200
        data = np.random.normal(5, 2, (ny, nx))

        bw = biweight_midvariance(data, axis=0)
        bwi = []
        for i in range(nx):
            bwi.append(biweight_midvariance(data[:, i]))
        bwi = np.array(bwi)
        assert_allclose(bw, bwi)

        bw = biweight_midvariance(data, axis=1)
        bwi = []
        for i in range(ny):
            bwi.append(biweight_midvariance(data[i, :]))
        bwi = np.array(bwi)
        assert_allclose(bw, bwi)


@pytest.mark.parametrize("nx, ny, nz", [(5, 4, 3), (5, 1, 3), (1, 4, 3), (1, 1, 5)])
def test_biweight_midvariance_axis_3d(nx, ny, nz):
    """Test a 3D array with the axis keyword."""
    with NumpyRNGContext(12345):
        data = np.random.normal(5, 2, (nz, ny, nx))
        bw = biweight_midvariance(data, axis=0)
        assert bw.shape == (ny, nx)

        y = 0
        bwi = []
        for i in range(nx):
            bwi.append(biweight_midvariance(data[:, y, i]))
        bwi = np.array(bwi)
        assert_allclose(bw[y], bwi)


@pytest.mark.filterwarnings("ignore:All-NaN slice encountered")
@pytest.mark.filterwarnings("ignore:Invalid value encountered in median")
def test_biweight_midvariance_ignore_nan():
    data1d = np.array([1, 3, 5, 500, 2, np.nan])
    data2d = np.array([data1d, data1d])

    assert np.isnan(biweight_midvariance(data1d, ignore_nan=False))

    biw_var = biweight_midvariance(data1d[:-1], ignore_nan=False)
    biw_var_nonan = biweight_midvariance(data1d, ignore_nan=True)
    assert_equal(biw_var_nonan, biw_var)

    assert_equal(
        biweight_midvariance(data2d, axis=0, ignore_nan=True),
        [0.0, 0.0, 0.0, 0.0, 0.0, np.nan],
    )
    assert_equal(
        biweight_midvariance(data2d, axis=1, ignore_nan=True),
        [biw_var_nonan, biw_var_nonan],
    )


@pytest.mark.filterwarnings("ignore:All-NaN slice encountered")
@pytest.mark.filterwarnings("ignore:Invalid value encountered in median")
def test_biweight_scale_nan():
    data1d = np.array([1, 3, 5, 500, 2, np.nan])
    all_nan = data1d.copy()
    all_nan[:] = np.nan
    data2d = np.array([data1d, data1d, all_nan])
    data1d_masked = np.ma.masked_invalid(data1d)
    data1d_masked.data[0] = np.nan
    data2d_masked = np.ma.masked_invalid(data2d)

    assert np.isnan(biweight_scale(data1d))
    bw_scl = biweight_scale(data1d_masked)
    assert not isinstance(bw_scl, np.ma.MaskedArray)
    assert np.isnan(bw_scl)
    assert np.isnan(biweight_scale(data2d))
    assert_allclose(biweight_scale(data2d_masked), 1.709926, atol=1e-5)

    for axis in (0, 1):
        assert np.all(np.isnan(biweight_scale(data2d, axis=axis)))
        assert isinstance(biweight_scale(data2d_masked, axis=axis), np.ma.MaskedArray)


@pytest.mark.filterwarnings("ignore:All-NaN slice encountered")
@pytest.mark.filterwarnings("ignore:Invalid value encountered in median")
def test_biweight_midvariance_masked():
    data1d = np.array([1, 3, 5, 500, 2, np.nan])
    data2d = np.array([data1d, data1d])

    data1d_masked = np.ma.masked_invalid(data1d)
    data2d_masked = np.ma.masked_invalid(data2d)

    assert_allclose(
        biweight_midvariance(data1d, ignore_nan=True),
        biweight_midvariance(data1d_masked),
    )
    assert_allclose(
        biweight_midvariance(data2d, ignore_nan=True),
        biweight_midvariance(data2d_masked),
    )

    bw_scl = biweight_midvariance(data1d_masked)
    assert_allclose(bw_scl, 2.9238456)
    assert np.isscalar(bw_scl)

    bw_loc = biweight_midvariance(data2d, ignore_nan=True, axis=1)
    bw_loc_masked = biweight_midvariance(data2d_masked, axis=1)
    assert isinstance(bw_loc_masked, np.ma.MaskedArray)
    assert ~np.any(bw_loc_masked.mask)  # mask is all False
    assert_allclose(bw_loc, bw_loc_masked.data)

    bw_loc = biweight_midvariance(data2d, ignore_nan=True, axis=0)
    bw_loc_masked = biweight_midvariance(data2d_masked, axis=0)
    assert_allclose(bw_loc_masked.data[:-1], bw_loc[:-1])
    assert bw_loc_masked.mask[-1]  # last mask element is True

    data1d_masked.data[0] = np.nan  # unmasked NaN
    bw_scl = biweight_midvariance(data1d_masked)
    assert not isinstance(bw_scl, np.ma.MaskedArray)
    assert np.isscalar(bw_scl)
    assert np.isnan(bw_scl)
    assert_allclose(
        biweight_midvariance(data1d_masked, ignore_nan=True),
        biweight_midvariance(data1d[1:], ignore_nan=True),
    )

    # ensure that input masked array is not modified
    assert np.isnan(data1d_masked[0])


def test_biweight_scale_axis_tuple():
    """Test a 3D array with a tuple axis keyword."""

    data = np.arange(24).reshape(2, 3, 4)
    data[0, 0] = 100.0

    assert_equal(biweight_scale(data, axis=0), biweight_scale(data, axis=(0,)))
    assert_equal(biweight_scale(data, axis=-1), biweight_scale(data, axis=(2,)))
    assert_equal(biweight_scale(data, axis=(0, 1)), biweight_scale(data, axis=(1, 0)))
    assert_equal(biweight_scale(data, axis=(0, 2)), biweight_scale(data, axis=(0, -1)))
    assert_equal(
        biweight_scale(data, axis=(0, 1, 2)), biweight_scale(data, axis=(2, 0, 1))
    )
    assert_equal(biweight_scale(data, axis=(0, 1, 2)), biweight_scale(data, axis=None))
    assert_equal(
        biweight_scale(data, axis=(0, 2), modify_sample_size=True),
        biweight_scale(data, axis=(0, -1), modify_sample_size=True),
    )


def test_biweight_midvariance_constant_axis():
    bw = biweight_midvariance(np.ones((10, 5)))
    assert bw == 0.0


def test_biweight_midvariance_constant_axis_2d():
    shape = (10, 5)
    data = np.ones(shape)
    cbl = biweight_midvariance(data, axis=0)
    assert_allclose(cbl, np.zeros(shape[1]))
    cbl = biweight_midvariance(data, axis=1)
    assert_allclose(cbl, np.zeros(shape[0]))

    data = np.arange(50).reshape(10, 5)
    data[2] = 100.0
    data[7] = 2.0
    data[8] = [5.0, 0.8, 5.0, -0.8, 5.0]
    bw = biweight_midvariance(data, axis=1)
    assert_allclose(bw[2], 0.0)
    assert_allclose(bw[7], 0.0)
    assert_allclose(bw[8], 0.0)


def test_biweight_midvariance_constant_axis_3d():
    shape = (10, 5, 2)
    data = np.ones(shape)
    cbl = biweight_midvariance(data, axis=0)
    assert_allclose(cbl, np.zeros((shape[1], shape[2])))
    cbl = biweight_midvariance(data, axis=1)
    assert_allclose(cbl, np.zeros((shape[0], shape[2])))
    cbl = biweight_midvariance(data, axis=2)
    assert_allclose(cbl, np.zeros((shape[0], shape[1])))


def test_biweight_midcovariance_1d():
    d = [0, 1, 2]
    cov = biweight_midcovariance(d)
    var = biweight_midvariance(d)
    assert_allclose(cov, [[var]])


def test_biweight_midcovariance_2d():
    d = [[0, 1, 2], [2, 1, 0]]
    cov = biweight_midcovariance(d)
    val = 0.70121809
    assert_allclose(cov, [[val, -val], [-val, val]])  # verified with R

    d = [[5, 1, 10], [500, 5, 2]]
    cov = biweight_midcovariance(d)
    assert_allclose(
        cov,
        [[14.54159077, -7.79026256], [-7.79026256, 6.92087252]],  # verified with R
    )

    cov = biweight_midcovariance(d, modify_sample_size=True)
    assert_allclose(cov, [[14.54159077, -5.19350838], [-5.19350838, 4.61391501]])


def test_biweight_midcovariance_constant():
    data = np.ones((3, 10))
    val3 = 5.0
    data[1] = [val3, 0.8, val3, -0.8, val3, val3, val3, 1.0, val3, -0.7]
    cov = biweight_midcovariance(data)
    assert_allclose(cov, np.zeros((3, 3)))

    rng = np.random.default_rng(123)
    data = rng.random((5, 5))
    val3 = 5.0
    data[1] = [val3, 0.8, val3, -0.8, val3]
    cov = biweight_midcovariance(data)
    assert_allclose(cov[1, :], 0.0)
    assert_allclose(cov[:, 1], 0.0)


def test_biweight_midcovariance_midvariance():
    """
    Test that biweight_midcovariance diagonal elements agree with
    biweight_midvariance.
    """

    rng = np.random.default_rng(1)
    d = rng.normal(0, 2, size=(100, 3))
    cov = biweight_midcovariance(d)
    var = [biweight_midvariance(a) for a in d]
    assert_allclose(cov.diagonal(), var)

    cov2 = biweight_midcovariance(d, modify_sample_size=True)
    var2 = [biweight_midvariance(a, modify_sample_size=True) for a in d]
    assert_allclose(cov2.diagonal(), var2)


def test_midcovariance_shape():
    """
    Test that biweight_midcovariance raises error with a 3D array.
    """

    d = np.ones(27).reshape(3, 3, 3)
    with pytest.raises(ValueError, match=r"The input array must be 2D or 1D\."):
        biweight_midcovariance(d)


def test_midcovariance_M_shape():
    """
    Test that biweight_midcovariance raises error when M is not a scalar
    or 1D array.
    """

    d = [0, 1, 2]
    M = [[0, 1], [2, 3]]
    with pytest.raises(ValueError, match=r"M must be a scalar or 1D array\."):
        biweight_midcovariance(d, M=M)


def test_biweight_midcovariance_symmetric():
    """
    Regression test to ensure that midcovariance matrix is symmetric
    when ``modify_sample_size=True`` (see #5972).
    """

    rng = np.random.default_rng(1)
    d = rng.gamma(2, 2, size=(3, 500))
    cov = biweight_midcovariance(d)
    assert_array_almost_equal_nulp(cov, cov.T, nulp=5)

    cov = biweight_midcovariance(d, modify_sample_size=True)
    assert_array_almost_equal_nulp(cov, cov.T, nulp=5)


def test_biweight_midcorrelation():
    x = [0, 1, 2]
    y = [2, 1, 0]
    assert_allclose(biweight_midcorrelation(x, x), 1.0)
    assert_allclose(biweight_midcorrelation(x, y), -1.0)

    x = [5, 1, 10, 12.4, 13.2]
    y = [500, 5, 2, 7.1, 0.9]
    # verified with R
    assert_allclose(biweight_midcorrelation(x, y), -0.14411038976763313)


def test_biweight_midcorrelation_inputs():
    a1 = np.ones((3, 3))
    a2 = np.ones(5)
    a3 = np.ones(7)

    with pytest.raises(ValueError, match=r"x must be a 1D array\."):
        biweight_midcorrelation(a1, a2)

    with pytest.raises(ValueError, match=r"y must be a 1D array\."):
        biweight_midcorrelation(a2, a1)

    with pytest.raises(ValueError, match=r"x and y must have the same shape\."):
        biweight_midcorrelation(a2, a3)


def test_biweight_32bit_runtime_warnings():
    """Regression test for #6905."""
    with NumpyRNGContext(12345):
        data = np.random.random(100).astype(np.float32)
        data[50] = 30000.0
        biweight_scale(data)
        biweight_midvariance(data)


def test_biweight_scl_var_constant_units():
    unit = u.km
    data = np.ones(10) << unit
    biwscl = biweight_scale(data)
    biwvar = biweight_midvariance(data)
    assert isinstance(biwscl, u.Quantity)
    assert isinstance(biwvar, u.Quantity)
    assert_quantity_allclose(biwscl, 0.0 << unit)
    assert_quantity_allclose(biwvar, 0.0 << unit**2)

    data = np.ones(10) << unit
    data[0] = np.nan
    biwscl = biweight_scale(data)
    biwvar = biweight_midvariance(data)
    assert isinstance(biwscl, u.Quantity)
    assert isinstance(biwvar, u.Quantity)
    assert_quantity_allclose(biwscl, np.nan << unit)
    assert_quantity_allclose(biwvar, np.nan << unit**2)
