# Licensed under a 3-clause BSD style license - see LICENSE.rst

import os
import textwrap
from contextlib import nullcontext

import numpy as np
import pytest

from astropy import log
from astropy import units as u
from astropy.io import fits
from astropy.nddata import _testing as nd_testing
from astropy.nddata.ccddata import CCDData
from astropy.nddata.nduncertainty import (
    InverseVariance,
    MissingDataAssociationException,
    StdDevUncertainty,
    VarianceUncertainty,
)
from astropy.table import Table
from astropy.tests.helper import PYTEST_LT_8_0
from astropy.utils import NumpyRNGContext
from astropy.utils.data import (
    get_pkg_data_contents,
    get_pkg_data_filename,
    get_pkg_data_filenames,
)
from astropy.utils.exceptions import AstropyWarning
from astropy.wcs import WCS, FITSFixedWarning

DEFAULT_DATA_SIZE = 100

with NumpyRNGContext(123):
    _random_array = np.random.normal(size=[DEFAULT_DATA_SIZE, DEFAULT_DATA_SIZE])
    _random_psf = np.random.normal(size=(20, 20))


@pytest.fixture
def home_is_tmpdir(tmp_path, monkeypatch, request):
    """
    Pytest fixture to run a test case with tilde-prefixed paths.

    In the tilde-path case, environment variables will be temporarily
    modified so that '~' resolves to the temp directory.
    """
    # For Unix
    monkeypatch.setenv("HOME", str(tmp_path))
    # For Windows
    monkeypatch.setenv("USERPROFILE", str(tmp_path))


def create_ccd_data():
    """
    Return a CCDData object of size DEFAULT_DATA_SIZE x DEFAULT_DATA_SIZE
    with units of ADU.
    """
    data = _random_array.copy()
    fake_meta = {"my_key": 42, "your_key": "not 42"}
    ccd = CCDData(data, unit=u.adu)
    ccd.header = fake_meta
    return ccd


def test_ccddata_empty():
    with pytest.raises(TypeError):
        CCDData()  # empty initializer should fail


def test_ccddata_must_have_unit():
    with pytest.raises(ValueError):
        CCDData(np.zeros([2, 2]))


def test_ccddata_unit_cannot_be_set_to_none():
    ccd_data = create_ccd_data()
    with pytest.raises(TypeError):
        ccd_data.unit = None


def test_ccddata_meta_header_conflict():
    with pytest.raises(ValueError, match=".*can't have both header and meta.*"):
        CCDData([1, 2, 3], unit="", meta={1: 1}, header={2: 2})


def test_ccddata_simple():
    ccd_data = create_ccd_data()
    assert ccd_data.shape == (DEFAULT_DATA_SIZE, DEFAULT_DATA_SIZE)
    assert ccd_data.size == DEFAULT_DATA_SIZE * DEFAULT_DATA_SIZE
    assert ccd_data.dtype == np.dtype(float)


def test_ccddata_init_with_string_electron_unit():
    ccd = CCDData(np.zeros([2, 2]), unit="electron")
    assert ccd.unit is u.electron


def test_initialize_from_FITS(tmp_path):
    ccd_data = create_ccd_data()
    hdu = fits.PrimaryHDU(ccd_data)
    hdulist = fits.HDUList([hdu])
    filename = str(tmp_path / "a_file.fits")
    hdulist.writeto(filename)
    cd = CCDData.read(filename, unit=u.electron)
    assert cd.shape == (DEFAULT_DATA_SIZE, DEFAULT_DATA_SIZE)
    assert cd.size == DEFAULT_DATA_SIZE * DEFAULT_DATA_SIZE
    assert np.issubdtype(cd.data.dtype, np.floating)
    for k, v in hdu.header.items():
        assert cd.meta[k] == v


def test_initialize_from_fits_with_unit_in_header(tmp_path):
    fake_img = np.zeros([2, 2])
    hdu = fits.PrimaryHDU(fake_img)
    hdu.header["bunit"] = u.adu.to_string()
    filename = str(tmp_path / "a_file.fits")
    hdu.writeto(filename)
    ccd = CCDData.read(filename)
    # ccd should pick up the unit adu from the fits header...did it?
    assert ccd.unit is u.adu

    # An explicit unit in the read overrides any unit in the FITS file
    ccd2 = CCDData.read(filename, unit="photon")
    assert ccd2.unit is u.photon


def test_initialize_from_fits_with_ADU_in_header(tmp_path):
    fake_img = np.zeros([2, 2])
    hdu = fits.PrimaryHDU(fake_img)
    hdu.header["bunit"] = "ADU"
    filename = str(tmp_path / "a_file.fits")
    hdu.writeto(filename)
    ccd = CCDData.read(filename)
    # ccd should pick up the unit adu from the fits header...did it?
    assert ccd.unit is u.adu


def test_initialize_from_fits_with_invalid_unit_in_header(tmp_path):
    hdu = fits.PrimaryHDU(np.ones((2, 2)))
    hdu.header["bunit"] = "definetely-not-a-unit"
    filename = str(tmp_path / "a_file.fits")
    hdu.writeto(filename)
    with pytest.raises(ValueError):
        CCDData.read(filename)


def test_initialize_from_fits_with_technically_invalid_but_not_really(tmp_path):
    hdu = fits.PrimaryHDU(np.ones((2, 2)))
    hdu.header["bunit"] = "ELECTRONS/S"
    filename = str(tmp_path / "a_file.fits")
    hdu.writeto(filename)
    ccd = CCDData.read(filename)
    assert ccd.unit == u.electron / u.s


