# Licensed under a 3-clause BSD style license - see LICENSE.rst


import abc
import warnings
from collections import OrderedDict

import numpy as np
from matplotlib import rcParams
from matplotlib.lines import Line2D, Path
from matplotlib.patches import PathPatch

from astropy.utils.exceptions import AstropyDeprecationWarning

__all__ = [
    "RectangularFrame1D",
    "Spine",
    "BaseFrame",
    "RectangularFrame",
    "EllipticalFrame",
]


class Spine:
    """
    A single side of an axes.

    This does not need to be a straight line, but represents a 'side' when
    determining which part of the frame to put labels and ticks on.

    Parameters
    ----------
    parent_axes : `~astropy.visualization.wcsaxes.WCSAxes`
        The parent axes
    transform : `~matplotlib.transforms.Transform`
        The transform from data to world
    data_func : callable
        If not ``None``, it should be a function that returns the appropriate spine
        data when called with this object as the sole argument.  If ``None``, the
        spine data must be manually updated in ``update_spines()``.
    """

    def __init__(self, parent_axes, transform, *, data_func=None):
        self.parent_axes = parent_axes
        self.transform = transform
        self.data_func = data_func

        self._data = None
        self._world = None

    @property
    def data(self):
        if self._data is None and self.data_func:
            self.data = self.data_func(self)
        return self._data

    @data.setter
    def data(self, value):
        self._data = value
        if value is None:
            self._world = None
        else:
            with np.errstate(invalid="ignore"):
                self._world = self.transform.transform(self._data)
            self._update_normal()

    def _get_pixel(self):
        return self.parent_axes.transData.transform(self._data)

    @property
    def pixel(self):
        warnings.warn(
            "Pixel coordinates cannot be accurately calculated unless "
            "Matplotlib is currently drawing a figure, so the .pixel "
            "attribute is deprecated and will be removed in a future "
            "astropy release.",
            AstropyDeprecationWarning,
        )
        return self._get_pixel()

    @pixel.setter
    def pixel(self, value):
        warnings.warn(
            "Manually setting pixel values of a Spine can lead to incorrect results "
            "as these can only be accurately calculated when Matplotlib is drawing "
            "a figure. As such the .pixel setter now does nothing, is deprecated, "
            "and will be removed in a future astropy release.",
            AstropyDeprecationWarning,
        )

    @property
    def world(self):
        return self._world

    @world.setter
    def world(self, value):
        self._world = value
        if value is None:
            self._data = None
            self._pixel = None
        else:
            self._data = self.transform.transform(value)
            self._pixel = self.parent_axes.transData.transform(self._data)
            self._update_normal()

    def _update_normal(self):
        pixel = self._get_pixel()
        # Find angle normal to border and inwards, in display coordinate
        dx = pixel[1:, 0] - pixel[:-1, 0]
        dy = pixel[1:, 1] - pixel[:-1, 1]
        self.normal_angle = np.degrees(np.arctan2(dx, -dy))

    def _halfway_x_y_angle(self):
        """
        Return the x, y, normal_angle values halfway along the spine.
        """
        pixel = self._get_pixel()
        x_disp, y_disp = pixel[:, 0], pixel[:, 1]
        # Get distance along the path
        d = np.hstack(
            [0.0, np.cumsum(np.sqrt(np.diff(x_disp) ** 2 + np.diff(y_disp) ** 2))]
        )
        xcen = np.interp(d[-1] / 2.0, d, x_disp)
        ycen = np.interp(d[-1] / 2.0, d, y_disp)

        # Find segment along which the mid-point lies
        imin = np.searchsorted(d, d[-1] / 2.0) - 1

        # Find normal of the axis label facing outwards on that segment
        normal_angle = self.normal_angle[imin] + 180.0
        return xcen, ycen, normal_angle


class SpineXAligned(Spine):
    """
    A single side of an axes, aligned with the X data axis.

    This does not need to be a straight line, but represents a 'side' when
    determining which part of the frame to put labels and ticks on.
    """

    @property
    def data(self):
        return self._data

    @data.setter
    def data(self, value):
        self._data = value
        if value is None:
            self._world = None
        else:
            with np.errstate(invalid="ignore"):
                self._world = self.transform.transform(self._data[:, 0:1])
            self._update_normal()


