import numpy as np
import pytest
from numpy.testing import assert_allclose

from astropy.stats.spatial import RipleysKEstimator
from astropy.utils.misc import NumpyRNGContext

a = np.array([[1, 4], [2, 5], [3, 6]])
b = np.array([[-1, 1], [-2, 2], [-3, 3]])


@pytest.mark.parametrize("points, x_min, x_max", [(a, 0, 10), (b, -5, 5)])
def test_ripley_K_implementation(points, x_min, x_max):
    """
    Test against Ripley's K function implemented in R package `spatstat`
        +-+---------+---------+----------+---------+-+
      6 +                                          * +
        |                                            |
        |                                            |
    5.5 +                                            +
        |                                            |
        |                                            |
      5 +                     *                      +
        |                                            |
    4.5 +                                            +
        |                                            |
        |                                            |
      4 + *                                          +
        +-+---------+---------+----------+---------+-+
          1        1.5        2         2.5        3

        +-+---------+---------+----------+---------+-+
      3 + *                                          +
        |                                            |
        |                                            |
    2.5 +                                            +
        |                                            |
        |                                            |
      2 +                     *                      +
        |                                            |
    1.5 +                                            +
        |                                            |
        |                                            |
      1 +                                          * +
        +-+---------+---------+----------+---------+-+
         -3       -2.5       -2        -1.5       -1
    """

    area = 100
    r = np.linspace(0, 2.5, 5)
    Kest = RipleysKEstimator(
        area=area, x_min=x_min, y_min=x_min, x_max=x_max, y_max=x_max
    )

    ANS_NONE = np.array([0, 0, 0, 66.667, 66.667])
    assert_allclose(ANS_NONE, Kest(data=points, radii=r, mode="none"), atol=1e-3)

    ANS_TRANS = np.array([0, 0, 0, 82.304, 82.304])
    assert_allclose(
        ANS_TRANS, Kest(data=points, radii=r, mode="translation"), atol=1e-3
    )


with NumpyRNGContext(123):
    a = np.random.uniform(low=5, high=10, size=(100, 2))
    b = np.random.uniform(low=-5, high=-10, size=(100, 2))


@pytest.mark.parametrize("points", [a, b])
def test_ripley_uniform_property(points):
    # Ripley's K function without edge-correction converges to the area when
    # the number of points and the argument radii are large enough, i.e.,
    # K(x) --> area as x --> inf
    area = 50
    Kest = RipleysKEstimator(area=area)
    r = np.linspace(0, 20, 5)
    assert_allclose(area, Kest(data=points, radii=r, mode="none")[4])


with NumpyRNGContext(123):
    a = np.random.uniform(low=0, high=1, size=(500, 2))
    b = np.random.uniform(low=-1, high=0, size=(500, 2))


@pytest.mark.parametrize("points, low, high", [(a, 0, 1), (b, -1, 0)])
def test_ripley_large_density(points, low, high):
    Kest = RipleysKEstimator(area=1, x_min=low, x_max=high, y_min=low, y_max=high)
    r = np.linspace(0, 0.25, 25)
    Kpos = Kest.poisson(r)
    modes = ["ohser", "translation", "ripley"]
    for m in modes:
        Kest_r = Kest(data=points, radii=r, mode=m)
        assert_allclose(Kpos, Kest_r, atol=1e-1)


with NumpyRNGContext(123):
    a = np.random.uniform(low=5, high=10, size=(500, 2))
    b = np.random.uniform(low=-10, high=-5, size=(500, 2))


@pytest.mark.parametrize("points, low, high", [(a, 5, 10), (b, -10, -5)])
def test_ripley_modes(points, low, high):
    Kest = RipleysKEstimator(area=25, x_max=high, y_max=high, x_min=low, y_min=low)
    r = np.linspace(0, 1.2, 25)
    Kpos_mean = np.mean(Kest.poisson(r))
    modes = ["ohser", "translation", "ripley"]
    for m in modes:
        Kest_mean = np.mean(Kest(data=points, radii=r, mode=m))
        assert_allclose(Kpos_mean, Kest_mean, atol=1e-1, rtol=1e-1)


with NumpyRNGContext(123):
    a = np.random.uniform(low=0, high=1, size=(50, 2))
    b = np.random.uniform(low=-1, high=0, size=(50, 2))


@pytest.mark.parametrize("points, low, high", [(a, 0, 1), (b, -1, 0)])
def test_ripley_large_density_var_width(points, low, high):
    Kest = RipleysKEstimator(area=1, x_min=low, x_max=high, y_min=low, y_max=high)
    r = np.linspace(0, 0.25, 25)
    Kpos = Kest.poisson(r)
    Kest_r = Kest(data=points, radii=r, mode="var-width")
    assert_allclose(Kpos, Kest_r, atol=1e-1)


with NumpyRNGContext(123):
    a = np.random.uniform(low=5, high=10, size=(50, 2))
    b = np.random.uniform(low=-10, high=-5, size=(50, 2))


@pytest.mark.parametrize("points, low, high", [(a, 5, 10), (b, -10, -5)])
def test_ripley_var_width(points, low, high):
    Kest = RipleysKEstimator(area=25, x_max=high, y_max=high, x_min=low, y_min=low)
    r = np.linspace(0, 1.2, 25)
    Kest_ohser = np.mean(Kest(data=points, radii=r, mode="ohser"))
    Kest_var_width = np.mean(Kest(data=points, radii=r, mode="var-width"))
    assert_allclose(Kest_ohser, Kest_var_width, atol=1e-1, rtol=1e-1)