def test_initialize_from_fits_with_data_in_different_extension(tmp_path):
    fake_img = np.arange(4).reshape(2, 2)
    hdu1 = fits.PrimaryHDU()
    hdu2 = fits.ImageHDU(fake_img)
    hdus = fits.HDUList([hdu1, hdu2])
    filename = str(tmp_path / "a_file.fits")
    hdus.writeto(filename)
    ccd = CCDData.read(filename, unit="adu")
    # ccd should pick up the unit adu from the fits header...did it?
    np.testing.assert_array_equal(ccd.data, fake_img)
    # check that the header is the combined header
    assert hdu2.header + hdu1.header == ccd.header


def test_initialize_from_fits_with_extension(tmp_path):
    fake_img1 = np.zeros([2, 2])
    fake_img2 = np.arange(4).reshape(2, 2)
    hdu0 = fits.PrimaryHDU()
    hdu1 = fits.ImageHDU(fake_img1, name="first", ver=1)
    hdu2 = fits.ImageHDU(fake_img2, name="second", ver=1)
    hdus = fits.HDUList([hdu0, hdu1, hdu2])
    filename = str(tmp_path / "a_file.fits")
    hdus.writeto(filename)
    ccd = CCDData.read(filename, hdu=2, unit="adu")
    # ccd should pick up the unit adu from the fits header...did it?
    np.testing.assert_array_equal(ccd.data, fake_img2)
    # check hdu string parameter
    ccd = CCDData.read(filename, hdu="second", unit="adu")
    np.testing.assert_array_equal(ccd.data, fake_img2)
    # check hdu tuple parameter
    ccd = CCDData.read(filename, hdu=("second", 1), unit="adu")
    np.testing.assert_array_equal(ccd.data, fake_img2)


def test_write_unit_to_hdu():
    ccd_data = create_ccd_data()
    ccd_unit = ccd_data.unit
    hdulist = ccd_data.to_hdu()
    assert "bunit" in hdulist[0].header
    assert hdulist[0].header["bunit"] == ccd_unit.to_string()


def test_initialize_from_FITS_bad_keyword_raises_error(tmp_path):
    # There are two fits.open keywords that are not permitted in ccdproc:
    #     do_not_scale_image_data and scale_back
    ccd_data = create_ccd_data()
    filename = str(tmp_path / "test.fits")
    ccd_data.write(filename)

    with pytest.raises(TypeError):
        CCDData.read(filename, unit=ccd_data.unit, do_not_scale_image_data=True)
    with pytest.raises(TypeError):
        CCDData.read(filename, unit=ccd_data.unit, scale_back=True)


def test_ccddata_writer(tmp_path):
    ccd_data = create_ccd_data()
    filename = str(tmp_path / "test.fits")
    ccd_data.write(filename)

    ccd_disk = CCDData.read(filename, unit=ccd_data.unit)
    np.testing.assert_array_equal(ccd_data.data, ccd_disk.data)


def test_ccddata_writer_as_imagehdu(tmp_path):
    ccd_data = create_ccd_data()
    filename = str(tmp_path / "test.fits")
    ccd_data.write(filename, as_image_hdu=False)
    with fits.open(filename) as hdus:
        assert len(hdus) == 1

    filename = str(tmp_path / "test2.fits")
    ccd_data.write(filename, as_image_hdu=True)
    with fits.open(filename) as hdus:
        assert len(hdus) == 2
        assert isinstance(hdus[1], fits.ImageHDU)


def test_ccddata_meta_is_case_sensitive():
    ccd_data = create_ccd_data()
    key = "SoMeKEY"
    ccd_data.meta[key] = 10
    assert key.lower() not in ccd_data.meta
    assert key.upper() not in ccd_data.meta
    assert key in ccd_data.meta


def test_ccddata_meta_is_not_fits_header():
    ccd_data = create_ccd_data()
    ccd_data.meta = {"OBSERVER": "Edwin Hubble"}
    assert not isinstance(ccd_data.meta, fits.Header)


def test_fromMEF(tmp_path):
    ccd_data = create_ccd_data()
    hdu = fits.PrimaryHDU(ccd_data)
    hdu2 = fits.PrimaryHDU(2 * ccd_data.data)
    hdulist = fits.HDUList(hdu)
    hdulist.append(hdu2)
    filename = str(tmp_path / "a_file.fits")
    hdulist.writeto(filename)
    # by default, we reading from the first extension
    cd = CCDData.read(filename, unit=u.electron)
    np.testing.assert_array_equal(cd.data, ccd_data.data)
    # but reading from the second should work too
    cd = CCDData.read(filename, hdu=1, unit=u.electron)
    np.testing.assert_array_equal(cd.data, 2 * ccd_data.data)


def test_metafromheader():
    hdr = fits.header.Header()
    hdr["observer"] = "Edwin Hubble"
    hdr["exptime"] = "3600"

    d1 = CCDData(np.ones((5, 5)), meta=hdr, unit=u.electron)
    assert d1.meta["OBSERVER"] == "Edwin Hubble"
    assert d1.header["OBSERVER"] == "Edwin Hubble"


def test_metafromdict():
    dic = {"OBSERVER": "Edwin Hubble", "EXPTIME": 3600}
    d1 = CCDData(np.ones((5, 5)), meta=dic, unit=u.electron)
    assert d1.meta["OBSERVER"] == "Edwin Hubble"


def test_header2meta():
    hdr = fits.header.Header()
    hdr["observer"] = "Edwin Hubble"
    hdr["exptime"] = "3600"

    d1 = CCDData(np.ones((5, 5)), unit=u.electron)
    d1.header = hdr
    assert d1.meta["OBSERVER"] == "Edwin Hubble"
    assert d1.header["OBSERVER"] == "Edwin Hubble"


def test_metafromstring_fail():
    hdr = "this is not a valid header"
    with pytest.raises(TypeError):
        CCDData(np.ones((5, 5)), meta=hdr, unit=u.adu)


def test_setting_bad_uncertainty_raises_error():
    ccd_data = create_ccd_data()
    with pytest.raises(TypeError):
        # Uncertainty is supposed to be an instance of NDUncertainty
        ccd_data.uncertainty = 10


