from __future__ import annotations

import numpy as np
import pandas as pd
import pytest

import dask.dataframe as dd

rs = np.random.RandomState(96)


@pytest.mark.parametrize(
    "df",
    [
        pd.DataFrame(
            {
                "x": [1, 2, 3] * 3,
                "y": [1.2, 3.4, 5.6] * 3,
                "z": -(np.arange(9, dtype=np.int8)),
            }
        ),
        pd.DataFrame(
            {
                "x": rs.randint(0, 1000000, (10000,)),
                "y": rs.randn(10000),
                "z": rs.uniform(0, 9999999, (10000,)),
            }
        ),
        pd.DataFrame(
            {
                "x": np.repeat(rs.randint(0, 1000000, (1000,)), 3),
                "y": np.repeat(rs.randn(1000), 3),
                "z": np.repeat(rs.uniform(0, 9999999, (1000,)), 3),
            }
        ),
        pd.DataFrame({"x": rs.randint(0, 1000000, (10000,))}),
        pd.DataFrame(
            {
                "x": rs.randint(0, 1000000, (7,)),
                "y": ["a", "bet", "is", "a", "tax", "on", "bs"],
            }
        ),
        pd.DataFrame(
            {
                "w": np.zeros((20000,)),
                "x": np.zeros((20000,)),
                "y": np.zeros((20000,)) + 4803592,
                "z": np.zeros((20000,)),
            }
        ),
        pd.DataFrame({"x": [1, 2, 3] * 1000}),
        pd.DataFrame({"x": np.random.random(1000)}),
        pd.DataFrame(
            {
                "a": [1, 2, 3] * 3,
                "b": [1.2, 3.4, 5.6] * 3,
                "c": [1 + 2j, 3 + 4j, 5 + 6j] * 3,
                "d": -(np.arange(9, dtype=np.int8)),
            }
        ),
        pd.Series([1, 2, 3] * 1000),
        pd.Series(np.random.random(1000)),
        pd.Series(np.random.random(1000), index=np.ones(1000)),
        pd.Series(np.random.random(1000), index=np.random.random(1000)),
    ],
)
@pytest.mark.parametrize("npartitions", [2, 20])
def test_basic(df, npartitions):
    ddf = dd.from_pandas(df, npartitions=npartitions)

    approx = ddf.nunique_approx().compute(scheduler="sync")
    exact = len(df.drop_duplicates())
    assert abs(approx - exact) <= 2 or abs(approx - exact) / exact < 0.05


@pytest.mark.parametrize("split_every", [None, 2, 10])
@pytest.mark.parametrize("npartitions", [2, 20])
def test_split_every(split_every, npartitions):
    df = pd.Series([1, 2, 3] * 1000)
    ddf = dd.from_pandas(df, npartitions=npartitions)

    approx = ddf.nunique_approx(split_every=split_every).compute(scheduler="sync")
    exact = len(df.drop_duplicates())
    assert abs(approx - exact) <= 2 or abs(approx - exact) / exact < 0.05


def test_larger_data():
    df = dd.demo.make_timeseries(
        "2000-01-01",
        "2000-04-01",
        {"value": float, "id": int},
        freq="10s",
        partition_freq="1D",
        seed=1,
    )
    assert df.nunique_approx().compute() > 1000
