from __future__ import annotations

import warnings
from itertools import product

import pytest

from dask._task_spec import Task

np = pytest.importorskip("numpy")
import math

import dask
import dask.array as da
from dask.array.rechunk import (
    _breakpoints,
    _intersect_1d,
    cumdims_label,
    divide_to_width,
    intersect_chunks,
    merge_to_number,
    normalize_chunks,
    old_to_new,
    plan_rechunk,
    rechunk,
)
from dask.array.utils import assert_eq
from dask.utils import funcname


def test_rechunk_internals_1():
    """Test the cumdims_label and _breakpoints and
    _intersect_1d internal funcs to rechunk."""
    new = cumdims_label(((1, 1, 2), (1, 5, 1)), "n")
    old = cumdims_label(((4,), (1,) * 5), "o")
    breaks = tuple(_breakpoints(o, n) for o, n in zip(old, new))
    answer = (("o", 0), ("n", 0), ("n", 1), ("n", 2), ("o", 4), ("n", 4))
    assert breaks[0] == answer
    answer2 = (
        ("o", 0),
        ("n", 0),
        ("o", 1),
        ("n", 1),
        ("o", 2),
        ("o", 3),
        ("o", 4),
        ("o", 5),
        ("n", 6),
        ("n", 7),
    )
    assert breaks[1] == answer2
    i1d = [_intersect_1d(b) for b in breaks]
    answer3 = [[(0, slice(0, 1))], [(0, slice(1, 2))], [(0, slice(2, 4))]]
    assert i1d[0] == answer3
    answer4 = [
        [(0, slice(0, 1))],
        [
            (1, slice(0, 1)),
            (2, slice(0, 1)),
            (3, slice(0, 1)),
            (4, slice(0, 1)),
            (5, slice(0, 1)),
        ],
        [(5, slice(1, 2))],
    ]
    assert i1d[1] == answer4

    new = cumdims_label(((1, 1, 2), (1, 5, 1, 0)), "n")
    breaks = tuple(_breakpoints(o, n) for o, n in zip(old, new))
    answer5 = (
        ("o", 0),
        ("n", 0),
        ("o", 1),
        ("n", 1),
        ("o", 2),
        ("o", 3),
        ("o", 4),
        ("o", 5),
        ("n", 6),
        ("n", 7),
        ("n", 7),
    )
    assert breaks[1] == answer5
    i1d = [_intersect_1d(b) for b in breaks]
    answer6 = [
        [(0, slice(0, 1))],
        [
            (1, slice(0, 1)),
            (2, slice(0, 1)),
            (3, slice(0, 1)),
            (4, slice(0, 1)),
            (5, slice(0, 1)),
        ],
        [(5, slice(1, 2))],
        [(5, slice(2, 2))],
    ]
    assert i1d[1] == answer6


def test_intersect_1():
    """Convert 1 D chunks"""
    old = ((10, 10, 10, 10, 10),)
    new = ((25, 5, 20),)
    answer = [
        (((0, slice(0, 10)),), ((1, slice(0, 10)),), ((2, slice(0, 5)),)),
        (((2, slice(5, 10)),),),
        (((3, slice(0, 10)),), ((4, slice(0, 10)),)),
    ]
    cross = list(intersect_chunks(old_chunks=old, new_chunks=new))
    assert answer == cross


def test_intersect_2():
    """Convert 1 D chunks"""
    old = ((20, 20, 20, 20, 20),)
    new = ((58, 4, 20, 18),)
    answer = [
        (((0, slice(0, 20)),), ((1, slice(0, 20)),), ((2, slice(0, 18)),)),
        (((2, slice(18, 20)),), ((3, slice(0, 2)),)),
        (((3, slice(2, 20)),), ((4, slice(0, 2)),)),
        (((4, slice(2, 20)),),),
    ]
    cross = list(intersect_chunks(old_chunks=old, new_chunks=new))
    assert answer == cross


def test_rechunk_1d():
    """Try rechunking a random 1d matrix"""
    a = np.random.default_rng().uniform(0, 1, 30)
    x = da.from_array(a, chunks=((10,) * 3,))
    new = ((5,) * 6,)
    x2 = rechunk(x, chunks=new)
    assert x2.chunks == new
    assert np.all(x2.compute() == a)


def test_rechunk_2d():
    """Try rechunking a random 2d matrix"""
    a = np.random.default_rng().uniform(0, 1, 300).reshape((10, 30))
    x = da.from_array(a, chunks=((1, 2, 3, 4), (5,) * 6))
    new = ((5, 5), (15,) * 2)
    x2 = rechunk(x, chunks=new)
    assert x2.chunks == new
    assert np.all(x2.compute() == a)


def test_rechunk_4d():
    """Try rechunking a random 4d matrix"""
    old = ((5, 5),) * 4
    a = np.random.default_rng().uniform(0, 1, 10000).reshape((10,) * 4)
    x = da.from_array(a, chunks=old)
    new = ((10,),) * 4
    x2 = rechunk(x, chunks=new)
    assert x2.chunks == new
    assert np.all(x2.compute() == a)


def test_rechunk_expand():
    a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
    x = da.from_array(a, chunks=(5, 5))
    y = x.rechunk(chunks=((3, 3, 3, 1), (3, 3, 3, 1)))
    assert np.all(y.compute() == a)


def test_rechunk_expand2():
    (a, b) = (3, 2)
    orig = np.random.default_rng().uniform(0, 1, a**b).reshape((a,) * b)
    for off, off2 in product(range(1, a - 1), range(1, a - 1)):
        old = ((a - off, off),) * b
        x = da.from_array(orig, chunks=old)
        new = ((a - off2, off2),) * b
        assert np.all(x.rechunk(chunks=new).compute() == orig)
        if a - off - off2 > 0:
            new = ((off, a - off2 - off, off2),) * b
            y = x.rechunk(chunks=new).compute()
            assert np.all(y == orig)