def test_setting_uncertainty_with_array():
    ccd_data = create_ccd_data()
    ccd_data.uncertainty = None
    fake_uncertainty = np.sqrt(np.abs(ccd_data.data))
    ccd_data.uncertainty = fake_uncertainty.copy()
    np.testing.assert_array_equal(ccd_data.uncertainty.array, fake_uncertainty)


def test_setting_uncertainty_wrong_shape_raises_error():
    ccd_data = create_ccd_data()
    with pytest.raises(ValueError):
        ccd_data.uncertainty = np.zeros([3, 4])


def test_to_hdu():
    ccd_data = create_ccd_data()
    ccd_data.meta = {"observer": "Edwin Hubble"}
    fits_hdulist = ccd_data.to_hdu()
    assert isinstance(fits_hdulist, fits.HDUList)
    for k, v in ccd_data.meta.items():
        assert fits_hdulist[0].header[k] == v
    np.testing.assert_array_equal(fits_hdulist[0].data, ccd_data.data)


def test_to_hdu_as_imagehdu():
    ccd_data = create_ccd_data()
    fits_hdulist = ccd_data.to_hdu(as_image_hdu=False)
    assert isinstance(fits_hdulist[0], fits.PrimaryHDU)
    fits_hdulist = ccd_data.to_hdu(as_image_hdu=True)
    assert isinstance(fits_hdulist[0], fits.ImageHDU)


def test_copy():
    ccd_data = create_ccd_data()
    ccd_copy = ccd_data.copy()
    np.testing.assert_array_equal(ccd_copy.data, ccd_data.data)
    assert ccd_copy.unit == ccd_data.unit
    assert ccd_copy.meta == ccd_data.meta


@pytest.mark.parametrize(
    "operation,affects_uncertainty",
    [
        ("multiply", True),
        ("divide", True),
    ],
)
@pytest.mark.parametrize(
    "operand",
    [
        2.0,
        2 * u.dimensionless_unscaled,
        2 * u.photon / u.adu,
    ],
)
@pytest.mark.parametrize("with_uncertainty", [True, False])
def test_mult_div_overload(operand, with_uncertainty, operation, affects_uncertainty):
    ccd_data = create_ccd_data()
    if with_uncertainty:
        ccd_data.uncertainty = StdDevUncertainty(np.ones_like(ccd_data))
    method = getattr(ccd_data, operation)
    np_method = getattr(np, operation)
    result = method(operand)
    assert result is not ccd_data
    assert isinstance(result, CCDData)
    assert result.uncertainty is None or isinstance(
        result.uncertainty, StdDevUncertainty
    )
    try:
        op_value = operand.value
    except AttributeError:
        op_value = operand

    np.testing.assert_array_equal(result.data, np_method(ccd_data.data, op_value))
    if with_uncertainty:
        if affects_uncertainty:
            np.testing.assert_array_equal(
                result.uncertainty.array,
                np_method(ccd_data.uncertainty.array, op_value),
            )
        else:
            np.testing.assert_array_equal(
                result.uncertainty.array, ccd_data.uncertainty.array
            )
    else:
        assert result.uncertainty is None

    if isinstance(operand, u.Quantity):
        # Need the "1 *" below to force arguments to be Quantity to work around
        # astropy/astropy#2377
        expected_unit = np_method(1 * ccd_data.unit, 1 * operand.unit).unit
        assert result.unit == expected_unit
    else:
        assert result.unit == ccd_data.unit


@pytest.mark.parametrize(
    "operation,affects_uncertainty",
    [
        ("add", False),
        ("subtract", False),
    ],
)
@pytest.mark.parametrize(
    "operand,expect_failure",
    [
        (2.0, u.UnitsError),  # fail--units don't match image
        (2 * u.dimensionless_unscaled, u.UnitsError),  # same
        (2 * u.adu, False),
    ],
)
@pytest.mark.parametrize("with_uncertainty", [True, False])
def test_add_sub_overload(
    operand, expect_failure, with_uncertainty, operation, affects_uncertainty
):
    ccd_data = create_ccd_data()
    if with_uncertainty:
        ccd_data.uncertainty = StdDevUncertainty(np.ones_like(ccd_data))
    method = getattr(ccd_data, operation)
    np_method = getattr(np, operation)
    if expect_failure:
        with pytest.raises(expect_failure):
            result = method(operand)
        return
    else:
        result = method(operand)
    assert result is not ccd_data
    assert isinstance(result, CCDData)
    assert result.uncertainty is None or isinstance(
        result.uncertainty, StdDevUncertainty
    )
    try:
        op_value = operand.value
    except AttributeError:
        op_value = operand

    np.testing.assert_array_equal(result.data, np_method(ccd_data.data, op_value))
    if with_uncertainty:
        if affects_uncertainty:
            np.testing.assert_array_equal(
                result.uncertainty.array,
                np_method(ccd_data.uncertainty.array, op_value),
            )
        else:
            np.testing.assert_array_equal(
                result.uncertainty.array, ccd_data.uncertainty.array
            )
    else:
        assert result.uncertainty is None

    if isinstance(operand, u.Quantity):
        assert result.unit == ccd_data.unit and result.unit == operand.unit
    else:
        assert result.unit == ccd_data.unit


def test_arithmetic_overload_fails():
    ccd_data = create_ccd_data()
    with pytest.raises(TypeError):
        ccd_data.multiply("five")

    with pytest.raises(TypeError):
        ccd_data.divide("five")

    with pytest.raises(TypeError):
        ccd_data.add("five")

    with pytest.raises(TypeError):
        ccd_data.subtract("five")


