# Licensed under a 3-clause BSD style license - see LICENSE.rst

"""
This module tests some of the methods related to YAML serialization.
"""

from io import StringIO

import numpy as np
import pytest

from astropy import units as u
from astropy.coordinates import (
    Angle,
    CartesianDifferential,
    CartesianRepresentation,
    EarthLocation,
    Latitude,
    Longitude,
    SkyCoord,
    SphericalCosLatDifferential,
    SphericalDifferential,
    SphericalRepresentation,
    UnitSphericalRepresentation,
)
from astropy.coordinates.tests.test_representation import representation_equal
from astropy.io.misc.yaml import dump, load, load_all
from astropy.table import QTable, SerializedColumn
from astropy.time import Time


@pytest.mark.parametrize(
    "c",
    [
        True,
        np.uint8(8),
        np.int16(4),
        np.int32(1),
        np.int64(3),
        np.int64(2**63 - 1),
        2.0,
        np.float64(),
        3 + 4j,
        np.complex_(3 + 4j),
        np.complex64(3 + 4j),
        np.complex128(1.0 - 2**-52 + 1j * (1.0 - 2**-52)),
    ],
)
def test_numpy_types(c):
    cy = load(dump(c))
    assert c == cy


@pytest.mark.parametrize(
    "c", [u.m, u.m / u.s, u.hPa, u.dimensionless_unscaled, u.Unit("m, (cm, um)")]
)
def test_unit(c):
    cy = load(dump(c))
    if isinstance(c, (u.CompositeUnit, u.StructuredUnit)):
        assert c == cy
    else:
        assert c is cy


@pytest.mark.parametrize("c", [u.Unit("bakers_dozen", 13 * u.one), u.def_unit("magic")])
def test_custom_unit(c):
    s = dump(c)
    with pytest.warns(u.UnitsWarning, match=f"'{c!s}' did not parse") as w:
        cy = load(s)
    assert len(w) == 1
    assert isinstance(cy, u.UnrecognizedUnit)
    assert str(cy) == str(c)

    with u.add_enabled_units(c):
        cy2 = load(s)
        assert cy2 is c


@pytest.mark.parametrize(
    "c",
    [
        Angle("1 2 3", unit="deg"),
        Longitude("1 2 3", unit="deg"),
        Latitude("1 2 3", unit="deg"),
        [[1], [3]] * u.m,
        np.array([[1, 2], [3, 4]], order="F"),
        np.array([[1, 2], [3, 4]], order="C"),
        np.array([1, 2, 3, 4])[::2],
        np.array([(1.0, 2), (3.0, 4)], dtype="f8,i4"),  # array with structured dtype.
        np.array((1.0, 2), dtype="f8,i4"),  # array scalar with structured dtype.
        np.array((1.0, 2), dtype="f8,i4")[()],  # numpy void.
        np.array((1.0, 2.0), dtype="f8,f8") * u.s,  # Quantity structured scalar.
        [
            ((1.0, 2.0, 3.0), (4.0, 5.0, 6.0)),  # Quantity with structured unit.
            ((11.0, 12.0, 13.0), (14.0, 15.0, 16.0)),
        ]
        * u.Unit("m, m/s"),
        np.array(
            [
                ((1.0, 2.0, 3.0), (4.0, 5.0, 6.0)),
                ((11.0, 12.0, 13.0), (14.0, 15.0, 16.0)),
            ],
            dtype=[("p", "3f8"), ("v", "3f8")],
        )
        * u.Unit("m, m/s"),
    ],
)
def test_ndarray_subclasses(c):
    cy = load(dump(c))

    assert np.all(c == cy)
    assert c.shape == cy.shape
    assert c.dtype == cy.dtype
    assert type(c) is type(cy)

    cc = "C_CONTIGUOUS"
    fc = "F_CONTIGUOUS"
    if c.flags[cc] or c.flags[fc]:
        assert c.flags[cc] == cy.flags[cc]
        assert c.flags[fc] == cy.flags[fc]
    else:
        # Original was not contiguous but round-trip version
        # should be c-contig.
        assert cy.flags[cc]

    if hasattr(c, "unit"):
        assert c.unit == cy.unit


def compare_coord(c, cy):
    assert c.shape == cy.shape
    assert c.frame.name == cy.frame.name

    assert list(c.frame_attributes) == list(cy.frame_attributes)
    for attr in c.frame_attributes:
        assert getattr(c, attr) == getattr(cy, attr)

    assert list(c.representation_component_names) == list(
        cy.representation_component_names
    )
    for name in c.representation_component_names:
        assert np.all(getattr(c, attr) == getattr(cy, attr))


@pytest.mark.parametrize("frame", ["fk4", "altaz"])
def test_skycoord(frame):
    c = SkyCoord(
        [[1, 2], [3, 4]],
        [[5, 6], [7, 8]],
        unit="deg",
        frame=frame,
        obstime=Time("2016-01-02"),
        location=EarthLocation(1000, 2000, 3000, unit=u.km),
    )
    cy = load(dump(c))
    compare_coord(c, cy)