def test_rechunk_method():
    """Test rechunking can be done as a method of dask array."""
    old = ((5, 2, 3),) * 4
    new = ((3, 3, 3, 1),) * 4
    a = np.random.default_rng().uniform(0, 1, 10000).reshape((10,) * 4)
    x = da.from_array(a, chunks=old)
    x2 = x.rechunk(chunks=new)
    assert x2.chunks == new
    assert np.all(x2.compute() == a)


def test_rechunk_blockshape():
    """Test that blockshape can be used."""
    new_shape, new_chunks = (10, 10), (4, 3)
    new_blockdims = normalize_chunks(new_chunks, new_shape)
    old_chunks = ((4, 4, 2), (3, 3, 3, 1))
    a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
    x = da.from_array(a, chunks=old_chunks)
    check1 = rechunk(x, chunks=new_chunks)
    assert check1.chunks == new_blockdims
    assert np.all(check1.compute() == a)


def test_dtype():
    x = da.ones(5, chunks=(2,))
    assert x.rechunk(chunks=(1,)).dtype == x.dtype


def test_rechunk_with_dict():
    x = da.ones((24, 24), chunks=(4, 8))
    y = x.rechunk(chunks={0: 12})
    assert y.chunks == ((12, 12), (8, 8, 8))

    x = da.ones((24, 24), chunks=(4, 8))
    y = x.rechunk(chunks={0: (12, 12)})
    assert y.chunks == ((12, 12), (8, 8, 8))

    x = da.ones((24, 24), chunks=(4, 8))
    y = x.rechunk(chunks={0: -1})
    assert y.chunks == ((24,), (8, 8, 8))

    x = da.ones((24, 24), chunks=(4, 8))
    y = x.rechunk(chunks={0: None, 1: "auto"})
    assert y.chunks == ((4, 4, 4, 4, 4, 4), (24,))


def test_rechunk_with_empty_input():
    x = da.ones((24, 24), chunks=(4, 8))
    assert x.rechunk(chunks={}).chunks == x.chunks
    pytest.raises(ValueError, lambda: x.rechunk(chunks=()))


def test_rechunk_with_null_dimensions():
    x = da.from_array(np.ones((24, 24)), chunks=(4, 8))
    assert x.rechunk(chunks=(None, 4)).chunks == da.ones((24, 24), chunks=(4, 4)).chunks
    assert (
        x.rechunk(chunks={0: None, 1: 4}).chunks
        == da.ones((24, 24), chunks=(4, 4)).chunks
    )


def test_rechunk_with_integer():
    x = da.from_array(np.arange(5), chunks=4)
    y = x.rechunk(3)
    assert y.chunks == ((3, 2),)
    assert (x.compute() == y.compute()).all()


def test_rechunk_0d():
    a = np.array(42)
    x = da.from_array(a, chunks=())
    y = x.rechunk(())
    assert y.chunks == ()
    assert y.compute() == a


@pytest.mark.parametrize(
    "arr", [da.array([]), da.array([[], []]), da.array([[[]], [[]]])]
)
def test_rechunk_empty_array(arr):
    arr.rechunk()
    assert arr.size == 0


def test_rechunk_empty():
    x = da.ones((0, 10), chunks=(5, 5))
    y = x.rechunk((2, 2))
    assert y.chunks == ((0,), (2,) * 5)
    assert_eq(x, y)


def test_rechunk_zero_dim_array():
    x = da.zeros((4, 0), chunks=3)
    y = x.rechunk({0: 4})
    assert y.chunks == ((4,), (0,))
    assert_eq(x, y)


def test_rechunk_zero_dim_array_II():
    x = da.zeros((4, 0, 6, 10), chunks=3)
    y = x.rechunk({0: 4, 2: 2})
    assert y.chunks == ((4,), (0,), (2, 2, 2), (3, 3, 3, 1))
    assert_eq(x, y)


def test_rechunk_same():
    x = da.ones((24, 24), chunks=(4, 8))
    y = x.rechunk(x.chunks)
    assert x is y


def test_rechunk_same_fully_unknown():
    pytest.importorskip("pandas")
    dd = pytest.importorskip("dask.dataframe")
    x = da.ones(shape=(10, 10), chunks=(5, 10))
    y = dd.from_array(x).values
    new_chunks = ((np.nan, np.nan), (10,))
    assert y.chunks == new_chunks
    result = y.rechunk(new_chunks)
    assert y is result


def test_rechunk_same_fully_unknown_floats():
    """Similar to test_rechunk_same_fully_unknown but testing the behavior if
    ``float("nan")`` is used instead of the recommended ``np.nan``
    """
    pytest.importorskip("pandas")
    dd = pytest.importorskip("dask.dataframe")
    x = da.ones(shape=(10, 10), chunks=(5, 10))
    y = dd.from_array(x).values
    new_chunks = ((float("nan"), float("nan")), (10,))
    result = y.rechunk(new_chunks)
    assert y is result


def test_rechunk_same_partially_unknown():
    pytest.importorskip("pandas")
    dd = pytest.importorskip("dask.dataframe")
    x = da.ones(shape=(10, 10), chunks=(5, 10))
    y = dd.from_array(x).values
    z = da.concatenate([x, y])
    new_chunks = ((5, 5, np.nan, np.nan), (10,))
    assert z.chunks == new_chunks
    result = z.rechunk(new_chunks)
    assert z is result


def test_rechunk_with_zero_placeholders():
    x = da.ones((24, 24), chunks=((12, 12), (24, 0)))
    y = da.ones((24, 24), chunks=((12, 12), (12, 12)))
    y = y.rechunk(((12, 12), (24, 0)))
    assert x.chunks == y.chunks