def test_arithmetic_no_wcs_compare():
    ccd = CCDData(np.ones((10, 10)), unit="")
    assert ccd.add(ccd, compare_wcs=None).wcs is None
    assert ccd.subtract(ccd, compare_wcs=None).wcs is None
    assert ccd.multiply(ccd, compare_wcs=None).wcs is None
    assert ccd.divide(ccd, compare_wcs=None).wcs is None


def test_arithmetic_with_wcs_compare():
    def return_true(_, __):
        return True

    wcs1, wcs2 = nd_testing.create_two_equal_wcs(naxis=2)
    ccd1 = CCDData(np.ones((10, 10)), unit="", wcs=wcs1)
    ccd2 = CCDData(np.ones((10, 10)), unit="", wcs=wcs2)
    nd_testing.assert_wcs_seem_equal(ccd1.add(ccd2, compare_wcs=return_true).wcs, wcs1)
    nd_testing.assert_wcs_seem_equal(
        ccd1.subtract(ccd2, compare_wcs=return_true).wcs, wcs1
    )
    nd_testing.assert_wcs_seem_equal(
        ccd1.multiply(ccd2, compare_wcs=return_true).wcs, wcs1
    )
    nd_testing.assert_wcs_seem_equal(
        ccd1.divide(ccd2, compare_wcs=return_true).wcs, wcs1
    )


def test_arithmetic_with_wcs_compare_fail():
    def return_false(_, __):
        return False

    ccd1 = CCDData(np.ones((10, 10)), unit="", wcs=WCS())
    ccd2 = CCDData(np.ones((10, 10)), unit="", wcs=WCS())
    with pytest.raises(ValueError):
        ccd1.add(ccd2, compare_wcs=return_false)
    with pytest.raises(ValueError):
        ccd1.subtract(ccd2, compare_wcs=return_false)
    with pytest.raises(ValueError):
        ccd1.multiply(ccd2, compare_wcs=return_false)
    with pytest.raises(ValueError):
        ccd1.divide(ccd2, compare_wcs=return_false)


def test_arithmetic_overload_ccddata_operand():
    ccd_data = create_ccd_data()
    ccd_data.uncertainty = StdDevUncertainty(np.ones_like(ccd_data))
    operand = ccd_data.copy()
    result = ccd_data.add(operand)
    assert len(result.meta) == 0
    np.testing.assert_array_equal(result.data, 2 * ccd_data.data)
    np.testing.assert_array_almost_equal_nulp(
        result.uncertainty.array, np.sqrt(2) * ccd_data.uncertainty.array
    )

    result = ccd_data.subtract(operand)
    assert len(result.meta) == 0
    np.testing.assert_array_equal(result.data, 0 * ccd_data.data)
    np.testing.assert_array_almost_equal_nulp(
        result.uncertainty.array, np.sqrt(2) * ccd_data.uncertainty.array
    )

    result = ccd_data.multiply(operand)
    assert len(result.meta) == 0
    np.testing.assert_array_equal(result.data, ccd_data.data**2)
    expected_uncertainty = (
        np.sqrt(2) * np.abs(ccd_data.data) * ccd_data.uncertainty.array
    )
    np.testing.assert_allclose(result.uncertainty.array, expected_uncertainty)

    result = ccd_data.divide(operand)
    assert len(result.meta) == 0
    np.testing.assert_array_equal(result.data, np.ones_like(ccd_data.data))
    expected_uncertainty = (
        np.sqrt(2) / np.abs(ccd_data.data) * ccd_data.uncertainty.array
    )
    np.testing.assert_allclose(result.uncertainty.array, expected_uncertainty)


def test_arithmetic_overload_differing_units():
    a = np.array([1, 2, 3]) * u.m
    b = np.array([1, 2, 3]) * u.cm
    ccddata = CCDData(a)

    # TODO: Could also be parametrized.
    res = ccddata.add(b)
    np.testing.assert_array_almost_equal(res.data, np.add(a, b).value)
    assert res.unit == np.add(a, b).unit

    res = ccddata.subtract(b)
    np.testing.assert_array_almost_equal(res.data, np.subtract(a, b).value)
    assert res.unit == np.subtract(a, b).unit

    res = ccddata.multiply(b)
    np.testing.assert_array_almost_equal(res.data, np.multiply(a, b).value)
    assert res.unit == np.multiply(a, b).unit

    res = ccddata.divide(b)
    np.testing.assert_array_almost_equal(res.data, np.divide(a, b).value)
    assert res.unit == np.divide(a, b).unit


def test_arithmetic_add_with_array():
    ccd = CCDData(np.ones((3, 3)), unit="")
    res = ccd.add(np.arange(3))
    np.testing.assert_array_equal(res.data, [[1, 2, 3]] * 3)

    ccd = CCDData(np.ones((3, 3)), unit="adu")
    with pytest.raises(ValueError):
        ccd.add(np.arange(3))


def test_arithmetic_subtract_with_array():
    ccd = CCDData(np.ones((3, 3)), unit="")
    res = ccd.subtract(np.arange(3))
    np.testing.assert_array_equal(res.data, [[1, 0, -1]] * 3)

    ccd = CCDData(np.ones((3, 3)), unit="adu")
    with pytest.raises(ValueError):
        ccd.subtract(np.arange(3))


def test_arithmetic_multiply_with_array():
    ccd = CCDData(np.ones((3, 3)) * 3, unit=u.m)
    res = ccd.multiply(np.ones((3, 3)) * 2)
    np.testing.assert_array_equal(res.data, [[6, 6, 6]] * 3)
    assert res.unit == ccd.unit


def test_arithmetic_divide_with_array():
    ccd = CCDData(np.ones((3, 3)), unit=u.m)
    res = ccd.divide(np.ones((3, 3)) * 2)
    np.testing.assert_array_equal(res.data, [[0.5, 0.5, 0.5]] * 3)
    assert res.unit == ccd.unit


