# Licensed under a 3-clause BSD style license - see PYFITS.rst

import sys

import numpy as np

from astropy.io.fits.column import FITS2NUMPY, ColDefs, Column
from astropy.io.fits.fitsrec import FITS_rec, FITS_record
from astropy.io.fits.util import _is_int, _is_pseudo_integer, _pseudo_zero
from astropy.utils import lazyproperty

from .base import DELAYED, DTYPE2BITPIX
from .image import PrimaryHDU
from .table import _TableLikeHDU


class Group(FITS_record):
    """
    One group of the random group data.
    """

    def __init__(self, input, row=0, start=None, end=None, step=None, base=None):
        super().__init__(input, row, start, end, step, base)

    @property
    def parnames(self):
        return self.array.parnames

    @property
    def data(self):
        # The last column in the coldefs is the data portion of the group
        return self.field(self.array._coldefs.names[-1])

    @lazyproperty
    def _unique(self):
        return _par_indices(self.parnames)

    def par(self, parname):
        """
        Get the group parameter value.
        """
        if _is_int(parname):
            result = self.array[self.row][parname]
        else:
            indx = self._unique[parname.upper()]
            if len(indx) == 1:
                result = self.array[self.row][indx[0]]

            # if more than one group parameter have the same name
            else:
                result = self.array[self.row][indx[0]].astype("f8")
                for i in indx[1:]:
                    result += self.array[self.row][i]

        return result

    def setpar(self, parname, value):
        """
        Set the group parameter value.
        """
        # TODO: It would be nice if, instead of requiring a multi-part value to
        # be an array, there were an *option* to automatically split the value
        # into multiple columns if it doesn't already fit in the array data
        # type.

        if _is_int(parname):
            self.array[self.row][parname] = value
        else:
            indx = self._unique[parname.upper()]
            if len(indx) == 1:
                self.array[self.row][indx[0]] = value

            # if more than one group parameter have the same name, the
            # value must be a list (or tuple) containing arrays
            else:
                if isinstance(value, (list, tuple)) and len(indx) == len(value):
                    for i in range(len(indx)):
                        self.array[self.row][indx[i]] = value[i]
                else:
                    raise ValueError(
                        "Parameter value must be a sequence with "
                        f"{len(indx)} arrays/numbers."
                    )


