# Licensed under a 3-clause BSD style license - see LICENSE.rst


import numpy as np

from matplotlib import rcParams
from matplotlib.text import Text
import matplotlib.transforms as mtransforms

from .frame import RectangularFrame


class AxisLabels(Text):

    def __init__(self, frame, minpad=1, *args, **kwargs):

        # Use rcParams if the following parameters were not specified explicitly
        if 'weight' not in kwargs:
            kwargs['weight'] = rcParams['axes.labelweight']
        if 'size' not in kwargs:
            kwargs['size'] = rcParams['axes.labelsize']
        if 'color' not in kwargs:
            kwargs['color'] = rcParams['axes.labelcolor']

        self._frame = frame
        super().__init__(*args, **kwargs)
        self.set_clip_on(True)
        self.set_visible_axes('all')
        self.set_ha('center')
        self.set_va('center')
        self._minpad = minpad
        self._visibility_rule = 'labels'

    def get_minpad(self, axis):
        try:
            return self._minpad[axis]
        except TypeError:
            return self._minpad

    def set_visible_axes(self, visible_axes):
        self._visible_axes = visible_axes

    def get_visible_axes(self):
        if self._visible_axes == 'all':
            return self._frame.keys()
        else:
            return [x for x in self._visible_axes if x in self._frame]

    def set_minpad(self, minpad):
        self._minpad = minpad

    def set_visibility_rule(self, value):
        allowed = ['always', 'labels', 'ticks']
        if value not in allowed:
            raise ValueError(f"Axis label visibility rule must be one of{' / '.join(allowed)}")

        self._visibility_rule = value

    def get_visibility_rule(self):
        return self._visibility_rule

    def draw(self, renderer, bboxes, ticklabels_bbox,
             coord_ticklabels_bbox, ticks_locs, visible_ticks):

        if not self.get_visible():
            return

        text_size = renderer.points_to_pixels(self.get_size())
        # Flatten the bboxes for all coords and all axes
        ticklabels_bbox_list = []
        for bbcoord in ticklabels_bbox.values():
            for bbaxis in bbcoord.values():
                ticklabels_bbox_list += bbaxis

        for axis in self.get_visible_axes():
            if self.get_visibility_rule() == 'ticks':
                if not ticks_locs[axis]:
                    continue
            elif self.get_visibility_rule() == 'labels':
                if not coord_ticklabels_bbox:
                    continue

            padding = text_size * self.get_minpad(axis)

            # Find position of the axis label. For now we pick the mid-point
            # along the path but in future we could allow this to be a
            # parameter.
            x, y, normal_angle = self._frame[axis]._halfway_x_y_angle()

            label_angle = (normal_angle - 90.) % 360.
            if 135 < label_angle < 225:
                label_angle += 180
            self.set_rotation(label_angle)

            # Find label position by looking at the bounding box of ticks'
            # labels and the image. It sets the default padding at 1 times the
            # axis label font size which can also be changed by setting
            # the minpad parameter.

            if isinstance(self._frame, RectangularFrame):

                if len(ticklabels_bbox_list) > 0 and ticklabels_bbox_list[0] is not None:
                    coord_ticklabels_bbox[axis] = [mtransforms.Bbox.union(ticklabels_bbox_list)]
                else:
                    coord_ticklabels_bbox[axis] = [None]

                visible = axis in visible_ticks and coord_ticklabels_bbox[axis][0] is not None

                if axis == 'l':
                    if visible:
                        x = coord_ticklabels_bbox[axis][0].xmin
                    x = x - padding

                elif axis == 'r':
                    if visible:
                        x = coord_ticklabels_bbox[axis][0].x1
                    x = x + padding

                elif axis == 'b':
                    if visible:
                        y = coord_ticklabels_bbox[axis][0].ymin
                    y = y - padding

                elif axis == 't':
                    if visible:
                        y = coord_ticklabels_bbox[axis][0].y1
                    y = y + padding

            else:  # arbitrary axis
                x = x + np.cos(np.radians(normal_angle)) * (padding + text_size * 1.5)
                y = y + np.sin(np.radians(normal_angle)) * (padding + text_size * 1.5)

            self.set_position((x, y))
            super().draw(renderer)

            bb = super().get_window_extent(renderer)
            bboxes.append(bb)