class BaseFrame(OrderedDict, metaclass=abc.ABCMeta):
    """
    Base class for frames, which are collections of
    :class:`~astropy.visualization.wcsaxes.frame.Spine` instances.
    """

    spine_class = Spine

    def __init__(self, parent_axes, transform, path=None):
        super().__init__()

        self.parent_axes = parent_axes
        self._transform = transform
        self._linewidth = rcParams["axes.linewidth"]
        self._color = rcParams["axes.edgecolor"]
        self._path = path

        for axis in self.spine_names:
            self[axis] = self.spine_class(parent_axes, transform)

    @property
    def origin(self):
        ymin, ymax = self.parent_axes.get_ylim()
        return "lower" if ymin < ymax else "upper"

    @property
    def transform(self):
        return self._transform

    @transform.setter
    def transform(self, value):
        self._transform = value
        for axis in self:
            self[axis].transform = value

    def _update_patch_path(self):
        self.update_spines()
        x, y = [], []
        for axis in self.spine_names:
            x.append(self[axis].data[:, 0])
            y.append(self[axis].data[:, 1])
        vertices = np.vstack([np.hstack(x), np.hstack(y)]).transpose()

        if self._path is None:
            self._path = Path(vertices)
        else:
            self._path.vertices = vertices

    @property
    def patch(self):
        self._update_patch_path()
        return PathPatch(
            self._path,
            transform=self.parent_axes.transData,
            facecolor=rcParams["axes.facecolor"],
            edgecolor="white",
        )

    def draw(self, renderer):
        for axis in self:
            pixel = self[axis]._get_pixel()
            x, y = pixel[:, 0], pixel[:, 1]
            line = Line2D(
                x, y, linewidth=self._linewidth, color=self._color, zorder=1000
            )
            line.draw(renderer)

    def sample(self, n_samples):
        self.update_spines()

        spines = OrderedDict()

        for axis in self:
            data = self[axis].data
            spines[axis] = self.spine_class(self.parent_axes, self.transform)
            if data.size > 0:
                p = np.linspace(0.0, 1.0, data.shape[0])
                p_new = np.linspace(0.0, 1.0, n_samples)
                spines[axis].data = np.array(
                    [np.interp(p_new, p, d) for d in data.T]
                ).transpose()
            else:
                spines[axis].data = data

        return spines

    def set_color(self, color):
        """
        Sets the color of the frame.

        Parameters
        ----------
        color : str
            The color of the frame.
        """
        self._color = color

    def get_color(self):
        return self._color

    def set_linewidth(self, linewidth):
        """
        Sets the linewidth of the frame.

        Parameters
        ----------
        linewidth : float
            The linewidth of the frame in points.
        """
        self._linewidth = linewidth

    def get_linewidth(self):
        return self._linewidth

    def update_spines(self):
        for spine in self.values():
            if spine.data_func:
                spine.data = spine.data_func(spine)


class RectangularFrame1D(BaseFrame):
    """
    A classic rectangular frame.
    """

    spine_names = "bt"
    _spine_auto_position_order = "bt"
    spine_class = SpineXAligned

    def update_spines(self):
        xmin, xmax = self.parent_axes.get_xlim()
        ymin, ymax = self.parent_axes.get_ylim()

        self["b"].data = np.array(([xmin, ymin], [xmax, ymin]))
        self["t"].data = np.array(([xmax, ymax], [xmin, ymax]))

        super().update_spines()

    def _update_patch_path(self):
        self.update_spines()

        xmin, xmax = self.parent_axes.get_xlim()
        ymin, ymax = self.parent_axes.get_ylim()

        x = [xmin, xmax, xmax, xmin, xmin]
        y = [ymin, ymin, ymax, ymax, ymin]

        vertices = np.vstack([np.hstack(x), np.hstack(y)]).transpose()

        if self._path is None:
            self._path = Path(vertices)
        else:
            self._path.vertices = vertices

    def draw(self, renderer):
        xmin, xmax = self.parent_axes.get_xlim()
        ymin, ymax = self.parent_axes.get_ylim()

        x = [xmin, xmax, xmax, xmin, xmin]
        y = [ymin, ymin, ymax, ymax, ymin]

        line = Line2D(
            x,
            y,
            linewidth=self._linewidth,
            color=self._color,
            zorder=1000,
            transform=self.parent_axes.transData,
        )
        line.draw(renderer)


class RectangularFrame(BaseFrame):
    """
    A classic rectangular frame.
    """

    spine_names = "brtl"
    _spine_auto_position_order = "bltr"

    def update_spines(self):
        xmin, xmax = self.parent_axes.get_xlim()
        ymin, ymax = self.parent_axes.get_ylim()

        self["b"].data = np.array(([xmin, ymin], [xmax, ymin]))
        self["r"].data = np.array(([xmax, ymin], [xmax, ymax]))
        self["t"].data = np.array(([xmax, ymax], [xmin, ymax]))
        self["l"].data = np.array(([xmin, ymax], [xmin, ymin]))

        super().update_spines()


class EllipticalFrame(BaseFrame):
    """
    An elliptical frame.
    """

    spine_names = "chv"
    _spine_auto_position_order = "chv"

    def update_spines(self):
        xmin, xmax = self.parent_axes.get_xlim()
        ymin, ymax = self.parent_axes.get_ylim()

        xmid = 0.5 * (xmax + xmin)
        ymid = 0.5 * (ymax + ymin)

        dx = xmid - xmin
        dy = ymid - ymin

        theta = np.linspace(0.0, 2 * np.pi, 1000)
        self["c"].data = np.array(
            [xmid + dx * np.cos(theta), ymid + dy * np.sin(theta)]
        ).transpose()
        self["h"].data = np.array(
            [np.linspace(xmin, xmax, 1000), np.repeat(ymid, 1000)]
        ).transpose()
        self["v"].data = np.array(
            [np.repeat(xmid, 1000), np.linspace(ymin, ymax, 1000)]
        ).transpose()

        super().update_spines()

    def _update_patch_path(self):
        """Override path patch to include only the outer ellipse,
        not the major and minor axes in the middle.
        """
        self.update_spines()
        vertices = self["c"].data

        if self._path is None:
            self._path = Path(vertices)
        else:
            self._path.vertices = vertices

    def draw(self, renderer):
        """Override to draw only the outer ellipse,
        not the major and minor axes in the middle.

        FIXME: we may want to add a general method to give the user control
        over which spines are drawn.
        """
        axis = "c"
        pixel = self[axis]._get_pixel()
        line = Line2D(
            pixel[:, 0],
            pixel[:, 1],
            linewidth=self._linewidth,
            color=self._color,
            zorder=1000,
        )
        line.draw(renderer)
