# Licensed under a 3-clause BSD style license - see LICENSE.rst

import abc
import contextlib
import re
import warnings
from collections import OrderedDict
from operator import itemgetter

import numpy as np

__all__ = ['IORegistryError']


class IORegistryError(Exception):
    """Custom error for registry clashes.
    """
    pass


# -----------------------------------------------------------------------------

class _UnifiedIORegistryBase(metaclass=abc.ABCMeta):
    """Base class for registries in Astropy's Unified IO.

    This base class provides identification functions and miscellaneous
    utilities. For an example how to build a registry subclass we suggest
    :class:`~astropy.io.registry.UnifiedInputRegistry`, which enables
    read-only registries. These higher-level subclasses will probably serve
    better as a baseclass, for instance
    :class:`~astropy.io.registry.UnifiedIORegistry` subclasses both
    :class:`~astropy.io.registry.UnifiedInputRegistry` and
    :class:`~astropy.io.registry.UnifiedOutputRegistry` to enable both
    reading from and writing to files.

    .. versionadded:: 5.0

    """

    def __init__(self):
        # registry of identifier functions
        self._identifiers = OrderedDict()

        # what this class can do: e.g. 'read' &/or 'write'
        self._registries = dict()
        self._registries["identify"] = dict(attr="_identifiers", column="Auto-identify")
        self._registries_order = ("identify", )  # match keys in `_registries`

        # If multiple formats are added to one class the update of the docs is quite
        # expensive. Classes for which the doc update is temporarly delayed are added
        # to this set.
        self._delayed_docs_classes = set()

    @property
    def available_registries(self):
        """Available registries.

        Returns
        -------
        ``dict_keys``
        """
        return self._registries.keys()

    def get_formats(self, data_class=None, filter_on=None):
        """
        Get the list of registered formats as a `~astropy.table.Table`.

        Parameters
        ----------
        data_class : class or None, optional
            Filter readers/writer to match data class (default = all classes).
        filter_on : str or None, optional
            Which registry to show. E.g. "identify"
            If None search for both.  Default is None.

        Returns
        -------
        format_table : :class:`~astropy.table.Table`
            Table of available I/O formats.

        Raises
        ------
        ValueError
            If ``filter_on`` is not None nor a registry name.
        """
        from astropy.table import Table

        # set up the column names
        colnames = (
            "Data class", "Format",
            *[self._registries[k]["column"] for k in self._registries_order],
            "Deprecated")
        i_dataclass = colnames.index("Data class")
        i_format = colnames.index("Format")
        i_regstart = colnames.index(self._registries[self._registries_order[0]]["column"])
        i_deprecated = colnames.index("Deprecated")

        # registries
        regs = set()
        for k in self._registries.keys() - {"identify"}:
            regs |= set(getattr(self, self._registries[k]["attr"]))
        format_classes = sorted(regs, key=itemgetter(0))
        # the format classes from all registries except "identify"

        rows = []
        for (fmt, cls) in format_classes:
            # see if can skip, else need to document in row
            if (data_class is not None and not self._is_best_match(
                data_class, cls, format_classes)):
                continue

            # flags for each registry
            has_ = {k: "Yes" if (fmt, cls) in getattr(self, v["attr"]) else "No"
                    for k, v in self._registries.items()}

            # Check if this is a short name (e.g. 'rdb') which is deprecated in
            # favor of the full 'ascii.rdb'.
            ascii_format_class = ('ascii.' + fmt, cls)
            # deprecation flag
            deprecated = "Yes" if ascii_format_class in format_classes else ""

            # add to rows
            rows.append((cls.__name__, fmt,
                         *[has_[n] for n in self._registries_order], deprecated))

        # filter_on can be in self_registries_order or None
        if str(filter_on).lower() in self._registries_order:
            index = self._registries_order.index(str(filter_on).lower())
            rows = [row for row in rows if row[i_regstart + index] == 'Yes']
        elif filter_on is not None:
            raise ValueError('unrecognized value for "filter_on": {0}.\n'
                             f'Allowed are {self._registries_order} and None.')

        # Sorting the list of tuples is much faster than sorting it after the
        # table is created. (#5262)
        if rows:
            # Indices represent "Data Class", "Deprecated" and "Format".
            data = list(zip(*sorted(
                rows, key=itemgetter(i_dataclass, i_deprecated, i_format))))
        else:
            data = None

        # make table
        # need to filter elementwise comparison failure issue
        # https://github.com/numpy/numpy/issues/6784
        with warnings.catch_warnings():
            warnings.simplefilter(action='ignore', category=FutureWarning)

            format_table = Table(data, names=colnames)
            if not np.any(format_table['Deprecated'].data == 'Yes'):
                format_table.remove_column('Deprecated')

        return format_table

    @contextlib.contextmanager
    def delay_doc_updates(self, cls):
        """Contextmanager to disable documentation updates when registering
        reader and writer. The documentation is only built once when the
        contextmanager exits.

        .. versionadded:: 1.3

        Parameters
        ----------
        cls : class
            Class for which the documentation updates should be delayed.

        Notes
        -----
        Registering multiple readers and writers can cause significant overhead
        because the documentation of the corresponding ``read`` and ``write``
        methods are build every time.

        Examples
        --------
        see for example the source code of ``astropy.table.__init__``.
        """
        self._delayed_docs_classes.add(cls)

        yield

        self._delayed_docs_classes.discard(cls)
        for method in self._registries.keys() - {"identify"}:
            self._update__doc__(cls, method)

    # =========================================================================
    # Identifier methods

    def register_identifier(self, data_format, data_class, identifier, force=False):
        """
        Associate an identifier function with a specific data type.

        Parameters
        ----------
        data_format : str
            The data format identifier. This is the string that is used to
            specify the data type when reading/writing.
        data_class : class
            The class of the object that can be written.
        identifier : function
            A function that checks the argument specified to `read` or `write` to
            determine whether the input can be interpreted as a table of type
            ``data_format``. This function should take the following arguments:

               - ``origin``: A string ``"read"`` or ``"write"`` identifying whether
                 the file is to be opened for reading or writing.
               - ``path``: The path to the file.
               - ``fileobj``: An open file object to read the file's contents, or
                 `None` if the file could not be opened.
               - ``*args``: Positional arguments for the `read` or `write`
                 function.
               - ``**kwargs``: Keyword arguments for the `read` or `write`
                 function.

            One or both of ``path`` or ``fileobj`` may be `None`.  If they are
            both `None`, the identifier will need to work from ``args[0]``.

            The function should return True if the input can be identified
            as being of format ``data_format``, and False otherwise.
        force : bool, optional
            Whether to override any existing function if already present.
            Default is ``False``.

        Examples
        --------
        To set the identifier based on extensions, for formats that take a
        filename as a first argument, you can do for example

        .. code-block:: python

            from astropy.io.registry import register_identifier
            from astropy.table import Table
            def my_identifier(*args, **kwargs):
                return isinstance(args[0], str) and args[0].endswith('.tbl')
            register_identifier('ipac', Table, my_identifier)
            unregister_identifier('ipac', Table)
        """
        if not (data_format, data_class) in self._identifiers or force:
            self._identifiers[(data_format, data_class)] = identifier
        else:
            raise IORegistryError("Identifier for format '{}' and class '{}' is "
                                  'already defined'.format(data_format,
                                                           data_class.__name__))

    def unregister_identifier(self, data_format, data_class):
        """
        Unregister an identifier function

        Parameters
        ----------
        data_format : str
            The data format identifier.
        data_class : class
            The class of the object that can be read/written.
        """
        if (data_format, data_class) in self._identifiers:
            self._identifiers.pop((data_format, data_class))
        else:
            raise IORegistryError("No identifier defined for format '{}' and class"
                                  " '{}'".format(data_format, data_class.__name__))

    def identify_format(self, origin, data_class_required, path, fileobj, args, kwargs):
        """Loop through identifiers to see which formats match.

        Parameters
        ----------
        origin : str
            A string ``"read`` or ``"write"`` identifying whether the file is to be
            opened for reading or writing.
        data_class_required : object
            The specified class for the result of `read` or the class that is to be
            written.
        path : str or path-like or None
            The path to the file or None.
        fileobj : file-like or None.
            An open file object to read the file's contents, or ``None`` if the
            file could not be opened.
        args : sequence
            Positional arguments for the `read` or `write` function. Note that
            these must be provided as sequence.
        kwargs : dict-like
            Keyword arguments for the `read` or `write` function. Note that this
            parameter must be `dict`-like.

        Returns
        -------
        valid_formats : list
            List of matching formats.
        """
        valid_formats = []
        for data_format, data_class in self._identifiers:
            if self._is_best_match(data_class_required, data_class, self._identifiers):
                if self._identifiers[(data_format, data_class)](
                        origin, path, fileobj, *args, **kwargs):
                    valid_formats.append(data_format)

        return valid_formats

    # =========================================================================
    # Utils

    def _get_format_table_str(self, data_class, filter_on):
        """``get_formats()``, without column "Data class", as a str."""
        format_table = self.get_formats(data_class, filter_on)
        format_table.remove_column('Data class')
        format_table_str = '\n'.join(format_table.pformat(max_lines=-1))
        return format_table_str

    def _is_best_match(self, class1, class2, format_classes):
        """
        Determine if class2 is the "best" match for class1 in the list
        of classes.  It is assumed that (class2 in classes) is True.
        class2 is the the best match if:

        - ``class1`` is a subclass of ``class2`` AND
        - ``class2`` is the nearest ancestor of ``class1`` that is in classes
          (which includes the case that ``class1 is class2``)
        """
        if issubclass(class1, class2):
            classes = {cls for fmt, cls in format_classes}
            for parent in class1.__mro__:
                if parent is class2:  # class2 is closest registered ancestor
                    return True
                if parent in classes:  # class2 was superceded
                    return False
        return False

    def _get_valid_format(self, mode, cls, path, fileobj, args, kwargs):
        """
        Returns the first valid format that can be used to read/write the data in
        question.  Mode can be either 'read' or 'write'.
        """
        valid_formats = self.identify_format(mode, cls, path, fileobj, args, kwargs)

        if len(valid_formats) == 0:
            format_table_str = self._get_format_table_str(cls, mode.capitalize())
            raise IORegistryError("Format could not be identified based on the"
                                  " file name or contents, please provide a"
                                  " 'format' argument.\n"
                                  "The available formats are:\n"
                                  "{}".format(format_table_str))
        elif len(valid_formats) > 1:
            return self._get_highest_priority_format(mode, cls, valid_formats)

        return valid_formats[0]

    def _get_highest_priority_format(self, mode, cls, valid_formats):
        """
        Returns the reader or writer with the highest priority. If it is a tie,
        error.
        """
        if mode == "read":
            format_dict = self._readers
            mode_loader = "reader"
        elif mode == "write":
            format_dict = self._writers
            mode_loader = "writer"

        best_formats = []
        current_priority = - np.inf
        for format in valid_formats:
            try:
                _, priority = format_dict[(format, cls)]
            except KeyError:
                # We could throw an exception here, but get_reader/get_writer handle
                # this case better, instead maximally deprioritise the format.
                priority = - np.inf

            if priority == current_priority:
                best_formats.append(format)
            elif priority > current_priority:
                best_formats = [format]
                current_priority = priority

        if len(best_formats) > 1:
            raise IORegistryError("Format is ambiguous - options are: {}".format(
                ', '.join(sorted(valid_formats, key=itemgetter(0)))
            ))
        return best_formats[0]

    def _update__doc__(self, data_class, readwrite):
        """
        Update the docstring to include all the available readers / writers for
        the ``data_class.read``/``data_class.write`` functions (respectively).
        Don't update if the data_class does not have the relevant method.
        """
        # abort if method "readwrite" isn't on data_class
        if not hasattr(data_class, readwrite):
            return

        from .interface import UnifiedReadWrite

        FORMATS_TEXT = 'The available built-in formats are:'

        # Get the existing read or write method and its docstring
        class_readwrite_func = getattr(data_class, readwrite)

        if not isinstance(class_readwrite_func.__doc__, str):
            # No docstring--could just be test code, or possibly code compiled
            # without docstrings
            return

        lines = class_readwrite_func.__doc__.splitlines()

        # Find the location of the existing formats table if it exists
        sep_indices = [ii for ii, line in enumerate(lines) if FORMATS_TEXT in line]
        if sep_indices:
            # Chop off the existing formats table, including the initial blank line
            chop_index = sep_indices[0]
            lines = lines[:chop_index]

        # Find the minimum indent, skipping the first line because it might be odd
        matches = [re.search(r'(\S)', line) for line in lines[1:]]
        left_indent = ' ' * min(match.start() for match in matches if match)

        # Get the available unified I/O formats for this class
        # Include only formats that have a reader, and drop the 'Data class' column
        format_table = self.get_formats(data_class, readwrite.capitalize())
        format_table.remove_column('Data class')

        # Get the available formats as a table, then munge the output of pformat()
        # a bit and put it into the docstring.
        new_lines = format_table.pformat(max_lines=-1, max_width=80)
        table_rst_sep = re.sub('-', '=', new_lines[1])
        new_lines[1] = table_rst_sep
        new_lines.insert(0, table_rst_sep)
        new_lines.append(table_rst_sep)

        # Check for deprecated names and include a warning at the end.
        if 'Deprecated' in format_table.colnames:
            new_lines.extend(['',
                              'Deprecated format names like ``aastex`` will be '
                              'removed in a future version. Use the full ',
                              'name (e.g. ``ascii.aastex``) instead.'])

        new_lines = [FORMATS_TEXT, ''] + new_lines
        lines.extend([left_indent + line for line in new_lines])

        # Depending on Python version and whether class_readwrite_func is
        # an instancemethod or classmethod, one of the following will work.
        if isinstance(class_readwrite_func, UnifiedReadWrite):
            class_readwrite_func.__class__.__doc__ = '\n'.join(lines)
        else:
            try:
                class_readwrite_func.__doc__ = '\n'.join(lines)
            except AttributeError:
                class_readwrite_func.__func__.__doc__ = '\n'.join(lines)