def test_history_preserved_if_metadata_is_fits_header(tmp_path):
    fake_img = np.zeros([2, 2])
    hdu = fits.PrimaryHDU(fake_img)
    hdu.header["history"] = "one"
    hdu.header["history"] = "two"
    hdu.header["history"] = "three"
    assert len(hdu.header["history"]) == 3
    tmp_file = str(tmp_path / "temp.fits")
    hdu.writeto(tmp_file)

    ccd_read = CCDData.read(tmp_file, unit="adu")
    assert ccd_read.header["history"] == hdu.header["history"]


def test_infol_logged_if_unit_in_fits_header(tmp_path):
    ccd_data = create_ccd_data()
    tmpfile = str(tmp_path / "temp.fits")
    ccd_data.write(tmpfile)
    log.setLevel("INFO")
    explicit_unit_name = "photon"
    with log.log_to_list() as log_list:
        _ = CCDData.read(tmpfile, unit=explicit_unit_name)
        assert explicit_unit_name in log_list[0].message


def test_wcs_attribute(tmp_path):
    """
    Check that WCS attribute gets added to header, and that if a CCDData
    object is created from a FITS file with a header, and the WCS attribute
    is modified, then the CCDData object is turned back into an hdu, the
    WCS object overwrites the old WCS information in the header.
    """
    ccd_data = create_ccd_data()
    tmpfile = str(tmp_path / "temp.fits")
    # This wcs example is taken from the astropy.wcs docs.
    wcs = WCS(naxis=2)
    wcs.wcs.crpix = np.array(ccd_data.shape) / 2
    wcs.wcs.cdelt = np.array([-0.066667, 0.066667])
    wcs.wcs.crval = [0, -90]
    wcs.wcs.ctype = ["RA---AIR", "DEC--AIR"]
    wcs.wcs.set_pv([(2, 1, 45.0)])
    ccd_data.header = ccd_data.to_hdu()[0].header
    ccd_data.header.extend(wcs.to_header(), useblanks=False)
    ccd_data.write(tmpfile)

    # Get the header length after it has been extended by the WCS keywords
    original_header_length = len(ccd_data.header)

    ccd_new = CCDData.read(tmpfile)
    # WCS attribute should be set for ccd_new
    assert ccd_new.wcs is not None
    # WCS attribute should be equal to wcs above.
    assert ccd_new.wcs.wcs == wcs.wcs

    # Converting CCDData object with wcs to an hdu shouldn't
    # create duplicate wcs-related entries in the header.
    ccd_new_hdu = ccd_new.to_hdu()[0]
    assert len(ccd_new_hdu.header) == original_header_length

    # Making a CCDData with WCS (but not WCS in the header) should lead to
    # WCS information in the header when it is converted to an HDU.
    ccd_wcs_not_in_header = CCDData(ccd_data.data, wcs=wcs, unit="adu")
    hdu = ccd_wcs_not_in_header.to_hdu()[0]
    wcs_header = wcs.to_header()
    for k in wcs_header.keys():
        # Skip these keywords if they are in the WCS header because they are
        # not WCS-specific.
        if k in ["", "COMMENT", "HISTORY"]:
            continue
        # No keyword from the WCS should be in the header.
        assert k not in ccd_wcs_not_in_header.header
        # Every keyword in the WCS should be in the header of the HDU
        assert hdu.header[k] == wcs_header[k]

    # Now check that if WCS of a CCDData is modified, then the CCDData is
    # converted to an HDU, the WCS keywords in the header are overwritten
    # with the appropriate keywords from the header.
    #
    # ccd_new has a WCS and WCS keywords in the header, so try modifying
    # the WCS.
    ccd_new.wcs.wcs.cdelt *= 2
    ccd_new_hdu_mod_wcs = ccd_new.to_hdu()[0]
    assert ccd_new_hdu_mod_wcs.header["CDELT1"] == ccd_new.wcs.wcs.cdelt[0]
    assert ccd_new_hdu_mod_wcs.header["CDELT2"] == ccd_new.wcs.wcs.cdelt[1]


def test_wcs_keywords_removed_from_header():
    """
    Test, for the file included with the nddata tests, that WCS keywords are
    properly removed from header.
    """
    from astropy.nddata.ccddata import _KEEP_THESE_KEYWORDS_IN_HEADER

    keepers = set(_KEEP_THESE_KEYWORDS_IN_HEADER)
    data_file = get_pkg_data_filename("data/sip-wcs.fits")
    ccd = CCDData.read(data_file)
    with pytest.warns(
        AstropyWarning, match=r"Some non-standard WCS keywords were excluded"
    ):
        wcs_header = ccd.wcs.to_header()
    assert not (set(wcs_header) & set(ccd.meta) - keepers)

    # Make sure that exceptions are not raised when trying to remove missing
    # keywords. o4sp040b0_raw.fits of io.fits is missing keyword 'PC1_1'.
    data_file1 = get_pkg_data_filename(
        "data/o4sp040b0_raw.fits", package="astropy.io.fits.tests"
    )
    if PYTEST_LT_8_0:
        ctx = nullcontext()
    else:
        ctx = pytest.warns(FITSFixedWarning, match="'datfix' made the change")

    with pytest.warns(FITSFixedWarning, match="'unitfix' made the change"), ctx:
        ccd = CCDData.read(data_file1, unit="count")


def test_wcs_SIP_coefficient_keywords_removed():
    # If SIP polynomials are present, check that no more polynomial
    # coefficients remain in the header. See #8598

    # The SIP paper is ambiguous as to whether keywords like
    # A_0_0 can appear in the header for a 2nd order or higher
    # polynomial. The paper clearly says that the corrections
    # are only for quadratic or higher order, so A_0_0 and the like
    # should be zero if they are present, but they apparently can be
    # there (or at least astrometry.net produces them).

    # astropy WCS does not write those coefficients, so they were
    # not being removed from the header even though they are WCS-related.

    data_file = get_pkg_data_filename("data/sip-wcs.fits")
    test_keys = ["A_0_0", "B_0_1"]

    # Make sure the keywords added to this file for testing are there
    with fits.open(data_file) as hdu:
        for key in test_keys:
            assert key in hdu[0].header

    ccd = CCDData.read(data_file)

    # Now the test...the two keywords above should have been removed.
    for key in test_keys:
        assert key not in ccd.header