def test_rechunk_minus_one():
    x = da.ones((24, 24), chunks=(4, 8))
    y = x.rechunk((-1, 8))
    assert y.chunks == ((24,), (8, 8, 8))
    assert_eq(x, y)


def test_rechunk_intermediates():
    x = da.random.default_rng().normal(10, 0.1, (10, 10), chunks=(10, 1))
    y = x.rechunk((1, 10))
    assert len(y.dask) > 30


def test_divide_to_width():
    chunks = divide_to_width((8, 9, 10), 10)
    assert chunks == (8, 9, 10)
    chunks = divide_to_width((8, 2, 9, 10, 11, 12), 4)
    # Note how 9 gives (3, 3, 3), not (4, 4, 1) or whatever
    assert chunks == (4, 4, 2, 3, 3, 3, 3, 3, 4, 3, 4, 4, 4, 4, 4)


def test_merge_to_number():
    chunks = merge_to_number((10,) * 4, 5)
    assert chunks == (10, 10, 10, 10)
    chunks = merge_to_number((10,) * 4, 4)
    assert chunks == (10, 10, 10, 10)
    chunks = merge_to_number((10,) * 4, 3)
    assert chunks == (20, 10, 10)
    chunks = merge_to_number((10,) * 4, 2)
    assert chunks == (20, 20)
    chunks = merge_to_number((10,) * 4, 1)
    assert chunks == (40,)

    chunks = merge_to_number((10,) * 10, 2)
    assert chunks == (50,) * 2
    chunks = merge_to_number((10,) * 10, 3)
    assert chunks == (40, 30, 30)

    chunks = merge_to_number((5, 1, 1, 15, 10), 4)
    assert chunks == (5, 2, 15, 10)
    chunks = merge_to_number((5, 1, 1, 15, 10), 3)
    assert chunks == (7, 15, 10)
    chunks = merge_to_number((5, 1, 1, 15, 10), 2)
    assert chunks == (22, 10)
    chunks = merge_to_number((5, 1, 1, 15, 10), 1)
    assert chunks == (32,)

    chunks = merge_to_number((1, 1, 1, 1, 3, 1, 1), 6)
    assert chunks == (2, 1, 1, 3, 1, 1)
    chunks = merge_to_number((1, 1, 1, 1, 3, 1, 1), 5)
    assert chunks == (2, 2, 3, 1, 1)
    chunks = merge_to_number((1, 1, 1, 1, 3, 1, 1), 4)
    assert chunks == (2, 2, 3, 2)
    chunks = merge_to_number((1, 1, 1, 1, 3, 1, 1), 3)
    assert chunks == (4, 3, 2)
    chunks = merge_to_number((1, 1, 1, 1, 3, 1, 1), 2)
    assert chunks == (4, 5)
    chunks = merge_to_number((1, 1, 1, 1, 3, 1, 1), 1)
    assert chunks == (9,)


def _plan(old_chunks, new_chunks, itemsize=1, block_size_limit=1e7, threshold=4):
    return plan_rechunk(
        old_chunks,
        new_chunks,
        itemsize=itemsize,
        block_size_limit=block_size_limit,
        threshold=threshold,
    )


def _assert_steps(steps, expected):
    assert len(steps) == len(expected)
    assert steps == expected


def test_plan_rechunk():
    c = (20,) * 2  # coarse
    f = (2,) * 20  # fine
    nc = (np.nan,) * 2  # nan-coarse
    nf = (np.nan,) * 20  # nan-fine

    # Trivial cases
    steps = _plan((), ())
    _assert_steps(steps, [()])
    steps = _plan((c, ()), (f, ()))
    _assert_steps(steps, [(f, ())])

    # No intermediate required
    steps = _plan((c,), (f,))
    _assert_steps(steps, [(f,)])
    steps = _plan((f,), (c,))
    _assert_steps(steps, [(c,)])
    steps = _plan((c, c), (f, f))
    _assert_steps(steps, [(f, f)])
    steps = _plan((f, f), (c, c))
    _assert_steps(steps, [(c, c)])
    steps = _plan((f, c), (c, c))
    _assert_steps(steps, [(c, c)])
    steps = _plan((c, c, c, c), (c, f, c, c))
    _assert_steps(steps, [(c, f, c, c)])

    # An intermediate is used to reduce graph size
    steps = _plan((f, c), (c, f))
    _assert_steps(steps, [(c, c), (c, f)])

    steps = _plan((c + c, c + f), (f + f, c + c))
    _assert_steps(steps, [(c + c, c + c), (f + f, c + c)])

    # Same, with unknown dim
    steps = _plan((nc + nf, c + c, c + f), (nc + nf, f + f, c + c))
    _assert_steps(steps, steps)

    # Regression test for #5908
    steps = _plan((c, c), (f, f), threshold=1)
    _assert_steps(steps, [(f, f)])

    # Just at the memory limit => an intermediate is used
    steps = _plan((f, c), (c, f), block_size_limit=400)
    _assert_steps(steps, [(c, c), (c, f)])

    # Hitting the memory limit => partial merge
    m = (10,) * 4  # mid

    steps = _plan((f, c), (c, f), block_size_limit=399)
    _assert_steps(steps, [(m, c), (c, f)])

    steps2 = _plan((f, c), (c, f), block_size_limit=3999, itemsize=10)
    _assert_steps(steps2, steps)

    # Larger problem size => more intermediates
    c = (1000,) * 2  # coarse
    f = (2,) * 1000  # fine

    steps = _plan((f, c), (c, f), block_size_limit=99999)
    assert len(steps) == 3
    assert steps[-1] == (c, f)
    for i in range(len(steps) - 1):
        prev = steps[i]
        succ = steps[i + 1]
        # Merging on the first dim, splitting on the second dim
        assert len(succ[0]) <= len(prev[0]) / 2.0
        assert len(succ[1]) >= len(prev[1]) * 2.0