class GroupData(FITS_rec):
    """
    Random groups data object.

    Allows structured access to FITS Group data in a manner analogous
    to tables.
    """

    _record_type = Group

    def __new__(
        cls,
        input=None,
        bitpix=None,
        pardata=None,
        parnames=[],
        bscale=None,
        bzero=None,
        parbscales=None,
        parbzeros=None,
    ):
        """
        Parameters
        ----------
        input : array or FITS_rec instance
            input data, either the group data itself (a
            `numpy.ndarray`) or a record array (`FITS_rec`) which will
            contain both group parameter info and the data.  The rest
            of the arguments are used only for the first case.

        bitpix : int
            data type as expressed in FITS ``BITPIX`` value (8, 16, 32,
            64, -32, or -64)

        pardata : sequence of array
            parameter data, as a list of (numeric) arrays.

        parnames : sequence of str
            list of parameter names.

        bscale : int
            ``BSCALE`` of the data

        bzero : int
            ``BZERO`` of the data

        parbscales : sequence of int
            list of bscales for the parameters

        parbzeros : sequence of int
            list of bzeros for the parameters
        """
        if not isinstance(input, FITS_rec):
            if pardata is None:
                npars = 0
            else:
                npars = len(pardata)

            if parbscales is None:
                parbscales = [None] * npars
            if parbzeros is None:
                parbzeros = [None] * npars

            if parnames is None:
                parnames = [f"PAR{idx + 1}" for idx in range(npars)]

            if len(parnames) != npars:
                raise ValueError(
                    "The number of parameter data arrays does "
                    "not match the number of parameters."
                )

            unique_parnames = _unique_parnames(parnames + ["DATA"])

            if bitpix is None:
                bitpix = DTYPE2BITPIX[input.dtype.name]

            fits_fmt = GroupsHDU._bitpix2tform[bitpix]  # -32 -> 'E'
            format = FITS2NUMPY[fits_fmt]  # 'E' -> 'f4'
            data_fmt = f"{input.shape[1:]}{format}"
            formats = ",".join(([format] * npars) + [data_fmt])
            gcount = input.shape[0]

            cols = [
                Column(
                    name=unique_parnames[idx],
                    format=fits_fmt,
                    bscale=parbscales[idx],
                    bzero=parbzeros[idx],
                )
                for idx in range(npars)
            ]
            cols.append(
                Column(
                    name=unique_parnames[-1],
                    format=fits_fmt,
                    bscale=bscale,
                    bzero=bzero,
                )
            )

            coldefs = ColDefs(cols)

            self = FITS_rec.__new__(
                cls,
                np.rec.array(None, formats=formats, names=coldefs.names, shape=gcount),
            )

            # By default the data field will just be 'DATA', but it may be
            # uniquified if 'DATA' is already used by one of the group names
            self._data_field = unique_parnames[-1]

            self._coldefs = coldefs
            self.parnames = parnames

            for idx, name in enumerate(unique_parnames[:-1]):
                column = coldefs[idx]
                # Note: _get_scale_factors is used here and in other cases
                # below to determine whether the column has non-default
                # scale/zero factors.
                # TODO: Find a better way to do this than using this interface
                scale, zero = self._get_scale_factors(column)[3:5]
                if scale or zero:
                    self._cache_field(name, pardata[idx])
                else:
                    np.rec.recarray.field(self, idx)[:] = pardata[idx]

            column = coldefs[self._data_field]
            scale, zero = self._get_scale_factors(column)[3:5]
            if scale or zero:
                self._cache_field(self._data_field, input)
            else:
                np.rec.recarray.field(self, npars)[:] = input
        else:
            self = FITS_rec.__new__(cls, input)
            self.parnames = None
        return self

    def __array_finalize__(self, obj):
        super().__array_finalize__(obj)
        if isinstance(obj, GroupData):
            self.parnames = obj.parnames
        elif isinstance(obj, FITS_rec):
            self.parnames = obj._coldefs.names

    def __getitem__(self, key):
        out = super().__getitem__(key)
        if isinstance(out, GroupData):
            out.parnames = self.parnames
        return out

    @property
    def data(self):
        """
        The raw group data represented as a multi-dimensional `numpy.ndarray`
        array.
        """
        # The last column in the coldefs is the data portion of the group
        return self.field(self._coldefs.names[-1])

    @lazyproperty
    def _unique(self):
        return _par_indices(self.parnames)

    def par(self, parname):
        """
        Get the group parameter values.
        """
        if _is_int(parname):
            result = self.field(parname)
        else:
            indx = self._unique[parname.upper()]
            if len(indx) == 1:
                result = self.field(indx[0])

            # if more than one group parameter have the same name
            else:
                result = self.field(indx[0]).astype("f8")
                for i in indx[1:]:
                    result += self.field(i)

        return result