@pytest.mark.filterwarnings("ignore")
def test_wcs_keyword_removal_for_wcs_test_files():
    """
    Test, for the WCS test files, that keyword removal works as
    expected. Those cover a much broader range of WCS types than
    test_wcs_keywords_removed_from_header.

    Includes regression test for #8597
    """
    from astropy.nddata.ccddata import (
        _KEEP_THESE_KEYWORDS_IN_HEADER,
        _CDs,
        _generate_wcs_and_update_header,
        _PCs,
    )

    keepers = set(_KEEP_THESE_KEYWORDS_IN_HEADER)
    # NOTE: pyinstaller requires relative path here.
    wcs_headers = get_pkg_data_filenames("../../wcs/tests/data", pattern="*.hdr")

    for hdr in wcs_headers:
        # Skip the files that are expected to be bad...
        if (
            "invalid" in hdr
            or "nonstandard" in hdr
            or "segfault" in hdr
            or "chandra-pixlist-wcs" in hdr
        ):
            continue
        header_string = get_pkg_data_contents(hdr)
        header = fits.Header.fromstring(header_string)

        wcs = WCS(header_string)
        header_from_wcs = wcs.to_header(relax=True)

        new_header, new_wcs = _generate_wcs_and_update_header(header)
        new_wcs_header = new_wcs.to_header(relax=True)

        # Make sure all of the WCS-related keywords generated by astropy
        # have been removed.
        assert not (set(new_header) & set(new_wcs_header) - keepers)

        # Check that new_header contains no remaining WCS information.
        # Specifically, check that
        # 1. The combination of new_header and new_wcs does not contain
        #    both PCi_j and CDi_j keywords. See #8597.

        # Check for 1
        final_header = new_header + new_wcs_header
        final_header_set = set(final_header)

        if _PCs & final_header_set:
            assert not (_CDs & final_header_set)
        elif _CDs & final_header_set:
            assert not (_PCs & final_header_set)

        # Check that the new wcs is the same as the old.
        for k, v in new_wcs_header.items():
            if isinstance(v, str):
                assert header_from_wcs[k] == v
            else:
                np.testing.assert_almost_equal(header_from_wcs[k], v)


def test_read_wcs_not_creatable(tmp_path):
    # The following Header can't be converted to a WCS object. See also #6499.
    hdr_txt_example_WCS = textwrap.dedent(
        """
    SIMPLE  =                    T / Fits standard
    BITPIX  =                   16 / Bits per pixel
    NAXIS   =                    2 / Number of axes
    NAXIS1  =                 1104 / Axis length
    NAXIS2  =                 4241 / Axis length
    CRVAL1  =         164.98110962 / Physical value of the reference pixel X
    CRVAL2  =          44.34089279 / Physical value of the reference pixel Y
    CRPIX1  =                -34.0 / Reference pixel in X (pixel)
    CRPIX2  =               2041.0 / Reference pixel in Y (pixel)
    CDELT1  =           0.10380000 / X Scale projected on detector (#/pix)
    CDELT2  =           0.10380000 / Y Scale projected on detector (#/pix)
    CTYPE1  = 'RA---TAN'           / Pixel coordinate system
    CTYPE2  = 'WAVELENGTH'         / Pixel coordinate system
    CUNIT1  = 'degree  '           / Units used in both CRVAL1 and CDELT1
    CUNIT2  = 'nm      '           / Units used in both CRVAL2 and CDELT2
    CD1_1   =           0.20760000 / Pixel Coordinate translation matrix
    CD1_2   =           0.00000000 / Pixel Coordinate translation matrix
    CD2_1   =           0.00000000 / Pixel Coordinate translation matrix
    CD2_2   =           0.10380000 / Pixel Coordinate translation matrix
    C2YPE1  = 'RA---TAN'           / Pixel coordinate system
    C2YPE2  = 'DEC--TAN'           / Pixel coordinate system
    C2NIT1  = 'degree  '           / Units used in both C2VAL1 and C2ELT1
    C2NIT2  = 'degree  '           / Units used in both C2VAL2 and C2ELT2
    RADECSYS= 'FK5     '           / The equatorial coordinate system
    """
    )
    hdr = fits.Header.fromstring(hdr_txt_example_WCS, sep="\n")
    hdul = fits.HDUList([fits.PrimaryHDU(np.ones((4241, 1104)), header=hdr)])
    filename = str(tmp_path / "a_file.fits")
    hdul.writeto(filename)
    # The hdr cannot be converted to a WCS object because of an
    # InconsistentAxisTypesError but it should still open the file
    ccd = CCDData.read(filename, unit="adu")
    assert ccd.wcs is None


def test_header():
    ccd_data = create_ccd_data()
    a = {"Observer": "Hubble"}
    ccd = CCDData(ccd_data, header=a)
    assert ccd.meta == a


def test_wcs_arithmetic():
    ccd_data = create_ccd_data()
    wcs = WCS(naxis=2)
    ccd_data.wcs = wcs
    result = ccd_data.multiply(1.0)
    nd_testing.assert_wcs_seem_equal(result.wcs, wcs)


@pytest.mark.parametrize("operation", ["multiply", "divide", "add", "subtract"])
def test_wcs_arithmetic_ccd(operation):
    ccd_data = create_ccd_data()
    ccd_data2 = ccd_data.copy()
    ccd_data.wcs = WCS(naxis=2)
    method = getattr(ccd_data, operation)
    result = method(ccd_data2)
    nd_testing.assert_wcs_seem_equal(result.wcs, ccd_data.wcs)
    assert ccd_data2.wcs is None