@pytest.mark.parametrize(
    "rep",
    [
        CartesianRepresentation(1 * u.m, 2.0 * u.m, 3.0 * u.m),
        SphericalRepresentation(
            [[1, 2], [3, 4]] * u.deg, [[5, 6], [7, 8]] * u.deg, 10 * u.pc
        ),
        UnitSphericalRepresentation(0 * u.deg, 10 * u.deg),
        SphericalCosLatDifferential(
            [[1.0], [2.0]] * u.mas / u.yr,
            [4.0, 5.0] * u.mas / u.yr,
            [[[10]], [[20]]] * u.km / u.s,
        ),
        CartesianDifferential([10, 20, 30] * u.km / u.s),
        CartesianRepresentation(
            [1, 2, 3] * u.m,
            differentials=CartesianDifferential([10, 20, 30] * u.km / u.s),
        ),
        SphericalRepresentation(
            [[1, 2], [3, 4]] * u.deg,
            [[5, 6], [7, 8]] * u.deg,
            10 * u.pc,
            differentials={
                "s": SphericalDifferential(
                    [[0.0, 1.0], [2.0, 3.0]] * u.mas / u.yr,
                    [[4.0, 5.0], [6.0, 7.0]] * u.mas / u.yr,
                    10 * u.km / u.s,
                )
            },
        ),
    ],
)
def test_representations(rep):
    rrep = load(dump(rep))
    assert np.all(representation_equal(rrep, rep))


def _get_time():
    t = Time(
        [[1], [2]], format="cxcsec", location=EarthLocation(1000, 2000, 3000, unit=u.km)
    )
    t.format = "iso"
    t.precision = 5
    t.delta_ut1_utc = np.array([[3.0], [4.0]])
    t.delta_tdb_tt = np.array([[5.0], [6.0]])
    t.out_subfmt = "date_hm"

    return t


def compare_time(t, ty):
    assert type(t) is type(ty)
    assert np.all(t == ty)
    for attr in (
        "shape",
        "jd1",
        "jd2",
        "format",
        "scale",
        "precision",
        "in_subfmt",
        "out_subfmt",
        "location",
        "delta_ut1_utc",
        "delta_tdb_tt",
    ):
        assert np.all(getattr(t, attr) == getattr(ty, attr))


def test_time():
    t = _get_time()
    ty = load(dump(t))
    compare_time(t, ty)


def test_timedelta():
    t = _get_time()
    dt = t - t + 0.1234556 * u.s
    dty = load(dump(dt))

    assert type(dt) is type(dty)
    for attr in ("shape", "jd1", "jd2", "format", "scale"):
        assert np.all(getattr(dt, attr) == getattr(dty, attr))


def test_serialized_column():
    sc = SerializedColumn({"name": "hello", "other": 1, "other2": 2.0})
    scy = load(dump(sc))

    assert sc == scy


def test_load_all():
    t = _get_time()
    unit = u.m / u.s
    c = SkyCoord(
        [[1, 2], [3, 4]],
        [[5, 6], [7, 8]],
        unit="deg",
        frame="fk4",
        obstime=Time("2016-01-02"),
        location=EarthLocation(1000, 2000, 3000, unit=u.km),
    )

    # Make a multi-document stream
    out = "---\n" + dump(t) + "---\n" + dump(unit) + "---\n" + dump(c)

    ty, unity, cy = list(load_all(out))

    compare_time(t, ty)
    compare_coord(c, cy)
    assert unity == unit


def test_ecsv_astropy_objects_in_meta():
    """
    Test that astropy core objects in ``meta`` are serialized.
    """
    t = QTable([[1, 2] * u.m, [4, 5]], names=["a", "b"])
    tm = _get_time()
    c = SkyCoord(
        [[1, 2], [3, 4]],
        [[5, 6], [7, 8]],
        unit="deg",
        frame="fk4",
        obstime=Time("2016-01-02"),
        location=EarthLocation(1000, 2000, 3000, unit=u.km),
    )
    unit = u.m / u.s

    t.meta = {"tm": tm, "c": c, "unit": unit}
    out = StringIO()
    t.write(out, format="ascii.ecsv")
    t2 = QTable.read(out.getvalue(), format="ascii.ecsv")

    compare_time(tm, t2.meta["tm"])
    compare_coord(c, t2.meta["c"])
    assert t2.meta["unit"] == unit


def test_yaml_dump_of_object_arrays_fail():
    """Test that dumping and loading object arrays fails."""
    with pytest.raises(TypeError, match="cannot serialize"):
        dump(np.array([1, 2, 3], dtype=object))


def test_yaml_load_of_object_arrays_fail():
    """Test that dumping and loading object arrays fails.

    The string to load was obtained by suppressing the exception and dumping
    ``np.array([1, 2, 3], dtype=object)`` to a yaml file.
    """
    with pytest.raises(TypeError, match="cannot load numpy array"):
        load(
            """!numpy.ndarray
            buffer: !!binary |
              WndBQUFISUFBQUJwQUFBQQ==
            dtype: object
            order: C
            shape: !!python/tuple [3]"""
        )