def test_plan_rechunk_5d():
    # 5d problem
    c = (10,) * 1  # coarse
    f = (1,) * 10  # fine

    steps = _plan((c, c, c, c, c), (f, f, f, f, f))
    _assert_steps(steps, [(f, f, f, f, f)])
    steps = _plan((f, f, f, f, c), (c, c, c, f, f))
    _assert_steps(steps, [(c, c, c, f, c), (c, c, c, f, f)])
    # Only 1 dim can be merged at first
    steps = _plan((c, c, f, f, c), (c, c, c, f, f), block_size_limit=2e4)
    _assert_steps(steps, [(c, c, c, f, c), (c, c, c, f, f)])


def test_plan_rechunk_heterogeneous():
    c = (10,) * 1  # coarse
    f = (1,) * 10  # fine
    cf = c + f
    cc = c + c
    ff = f + f
    fc = f + c

    # No intermediate required
    steps = _plan((cc, cf), (ff, ff))
    _assert_steps(steps, [(ff, ff)])
    steps = _plan((cf, fc), (ff, cf))
    _assert_steps(steps, [(ff, cf)])

    # An intermediate is used to reduce graph size
    steps = _plan((cc, cf), (ff, cc))
    _assert_steps(steps, [(cc, cc), (ff, cc)])

    steps = _plan((cc, cf, cc), (ff, cc, cf))
    _assert_steps(steps, [(cc, cc, cc), (ff, cc, cf)])

    # Imposing a memory limit => the first intermediate is constrained:
    #  * cc -> ff would increase the graph size: no
    #  * ff -> cf would increase the block size too much: no
    #  * cf -> cc fits the bill (graph size /= 10, block size neutral)
    #  * cf -> fc also fits the bill (graph size and block size neutral)
    steps = _plan((cc, ff, cf), (ff, cf, cc), block_size_limit=100)
    _assert_steps(steps, [(cc, ff, cc), (ff, cf, cc)])