def test_wcs_sip_handling():
    """
    Check whether the ctypes RA---TAN-SIP and DEC--TAN-SIP survive
    a roundtrip unchanged.
    """
    data_file = get_pkg_data_filename("data/sip-wcs.fits")

    def check_wcs_ctypes(header):
        expected_wcs_ctypes = {"CTYPE1": "RA---TAN-SIP", "CTYPE2": "DEC--TAN-SIP"}

        return [header[k] == v for k, v in expected_wcs_ctypes.items()]

    ccd_original = CCDData.read(data_file)
    # After initialization the keywords should be in the WCS, not in the
    # meta.
    with fits.open(data_file) as raw:
        good_ctype = check_wcs_ctypes(raw[0].header)
    assert all(good_ctype)

    ccd_new = ccd_original.to_hdu()
    good_ctype = check_wcs_ctypes(ccd_new[0].header)
    assert all(good_ctype)

    # Try converting to header with wcs_relax=False and
    # the header should contain the CTYPE keywords without
    # the -SIP

    ccd_no_relax = ccd_original.to_hdu(wcs_relax=False)
    good_ctype = check_wcs_ctypes(ccd_no_relax[0].header)
    assert not any(good_ctype)
    assert ccd_no_relax[0].header["CTYPE1"] == "RA---TAN"
    assert ccd_no_relax[0].header["CTYPE2"] == "DEC--TAN"


@pytest.mark.parametrize("operation", ["multiply", "divide", "add", "subtract"])
def test_mask_arithmetic_ccd(operation):
    ccd_data = create_ccd_data()
    ccd_data2 = ccd_data.copy()
    ccd_data.mask = ccd_data.data > 0
    method = getattr(ccd_data, operation)
    result = method(ccd_data2)
    np.testing.assert_equal(result.mask, ccd_data.mask)


def test_write_read_multiextensionfits_mask_default(tmp_path):
    # Test that if a mask is present the mask is saved and loaded by default.
    ccd_data = create_ccd_data()
    ccd_data.mask = ccd_data.data > 10
    filename = str(tmp_path / "a_file.fits")
    ccd_data.write(filename)
    ccd_after = CCDData.read(filename)
    assert ccd_after.mask is not None
    np.testing.assert_array_equal(ccd_data.mask, ccd_after.mask)


@pytest.mark.parametrize(
    "uncertainty_type", [StdDevUncertainty, VarianceUncertainty, InverseVariance]
)
def test_write_read_multiextensionfits_uncertainty_default(tmp_path, uncertainty_type):
    # Test that if a uncertainty is present it is saved and loaded by default.
    ccd_data = create_ccd_data()
    ccd_data.uncertainty = uncertainty_type(ccd_data.data * 10)
    filename = str(tmp_path / "a_file.fits")
    ccd_data.write(filename)
    ccd_after = CCDData.read(filename)
    assert ccd_after.uncertainty is not None
    assert type(ccd_after.uncertainty) is uncertainty_type
    np.testing.assert_array_equal(
        ccd_data.uncertainty.array, ccd_after.uncertainty.array
    )


@pytest.mark.parametrize(
    "uncertainty_type", [StdDevUncertainty, VarianceUncertainty, InverseVariance]
)
def test_write_read_multiextensionfits_uncertainty_different_uncertainty_key(
    tmp_path, uncertainty_type
):
    # Test that if a uncertainty is present it is saved and loaded by default.
    ccd_data = create_ccd_data()
    ccd_data.uncertainty = uncertainty_type(ccd_data.data * 10)
    filename = str(tmp_path / "a_file.fits")
    ccd_data.write(filename, key_uncertainty_type="Blah")
    ccd_after = CCDData.read(filename, key_uncertainty_type="Blah")
    assert ccd_after.uncertainty is not None
    assert type(ccd_after.uncertainty) is uncertainty_type
    np.testing.assert_array_equal(
        ccd_data.uncertainty.array, ccd_after.uncertainty.array
    )


def test_write_read_multiextensionfits_not(tmp_path):
    # Test that writing mask and uncertainty can be disabled
    ccd_data = create_ccd_data()
    ccd_data.mask = ccd_data.data > 10
    ccd_data.uncertainty = StdDevUncertainty(ccd_data.data * 10)
    filename = str(tmp_path / "a_file.fits")
    ccd_data.write(filename, hdu_mask=None, hdu_uncertainty=None)
    ccd_after = CCDData.read(filename)
    assert ccd_after.uncertainty is None
    assert ccd_after.mask is None


def test_write_read_multiextensionfits_custom_ext_names(tmp_path):
    # Test writing mask, uncertainty in another extension than default
    ccd_data = create_ccd_data()
    ccd_data.mask = ccd_data.data > 10
    ccd_data.uncertainty = StdDevUncertainty(ccd_data.data * 10)
    filename = str(tmp_path / "a_file.fits")
    ccd_data.write(filename, hdu_mask="Fun", hdu_uncertainty="NoFun")

    # Try reading with defaults extension names
    ccd_after = CCDData.read(filename)
    assert ccd_after.uncertainty is None
    assert ccd_after.mask is None

    # Try reading with custom extension names
    ccd_after = CCDData.read(filename, hdu_mask="Fun", hdu_uncertainty="NoFun")
    assert ccd_after.uncertainty is not None
    assert ccd_after.mask is not None
    np.testing.assert_array_equal(ccd_data.mask, ccd_after.mask)
    np.testing.assert_array_equal(
        ccd_data.uncertainty.array, ccd_after.uncertainty.array
    )