class GroupsHDU(PrimaryHDU, _TableLikeHDU):
    """
    FITS Random Groups HDU class.

    See the :ref:`astropy:random-groups` section in the Astropy documentation
    for more details on working with this type of HDU.
    """

    _bitpix2tform = {8: "B", 16: "I", 32: "J", 64: "K", -32: "E", -64: "D"}
    _data_type = GroupData
    _data_field = "DATA"
    """
    The name of the table record array field that will contain the group data
    for each group; 'DATA' by default, but may be preceded by any number of
    underscores if 'DATA' is already a parameter name
    """

    def __init__(self, data=None, header=None):
        super().__init__(data=data, header=header)
        if data is not DELAYED:
            self.update_header()

        # Update the axes; GROUPS HDUs should always have at least one axis
        if len(self._axes) <= 0:
            self._axes = [0]
            self._header["NAXIS"] = 1
            self._header.set("NAXIS1", 0, after="NAXIS")

    @classmethod
    def match_header(cls, header):
        keyword = header.cards[0].keyword
        return keyword == "SIMPLE" and "GROUPS" in header and header["GROUPS"] is True

    @lazyproperty
    def data(self):
        """
        The data of a random group FITS file will be like a binary table's
        data.
        """
        if self._axes == [0]:
            return

        data = self._get_tbdata()
        data._coldefs = self.columns
        data.parnames = self.parnames
        del self.columns
        return data

    @lazyproperty
    def parnames(self):
        """The names of the group parameters as described by the header."""
        pcount = self._header["PCOUNT"]
        # The FITS standard doesn't really say what to do if a parname is
        # missing, so for now just assume that won't happen
        return [self._header["PTYPE" + str(idx + 1)] for idx in range(pcount)]

    @lazyproperty
    def columns(self):
        if self._has_data and hasattr(self.data, "_coldefs"):
            return self.data._coldefs

        format = self._bitpix2tform[self._header["BITPIX"]]
        pcount = self._header["PCOUNT"]
        parnames = []
        bscales = []
        bzeros = []

        for idx in range(pcount):
            bscales.append(self._header.get("PSCAL" + str(idx + 1), None))
            bzeros.append(self._header.get("PZERO" + str(idx + 1), None))
            parnames.append(self._header["PTYPE" + str(idx + 1)])

        formats = [format] * len(parnames)
        dim = [None] * len(parnames)

        # Now create columns from collected parameters, but first add the DATA
        # column too, to contain the group data.
        parnames.append("DATA")
        bscales.append(self._header.get("BSCALE"))
        bzeros.append(self._header.get("BZEROS"))
        data_shape = self.shape[:-1]
        formats.append(str(int(np.prod(data_shape))) + format)
        dim.append(data_shape)
        parnames = _unique_parnames(parnames)

        self._data_field = parnames[-1]

        cols = [
            Column(name=name, format=fmt, bscale=bscale, bzero=bzero, dim=dim)
            for name, fmt, bscale, bzero, dim in zip(
                parnames, formats, bscales, bzeros, dim
            )
        ]

        coldefs = ColDefs(cols)
        return coldefs

    @property
    def _nrows(self):
        if not self._data_loaded:
            # The number of 'groups' equates to the number of rows in the table
            # representation of the data
            return self._header.get("GCOUNT", 0)
        else:
            return len(self.data)

    @lazyproperty
    def _theap(self):
        # Only really a lazyproperty for symmetry with _TableBaseHDU
        return 0

    @property
    def is_image(self):
        return False

    @property
    def size(self):
        """
        Returns the size (in bytes) of the HDU's data part.
        """
        size = 0
        naxis = self._header.get("NAXIS", 0)

        # for random group image, NAXIS1 should be 0, so we skip NAXIS1.
        if naxis > 1:
            size = 1
            for idx in range(1, naxis):
                size = size * self._header["NAXIS" + str(idx + 1)]
            bitpix = self._header["BITPIX"]
            gcount = self._header.get("GCOUNT", 1)
            pcount = self._header.get("PCOUNT", 0)
            size = abs(bitpix) * gcount * (pcount + size) // 8
        return size

    def update_header(self):
        old_naxis = self._header.get("NAXIS", 0)

        if self._data_loaded:
            if isinstance(self.data, GroupData):
                self._axes = list(self.data.data.shape)[1:]
                self._axes.reverse()
                self._axes = [0] + self._axes
                field0 = self.data.dtype.names[0]
                field0_code = self.data.dtype.fields[field0][0].name
            elif self.data is None:
                self._axes = [0]
                field0_code = "uint8"  # For lack of a better default
            else:
                raise ValueError("incorrect array type")

            self._header["BITPIX"] = DTYPE2BITPIX[field0_code]

        self._header["NAXIS"] = len(self._axes)

        # add NAXISi if it does not exist
        for idx, axis in enumerate(self._axes):
            if idx == 0:
                after = "NAXIS"
            else:
                after = "NAXIS" + str(idx)

            self._header.set("NAXIS" + str(idx + 1), axis, after=after)

        # delete extra NAXISi's
        for idx in range(len(self._axes) + 1, old_naxis + 1):
            try:
                del self._header["NAXIS" + str(idx)]
            except KeyError:
                pass

        if self._has_data and isinstance(self.data, GroupData):
            self._header.set("GROUPS", True, after="NAXIS" + str(len(self._axes)))
            self._header.set("PCOUNT", len(self.data.parnames), after="GROUPS")
            self._header.set("GCOUNT", len(self.data), after="PCOUNT")

            column = self.data._coldefs[self._data_field]
            scale, zero = self.data._get_scale_factors(column)[3:5]
            if scale:
                self._header.set("BSCALE", column.bscale)
            if zero:
                self._header.set("BZERO", column.bzero)

            for idx, name in enumerate(self.data.parnames):
                self._header.set("PTYPE" + str(idx + 1), name)
                column = self.data._coldefs[idx]
                scale, zero = self.data._get_scale_factors(column)[3:5]
                if scale:
                    self._header.set("PSCAL" + str(idx + 1), column.bscale)
                if zero:
                    self._header.set("PZERO" + str(idx + 1), column.bzero)

        # Update the position of the EXTEND keyword if it already exists
        if "EXTEND" in self._header:
            if len(self._axes):
                after = "NAXIS" + str(len(self._axes))
            else:
                after = "NAXIS"
            self._header.set("EXTEND", after=after)

    def _writedata_internal(self, fileobj):
        """
        Basically copy/pasted from `_ImageBaseHDU._writedata_internal()`, but
        we have to get the data's byte order a different way...

        TODO: Might be nice to store some indication of the data's byte order
        as an attribute or function so that we don't have to do this.
        """
        size = 0

        if self.data is not None:
            self.data._scale_back()

            # Based on the system type, determine the byteorders that
            # would need to be swapped to get to big-endian output
            if sys.byteorder == "little":
                swap_types = ("<", "=")
            else:
                swap_types = ("<",)
            # deal with unsigned integer 16, 32 and 64 data
            if _is_pseudo_integer(self.data.dtype):
                # Convert the unsigned array to signed
                output = np.array(
                    self.data - _pseudo_zero(self.data.dtype),
                    dtype=f">i{self.data.dtype.itemsize}",
                )
                should_swap = False
            else:
                output = self.data
                fname = self.data.dtype.names[0]
                byteorder = self.data.dtype.fields[fname][0].str[0]
                should_swap = byteorder in swap_types

            if should_swap:
                if output.flags.writeable:
                    output.byteswap(True)
                    try:
                        fileobj.writearray(output)
                    finally:
                        output.byteswap(True)
                else:
                    # For read-only arrays, there is no way around making
                    # a byteswapped copy of the data.
                    fileobj.writearray(output.byteswap(False))
            else:
                fileobj.writearray(output)

            size += output.size * output.itemsize
        return size

    def _verify(self, option="warn"):
        errs = super()._verify(option=option)

        # Verify locations and values of mandatory keywords.
        self.req_cards(
            "NAXIS", 2, lambda v: (_is_int(v) and 1 <= v <= 999), 1, option, errs
        )
        self.req_cards("NAXIS1", 3, lambda v: (_is_int(v) and v == 0), 0, option, errs)

        after = self._header["NAXIS"] + 3
        pos = lambda x: x >= after

        self.req_cards("GCOUNT", pos, _is_int, 1, option, errs)
        self.req_cards("PCOUNT", pos, _is_int, 0, option, errs)
        self.req_cards("GROUPS", pos, lambda v: (v is True), True, option, errs)
        return errs

    def _calculate_datasum(self):
        """
        Calculate the value for the ``DATASUM`` card in the HDU.
        """
        if self._has_data:
            # We have the data to be used.

            # Check the byte order of the data.  If it is little endian we
            # must swap it before calculating the datasum.
            # TODO: Maybe check this on a per-field basis instead of assuming
            # that all fields have the same byte order?
            byteorder = self.data.dtype.fields[self.data.dtype.names[0]][0].str[0]

            if byteorder != ">":
                if self.data.flags.writeable:
                    byteswapped = True
                    d = self.data.byteswap(True)
                    d.dtype = d.dtype.newbyteorder(">")
                else:
                    # If the data is not writeable, we just make a byteswapped
                    # copy and don't bother changing it back after
                    d = self.data.byteswap(False)
                    d.dtype = d.dtype.newbyteorder(">")
                    byteswapped = False
            else:
                byteswapped = False
                d = self.data

            byte_data = d.view(type=np.ndarray, dtype=np.ubyte)

            cs = self._compute_checksum(byte_data)

            # If the data was byteswapped in this method then return it to
            # its original little-endian order.
            if byteswapped:
                d.byteswap(True)
                d.dtype = d.dtype.newbyteorder("<")

            return cs
        else:
            # This is the case where the data has not been read from the file
            # yet.  We can handle that in a generic manner so we do it in the
            # base class.  The other possibility is that there is no data at
            # all.  This can also be handled in a generic manner.
            return super()._calculate_datasum()

    def _summary(self):
        summary = super()._summary()
        name, ver, classname, length, shape, format, gcount = summary

        # Drop the first axis from the shape
        if shape:
            shape = shape[1:]

            if shape and all(shape):
                # Update the format
                format = self.columns[0].dtype.name

        # Update the GCOUNT report
        gcount = f"{self._gcount} Groups  {self._pcount} Parameters"
        return (name, ver, classname, length, shape, format, gcount)


def _par_indices(names):
    """
    Given a list of objects, returns a mapping of objects in that list to the
    index or indices at which that object was found in the list.
    """
    unique = {}
    for idx, name in enumerate(names):
        # Case insensitive
        name = name.upper()
        if name in unique:
            unique[name].append(idx)
        else:
            unique[name] = [idx]
    return unique


def _unique_parnames(names):
    """
    Given a list of parnames, including possible duplicates, returns a new list
    of parnames with duplicates prepended by one or more underscores to make
    them unique.  This is also case insensitive.
    """
    upper_names = set()
    unique_names = []

    for name in names:
        name_upper = name.upper()
        while name_upper in upper_names:
            name = "_" + name
            name_upper = "_" + name_upper

        unique_names.append(name)
        upper_names.add(name_upper)

    return unique_names