def test_plan_rechunk_asymmetric():
    a = ((1,) * 1000, (80000000,))
    b = ((1000,), (80000,) * 1000)
    steps = plan_rechunk(a, b, itemsize=8)
    assert len(steps) > 1

    x = da.ones((1000, 80000000), chunks=(1, 80000000))
    y = x.rechunk((1000, x.shape[1] // 1000))
    assert len(y.dask) < 100000


def test_rechunk_warning():
    N = 20
    x = da.random.default_rng().normal(size=(N, N, 100), chunks=(1, N, 100))
    with warnings.catch_warnings(record=True) as w:
        x = x.rechunk((N, 1, 100))

    assert not w


@pytest.mark.parametrize(
    "shape,chunks", [[(4,), (2,)], [(4, 4), (2, 2)], [(4, 4), (4, 2)]]
)
def test_dont_concatenate_single_chunks(shape, chunks):
    x = da.ones(shape, chunks=shape)
    y = x.rechunk(chunks)
    dsk = dict(y.dask)
    assert not any(
        funcname(task.func).startswith("concat")
        for task in dsk.values()
        if isinstance(task, Task)
    )


def test_intersect_nan():
    old_chunks = ((np.nan, np.nan), (8,))
    new_chunks = ((np.nan, np.nan), (4, 4))

    result = list(intersect_chunks(old_chunks, new_chunks))
    expected = [
        (((0, slice(0, None, None)), (0, slice(0, 4, None))),),
        (((0, slice(0, None, None)), (0, slice(4, 8, None))),),
        (((1, slice(0, None, None)), (0, slice(0, 4, None))),),
        (((1, slice(0, None, None)), (0, slice(4, 8, None))),),
    ]
    assert result == expected


def test_intersect_nan_single():
    old_chunks = ((np.nan,), (10,))
    new_chunks = ((np.nan,), (5, 5))

    result = list(intersect_chunks(old_chunks, new_chunks))
    expected = [
        (((0, slice(0, None, None)), (0, slice(0, 5, None))),),
        (((0, slice(0, None, None)), (0, slice(5, 10, None))),),
    ]
    assert result == expected


def test_intersect_nan_long():
    old_chunks = (tuple([np.nan] * 4), (10,))
    new_chunks = (tuple([np.nan] * 4), (5, 5))
    result = list(intersect_chunks(old_chunks, new_chunks))
    expected = [
        (((0, slice(0, None, None)), (0, slice(0, 5, None))),),
        (((0, slice(0, None, None)), (0, slice(5, 10, None))),),
        (((1, slice(0, None, None)), (0, slice(0, 5, None))),),
        (((1, slice(0, None, None)), (0, slice(5, 10, None))),),
        (((2, slice(0, None, None)), (0, slice(0, 5, None))),),
        (((2, slice(0, None, None)), (0, slice(5, 10, None))),),
        (((3, slice(0, None, None)), (0, slice(0, 5, None))),),
        (((3, slice(0, None, None)), (0, slice(5, 10, None))),),
    ]
    assert result == expected


def test_rechunk_unknown_from_pandas():
    pd = pytest.importorskip("pandas")
    dd = pytest.importorskip("dask.dataframe")

    arr = np.random.default_rng().standard_normal((50, 10))
    x = dd.from_pandas(pd.DataFrame(arr), 2).values
    result = x.rechunk((None, (5, 5)))
    assert np.isnan(x.chunks[0]).all()
    assert np.isnan(result.chunks[0]).all()
    assert result.chunks[1] == (5, 5)
    expected = da.from_array(arr, chunks=((25, 25), (10,))).rechunk((None, (5, 5)))
    assert_eq(result, expected)


def test_rechunk_unknown_from_array():
    pytest.importorskip("pandas")
    dd = pytest.importorskip("dask.dataframe")
    # pd = pytest.importorskip('pandas')
    x = dd.from_array(da.ones(shape=(4, 4), chunks=(2, 2))).values
    # result = x.rechunk({1: 5})
    result = x.rechunk((None, 4))
    assert np.isnan(x.chunks[0]).all()
    assert np.isnan(result.chunks[0]).all()
    assert x.chunks[1] == (4,)
    assert_eq(x, result)


@pytest.mark.parametrize(
    "x, chunks",
    [
        (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
        (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
        (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
        (da.ones(shape=(1000, 10), chunks=(5, 10)), (None, 5)),
        (da.ones(shape=(1000, 10), chunks=(5, 10)), {1: 5}),
        (da.ones(shape=(1000, 10), chunks=(5, 10)), (None, (5, 5))),
        (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
        (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
        (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
        (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
        (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
        (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
    ],
)
def test_rechunk_with_fully_unknown_dimension(x, chunks):
    pytest.importorskip("pandas")
    dd = pytest.importorskip("dask.dataframe")
    y = dd.from_array(x).values
    result = y.rechunk(chunks)
    expected = x.rechunk(chunks)
    assert_chunks_match(result.chunks, expected.chunks)
    assert_eq(result, expected)


@pytest.mark.parametrize(
    "x, chunks",
    [
        (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
        (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
        (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
        (da.ones(shape=(1000, 10), chunks=(5, 10)), (None, 5)),
        (da.ones(shape=(1000, 10), chunks=(5, 10)), {1: 5}),
        (da.ones(shape=(1000, 10), chunks=(5, 10)), (None, (5, 5))),
        (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
        (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
        (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
        (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
        (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
        (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
    ],
)
def test_rechunk_with_partially_unknown_dimension(x, chunks):
    pytest.importorskip("pandas")
    dd = pytest.importorskip("dask.dataframe")
    y = dd.from_array(x).values
    z = da.concatenate([x, y])
    xx = da.concatenate([x, x])
    result = z.rechunk(chunks)
    expected = xx.rechunk(chunks)
    assert_chunks_match(result.chunks, expected.chunks)
    assert_eq(result, expected)


@pytest.mark.parametrize(
    "new_chunks",
    [
        ((np.nan, np.nan), (5, 5)),
        ((math.nan, math.nan), (5, 5)),
        ((float("nan"), float("nan")), (5, 5)),
    ],
)
def test_rechunk_with_fully_unknown_dimension_explicit(new_chunks):
    pytest.importorskip("pandas")
    dd = pytest.importorskip("dask.dataframe")
    x = da.ones(shape=(10, 10), chunks=(5, 2))
    y = dd.from_array(x).values
    result = y.rechunk(new_chunks)
    expected = x.rechunk((None, (5, 5)))
    assert_chunks_match(result.chunks, expected.chunks)
    assert_eq(result, expected)


@pytest.mark.parametrize(
    "new_chunks",
    [
        ((5, 5, np.nan, np.nan), (5, 5)),
        ((5, 5, math.nan, math.nan), (5, 5)),
        ((5, 5, float("nan"), float("nan")), (5, 5)),
    ],
)
def test_rechunk_with_partially_unknown_dimension_explicit(new_chunks):
    pytest.importorskip("pandas")
    dd = pytest.importorskip("dask.dataframe")
    x = da.ones(shape=(10, 10), chunks=(5, 2))
    y = dd.from_array(x).values
    z = da.concatenate([x, y])
    xx = da.concatenate([x, x])
    result = z.rechunk(new_chunks)
    expected = xx.rechunk((None, (5, 5)))
    assert_chunks_match(result.chunks, expected.chunks)
    assert_eq(result, expected)


def assert_chunks_match(left, right):
    for ldim, rdim in zip(left, right):
        assert all(np.isnan(l) or l == r for l, r in zip(ldim, rdim))


def test_rechunk_unknown_raises():
    pytest.importorskip("pandas")
    dd = pytest.importorskip("dask.dataframe")

    x = da.ones(shape=(10, 10), chunks=(5, 5))
    y = dd.from_array(x).values
    with pytest.raises(ValueError, match="Chunks do not add"):
        y.rechunk((None, (5, 5, 5)))

    with pytest.raises(ValueError, match="Chunks must be unchanging"):
        y.rechunk(((np.nan, np.nan, np.nan), (5, 5)))

    with pytest.raises(ValueError, match="Chunks must be unchanging"):
        y.rechunk(((5, 5), (5, 5)))

    with pytest.raises(ValueError, match="Chunks must be unchanging"):
        z = da.concatenate([x, y])
        z.rechunk(((5, 3, 2, np.nan, np.nan), (5, 5)))


def test_old_to_new_single():
    old = ((np.nan, np.nan), (8,))
    new = ((np.nan, np.nan), (4, 4))
    result = old_to_new(old, new)

    expected = [
        [[(0, slice(0, None, None))], [(1, slice(0, None, None))]],
        [[(0, slice(0, 4, None))], [(0, slice(4, 8, None))]],
    ]

    assert result == expected


def test_old_to_new():
    old = ((np.nan,), (10,))
    new = ((np.nan,), (5, 5))
    result = old_to_new(old, new)
    expected = [
        [[(0, slice(0, None, None))]],
        [[(0, slice(0, 5, None))], [(0, slice(5, 10, None))]],
    ]

    assert result == expected


def test_old_to_new_large():
    old = (tuple([np.nan] * 4), (10,))
    new = (tuple([np.nan] * 4), (5, 5))

    result = old_to_new(old, new)
    expected = [
        [
            [(0, slice(0, None, None))],
            [(1, slice(0, None, None))],
            [(2, slice(0, None, None))],
            [(3, slice(0, None, None))],
        ],
        [[(0, slice(0, 5, None))], [(0, slice(5, 10, None))]],
    ]
    assert result == expected


def test_old_to_new_known():
    old = ((10, 10, 10, 10, 10),)
    new = ((25, 5, 20),)
    result = old_to_new(old, new)
    expected = [
        [
            [(0, slice(0, 10, None)), (1, slice(0, 10, None)), (2, slice(0, 5, None))],
            [(2, slice(5, 10, None))],
            [(3, slice(0, 10, None)), (4, slice(0, 10, None))],
        ]
    ]
    assert result == expected


def test_rechunk_zero_dim():
    da = pytest.importorskip("dask.array")

    x = da.ones((0, 10, 100), chunks=(0, 10, 10)).rechunk((0, 10, 50))
    assert len(x.compute()) == 0


def test_rechunk_empty_chunks():
    x = da.zeros((7, 24), chunks=((7,), (10, 0, 0, 9, 0, 5)))
    y = x.rechunk((2, 3))
    assert_eq(x, y)


def test_rechunk_avoid_needless_chunking():
    x = da.ones(16, chunks=2)
    y = x.rechunk(8)
    dsk = y.__dask_graph__()
    assert len(dsk) <= 8 + 2


@pytest.mark.parametrize(
    "shape,chunks,bs,expected",
    [
        (100, 1, 10, (10,) * 10),
        (100, 50, 10, (10,) * 10),
        (100, 100, 10, (10,) * 10),
        (20, 7, 10, (7, 7, 6)),
        (
            20,
            (1, 1, 1, 1, 6, 2, 1, 7),
            5,
            (5, 5, 5, 5),
        ),  # This should be smarter
        (21, (1,) * 21, 10, (10, 10, 1)),  # ensure that we squash together properly
        (30, (3,) * 10, 7, (6,) * 5),  # ensure that we squash together properly
        (20, ((2, 2, 2, 9, 1, 2, 2)), 5, (4, 4, 4, 4, 4)),
        (20, ((10, 10)), 4, (4,) * 5),  # this should split branches evenly
        (21, ((10, 11)), 4, (4, 4, 4, 4, 4, 1)),  # this should split branches evenly
        (20, ((1, 18, 1)), 5, (5,) * 4),
        (38, ((10, 18, 10)), 5, (5, 5, 5, 5, 5, 5, 5, 3)),
    ],
)
def test_rechunk_auto_1d(shape, chunks, bs, expected):
    x = da.ones(shape, chunks=(chunks,))
    y = x.rechunk({0: "auto"}, block_size_limit=bs * x.dtype.itemsize)
    assert y.chunks == (expected,)


@pytest.mark.parametrize(
    "previous_chunks,bs,expected",
    [
        (((1, 1, 1), (10, 10, 10, 10, 10, 10, 10, 10)), 160, ((3,), (50, 30))),
        (((2, 2), (20,)), 5, ((1, 1, 1, 1), (5,) * 4)),
        (((1, 1), (20,)), 5, ((1, 1), (5,) * 4)),
    ],
)
def test_normalize_chunks_auto_2d(previous_chunks, bs, expected):
    shape = tuple(map(sum, previous_chunks))
    result = normalize_chunks(
        {0: "auto", 1: "auto"},
        shape,
        limit=bs,
        dtype=np.int8,
        previous_chunks=previous_chunks,
    )
    assert result == expected


def test_rechunk_auto_2d():
    x = da.ones((20, 20), chunks=(2, 2))
    y = x.rechunk({0: -1, 1: "auto"}, block_size_limit=20 * x.dtype.itemsize)
    assert y.chunks == ((20,), (1,) * 20)

    x = da.ones((20, 20), chunks=(2, 2))
    y = x.rechunk((-1, "auto"), block_size_limit=80 * x.dtype.itemsize)
    assert y.chunks == ((20,), (4,) * 5)

    x = da.ones((20, 20), chunks=((2, 2)))
    y = x.rechunk({0: "auto"}, block_size_limit=20 * x.dtype.itemsize)
    assert y.chunks[1] == x.chunks[1]
    assert y.chunks[0] == (10, 10)

    x = da.ones((20, 20), chunks=((2,) * 10, (2, 2, 2, 2, 2, 5, 5)))
    y = x.rechunk({0: "auto"}, block_size_limit=20 * x.dtype.itemsize)
    assert y.chunks[1] == x.chunks[1]
    assert y.chunks[0] == (4, 4, 4, 4, 4)  # limited by largest


def test_rechunk_auto_3d():
    x = da.ones((20, 20, 20), chunks=((2, 2, 2)))
    y = x.rechunk({0: "auto", 1: "auto"}, block_size_limit=200 * x.dtype.itemsize)
    assert y.chunks[2] == x.chunks[2]
    assert y.chunks[0] == (10, 10)
    assert y.chunks[1] == (10, 10)  # even split


@pytest.mark.parametrize("n", [100, 1000])
def test_rechunk_auto_image_stack(n):
    with dask.config.set({"array.chunk-size": "10MiB"}):
        x = da.ones((n, 1000, 1000), chunks=(1, 1000, 1000), dtype="uint8")
        y = x.rechunk("auto")
        assert y.chunks == ((10,) * (n // 10), (1000,), (1000,))
        assert y.rechunk("auto").chunks == y.chunks  # idempotent

    with dask.config.set({"array.chunk-size": "7MiB"}):
        z = x.rechunk("auto")
        if n == 100:
            assert z.chunks == ((7,) * 14 + (2,), (1000,), (1000,))
        else:
            assert z.chunks == ((7,) * 142 + (6,), (1000,), (1000,))

    with dask.config.set({"array.chunk-size": "1MiB"}):
        x = da.ones((n, 1000, 1000), chunks=(1, 1000, 1000), dtype="float64")
        z = x.rechunk("auto")
        assert z.chunks == ((1,) * n, (362, 362, 276), (362, 362, 276))

    with dask.config.set({"array.chunk-size": "1MiB"}):
        x = da.ones((n, 2000, 2000), chunks=(1, 1000, 1000), dtype="float64")
        z = x.rechunk("auto")
        assert z.chunks == (
            (1,) * n,
            (362, 362, 362, 362, 362, 190),
            (362, 362, 362, 362, 362, 190),
        )


def test_rechunk_down():
    with dask.config.set({"array.chunk-size": "10MiB"}):
        x = da.ones((100, 1000, 1000), chunks=(1, 1000, 1000), dtype="uint8")
        y = x.rechunk("auto")
        assert y.chunks == ((10,) * 10, (1000,), (1000,))

    with dask.config.set({"array.chunk-size": "1MiB"}):
        z = y.rechunk("auto")
        assert z.chunks == ((4,) * 25, (511, 489), (511, 489))

    with dask.config.set({"array.chunk-size": "1MiB"}):
        z = y.rechunk({0: "auto"})
        assert z.chunks == ((1,) * 100, (1000,), (1000,))

        z = y.rechunk({1: "auto"})
        assert z.chunks == ((10,) * 10, (104,) * 9 + (64,), (1000,))


def test_rechunk_zero():
    with dask.config.set({"array.chunk-size": "1B"}):
        x = da.ones(10, chunks=(5,))
        y = x.rechunk("auto")
        assert y.chunks == ((1,) * 10,)


def test_rechunk_bad_keys():
    x = da.zeros((2, 3, 4), chunks=1)
    assert x.rechunk({-1: 4}).chunks == ((1, 1), (1, 1, 1), (4,))
    assert x.rechunk({-x.ndim: 2}).chunks == ((2,), (1, 1, 1), (1, 1, 1, 1))

    with pytest.raises(TypeError) as info:
        x.rechunk({"blah": 4})

    assert "blah" in str(info.value)

    with pytest.raises(ValueError) as info:
        x.rechunk({100: 4})

    assert "100" in str(info.value)

    with pytest.raises(ValueError) as info:
        x.rechunk({-100: 4})

    assert "-100" in str(info.value)


def test_balance_basics():
    arr_len = 220

    x = da.from_array(np.arange(arr_len), chunks=100)
    balanced = x.rechunk(chunks=100, balance=True)
    unbalanced = x.rechunk(chunks=100, balance=False)
    assert unbalanced.chunks[0] == (100, 100, 20)
    assert balanced.chunks[0] == (110, 110)


def test_balance_chunks_unchanged():
    arr_len = 220

    x = da.from_array(np.arange(arr_len))
    balanced = x.rechunk(chunks=100, balance=True)
    unbalanced = x.rechunk(chunks=100, balance=False)
    assert unbalanced.chunks[0] == (100, 100, 20)
    assert balanced.chunks[0] == (110, 110)


def test_balance_small():
    arr_len = 13

    x = da.from_array(np.arange(arr_len))
    balanced = x.rechunk(chunks=4, balance=True)
    unbalanced = x.rechunk(chunks=4, balance=False)
    assert balanced.chunks[0] == (5, 5, 3)
    assert unbalanced.chunks[0] == (4, 4, 4, 1)

    arr_len = 7

    x = da.from_array(np.arange(arr_len))
    balanced = x.rechunk(chunks=3, balance=True)
    unbalanced = x.rechunk(chunks=3, balance=False)
    assert balanced.chunks[0] == (4, 3)
    assert unbalanced.chunks[0] == (3, 3, 1)


def test_balance_n_chunks_size():
    arr_len = 100
    n_chunks = 8

    x = da.from_array(np.arange(arr_len))
    balanced = x.rechunk(chunks=arr_len // n_chunks, balance=True)
    unbalanced = x.rechunk(chunks=arr_len // n_chunks, balance=False)
    assert balanced.chunks[0] == (13,) * 7 + (9,)
    assert unbalanced.chunks[0] == (12,) * 8 + (4,)


def test_balance_raises():
    arr_len = 100
    n_chunks = 11

    x = da.from_array(np.arange(arr_len))
    with pytest.warns(UserWarning, match="Try increasing the chunk size"):
        balanced = x.rechunk(chunks=arr_len // n_chunks, balance=True)
    unbalanced = x.rechunk(chunks=arr_len // n_chunks, balance=False)
    assert balanced.chunks == unbalanced.chunks

    n_chunks = 10
    x.rechunk(chunks=arr_len // n_chunks, balance=True)


def test_balance_basics_2d():
    N = 210

    x = da.from_array(np.random.default_rng().uniform(size=(N, N)))
    balanced = x.rechunk(chunks=(100, 100), balance=True)
    unbalanced = x.rechunk(chunks=(100, 100), balance=False)
    assert unbalanced.chunks == ((100, 100, 10), (100, 100, 10))
    assert balanced.chunks == ((105, 105), (105, 105))


def test_balance_2d_negative_dimension():
    N = 210

    x = da.from_array(np.random.default_rng().uniform(size=(N, N)))
    balanced = x.rechunk(chunks=(100, -1), balance=True)
    unbalanced = x.rechunk(chunks=(100, -1), balance=False)
    assert unbalanced.chunks == ((100, 100, 10), (N,))
    assert balanced.chunks == ((105, 105), (N,))


def test_balance_different_inputs():
    N = 210

    x = da.from_array(np.random.default_rng().uniform(size=(N, N)))
    balanced = x.rechunk(chunks=("10MB", -1), balance=True)
    unbalanced = x.rechunk(chunks=("10MB", -1), balance=False)
    assert balanced.chunks == unbalanced.chunks
    assert balanced.chunks[1] == (N,)


def test_balance_split_into_n_chunks():
    # Some prime numbers around 1000
    array_lens = [
        991,
        997,
        1009,
        1013,
        1019,
        1021,
        1031,
        1033,
        1039,
        1049,
        1051,
        1061,
        1063,
        1069,
    ]

    for N in array_lens:
        for nchunks in range(1, 20):
            x = da.from_array(np.random.default_rng().uniform(size=N))
            y = x.rechunk(chunks=len(x) // nchunks, balance=True)
            assert len(y.chunks[0]) == nchunks


def test_rechunk_with_zero():
    a = da.ones((8, 8), chunks=(4, 4))
    result = a.rechunk(((4, 4), (4, 0, 0, 4)))
    expected = da.ones((8, 8), chunks=((4, 4), (4, 0, 0, 4)))

    # reverse:
    a, expected = expected, a
    result = a.rechunk((4, 4))
    assert_eq(result, expected)


def test_intersect_chunks_with_nonzero():
    from dask.array.rechunk import intersect_chunks

    old = ((4, 4), (2,))
    new = ((8,), (1, 1))
    result = list(intersect_chunks(old, new))
    expected = [
        (
            ((0, slice(0, 4, None)), (0, slice(0, 1, None))),
            ((1, slice(0, 4, None)), (0, slice(0, 1, None))),
        ),
        (
            ((0, slice(0, 4, None)), (0, slice(1, 2, None))),
            ((1, slice(0, 4, None)), (0, slice(1, 2, None))),
        ),
    ]
    assert result == expected


def test_intersect_chunks_with_zero():
    from dask.array.rechunk import intersect_chunks

    old = ((4, 4), (2,))
    new = ((4, 0, 0, 4), (1, 1))
    result = list(intersect_chunks(old, new))

    expected = [
        (((0, slice(0, 4, None)), (0, slice(0, 1, None))),),
        (((0, slice(0, 4, None)), (0, slice(1, 2, None))),),
        (((1, slice(0, 0, None)), (0, slice(0, 1, None))),),
        (((1, slice(0, 0, None)), (0, slice(1, 2, None))),),
        (((1, slice(0, 0, None)), (0, slice(0, 1, None))),),
        (((1, slice(0, 0, None)), (0, slice(1, 2, None))),),
        (((1, slice(0, 4, None)), (0, slice(0, 1, None))),),
        (((1, slice(0, 4, None)), (0, slice(1, 2, None))),),
    ]

    assert result == expected

    old = ((4, 0, 0, 4), (1, 1))
    new = ((4, 4), (2,))
    result = list(intersect_chunks(old, new))

    expected = [
        (
            ((0, slice(0, 4, None)), (0, slice(0, 1, None))),
            ((0, slice(0, 4, None)), (1, slice(0, 1, None))),
        ),
        (
            ((3, slice(0, 4, None)), (0, slice(0, 1, None))),
            ((3, slice(0, 4, None)), (1, slice(0, 1, None))),
        ),
    ]

    assert result == expected

    old = ((4, 4), (2,))
    new = ((2, 0, 0, 2, 4), (1, 1))
    result = list(intersect_chunks(old, new))
    expected = [
        (((0, slice(0, 2, None)), (0, slice(0, 1, None))),),
        (((0, slice(0, 2, None)), (0, slice(1, 2, None))),),
        (((0, slice(2, 2, None)), (0, slice(0, 1, None))),),
        (((0, slice(2, 2, None)), (0, slice(1, 2, None))),),
        (((0, slice(2, 2, None)), (0, slice(0, 1, None))),),
        (((0, slice(2, 2, None)), (0, slice(1, 2, None))),),
        (((0, slice(2, 4, None)), (0, slice(0, 1, None))),),
        (((0, slice(2, 4, None)), (0, slice(1, 2, None))),),
        (((1, slice(0, 4, None)), (0, slice(0, 1, None))),),
        (((1, slice(0, 4, None)), (0, slice(1, 2, None))),),
    ]

    assert result == expected

    old = ((4, 4), (2,))
    new = ((0, 0, 4, 4), (1, 1))
    result = list(intersect_chunks(old, new))
    expected = [
        (((0, slice(0, 0, None)), (0, slice(0, 1, None))),),
        (((0, slice(0, 0, None)), (0, slice(1, 2, None))),),
        (((0, slice(0, 0, None)), (0, slice(0, 1, None))),),
        (((0, slice(0, 0, None)), (0, slice(1, 2, None))),),
        (((0, slice(0, 4, None)), (0, slice(0, 1, None))),),
        (((0, slice(0, 4, None)), (0, slice(1, 2, None))),),
        (((1, slice(0, 4, None)), (0, slice(0, 1, None))),),
        (((1, slice(0, 4, None)), (0, slice(1, 2, None))),),
    ]

    assert result == expected


def test_old_to_new_with_zero():
    from dask.array.rechunk import old_to_new

    old = ((4, 4),)
    new = ((4, 0, 4),)
    result = old_to_new(old, new)
    expected = [[[(0, slice(0, 4))], [(1, slice(0, 0))], [(1, slice(0, 4))]]]
    assert result == expected

    old = ((4,),)
    new = ((4, 0),)
    result = old_to_new(old, new)
    expected = [[[(0, slice(0, 4))], [(0, slice(4, 4))]]]
    assert result == expected

    old = ((4, 0, 4),)
    new = ((4, 0, 2, 2),)
    result = old_to_new(old, new)
    expected = [
        [[(0, slice(0, 4))], [(2, slice(0, 0))], [(2, slice(0, 2))], [(2, slice(2, 4))]]
    ]
    assert result == expected


def test_rechunk_non_perfect_slicing_of_dimensions():
    # GH#7859
    # this matters -- 1060 and 1058 work
    shape = (200, 100, 1059)
    final_chunks = (64, 64, 64)

    arr = da.coarsen(
        da.mean,
        da.zeros(shape, chunks=(1, -1, -1)),
        {0: 2, 1: 2, 2: 2},
        trim_excess=True,
    )
    result = arr.rechunk(*final_chunks)
    assert_eq(arr, result)
    result_b = arr.rechunk(final_chunks)
    assert_eq(arr, result)
    assert result.chunks == result_b.chunks