def test_read_old_style_multiextensionfits(tmp_path):
    # Regression test for https://github.com/astropy/ccdproc/issues/664
    #
    # Prior to astropy 3.1 there was no uncertainty type saved
    # in the multiextension fits files generated by CCDData
    # because the uncertainty had to be StandardDevUncertainty.
    #
    # Current version should be able to read those in.
    #
    size = 4
    # Value of the variables below are not important to the test.
    data = np.zeros([size, size])
    mask = data > 0.9
    uncert = np.sqrt(data)

    ccd = CCDData(data=data, mask=mask, uncertainty=uncert, unit="adu")
    # We'll create the file manually to ensure we have the
    # right extension names and no uncertainty type.
    hdulist = ccd.to_hdu()
    del hdulist[2].header["UTYPE"]
    file_name = str(tmp_path / "old_ccddata_mef.fits")
    hdulist.writeto(file_name)

    ccd = CCDData.read(file_name)

    assert isinstance(ccd.uncertainty, StdDevUncertainty)


def test_wcs():
    ccd_data = create_ccd_data()
    wcs = WCS(naxis=2)
    ccd_data.wcs = wcs
    assert ccd_data.wcs is wcs


def test_recognized_fits_formats_for_read_write(tmp_path):
    # These are the extensions that are supposed to be supported.
    ccd_data = create_ccd_data()
    supported_extensions = ["fit", "fits", "fts"]

    for ext in supported_extensions:
        path = str(tmp_path / f"test.{ext}")
        ccd_data.write(path)
        from_disk = CCDData.read(path)
        assert (ccd_data.data == from_disk.data).all()


def test_stddevuncertainty_compat_descriptor_no_parent():
    with pytest.raises(MissingDataAssociationException):
        StdDevUncertainty(np.ones((10, 10))).parent_nddata


def test_stddevuncertainty_compat_descriptor_no_weakref():
    # TODO: Remove this test if astropy 1.0 isn't supported anymore
    # This test might create a Memoryleak on purpose, so the last lines after
    # the assert are IMPORTANT cleanup.
    ccd = CCDData(np.ones((10, 10)), unit="")
    uncert = StdDevUncertainty(np.ones((10, 10)))
    uncert._parent_nddata = ccd
    assert uncert.parent_nddata is ccd
    uncert._parent_nddata = None


# https://github.com/astropy/astropy/issues/7595
def test_read_returns_image(tmp_path):
    # Test if CCData.read returns a image when reading a fits file containing
    # a table and image, in that order.
    tbl = Table(np.ones(10).reshape(5, 2))
    img = np.ones((5, 5))
    hdul = fits.HDUList(
        hdus=[fits.PrimaryHDU(), fits.TableHDU(tbl.as_array()), fits.ImageHDU(img)]
    )
    filename = str(tmp_path / "table_image.fits")
    hdul.writeto(filename)
    ccd = CCDData.read(filename, unit="adu")
    # Expecting to get (5, 5), the size of the image
    assert ccd.data.shape == (5, 5)


# https://github.com/astropy/astropy/issues/9664
def test_sliced_ccdata_to_hdu():
    wcs = WCS(naxis=2)
    wcs.wcs.crpix = 10, 10
    ccd = CCDData(np.ones((10, 10)), wcs=wcs, unit="pixel")
    trimmed = ccd[2:-2, 2:-2]
    hdul = trimmed.to_hdu()
    assert isinstance(hdul, fits.HDUList)
    assert hdul[0].header["CRPIX1"] == 8
    assert hdul[0].header["CRPIX2"] == 8


def test_read_write_tilde_paths(home_is_tmpdir):
    # Test for reading and writing to tilde-prefixed paths without errors
    ccd_data = create_ccd_data()
    filename = os.path.join("~", "test.fits")
    ccd_data.write(filename)

    ccd_disk = CCDData.read(filename, unit=ccd_data.unit)
    np.testing.assert_array_equal(ccd_data.data, ccd_disk.data)

    # Ensure the unexpanded path doesn't exist (e.g. no directory whose name is
    # a literal ~ was created)
    assert not os.path.exists(filename)


def test_ccddata_with_psf():
    psf = _random_psf.copy()
    ccd = CCDData(_random_array.copy(), unit=u.adu, psf=psf)
    assert (ccd.psf == psf).all()

    # cannot pass in non-ndarray
    with pytest.raises(TypeError, match="The psf must be a numpy array."):
        CCDData(_random_array.copy(), unit=u.adu, psf="something")


def test_psf_setter():
    psf = _random_psf.copy()
    ccd = CCDData(_random_array.copy(), unit=u.adu)
    ccd.psf = psf
    assert (ccd.psf == psf).all()

    # cannot set with non-ndarray
    with pytest.raises(TypeError, match="The psf must be a numpy array."):
        ccd.psf = 5


def test_write_read_psf(tmp_path):
    """Test that we can round-trip a CCDData with an attached PSF image."""
    ccd_data = create_ccd_data()
    ccd_data.psf = _random_psf

    filename = tmp_path / "test_write_read_psf.fits"
    ccd_data.write(filename)
    ccd_disk = CCDData.read(filename)
    np.testing.assert_array_equal(ccd_data.data, ccd_disk.data)
    np.testing.assert_array_equal(ccd_data.psf, ccd_disk.psf)

    # Try a different name for the PSF HDU.
    filename = tmp_path / "test_write_read_psf_hdu.fits"
    ccd_data.write(filename, hdu_psf="PSFOTHER")
    # psf will be None if we don't supply the new HDU name to the reader.
    ccd_disk = CCDData.read(filename)
    np.testing.assert_array_equal(ccd_data.data, ccd_disk.data)
    assert ccd_disk.psf is None
    # psf will round-trip if we do supply the new HDU name.
    ccd_disk = CCDData.read(filename, hdu_psf="PSFOTHER")
    np.testing.assert_array_equal(ccd_data.data, ccd_disk.data)
    np.testing.assert_array_equal(ccd_data.psf, ccd_disk.psf)
