import asyncio
import logging
from datetime import timedelta

import dateutil.parser
from botocore.compat import total_seconds
from botocore.exceptions import ClientError, TokenRetrievalError
from botocore.tokens import (
    DeferredRefreshableToken,
    FrozenAuthToken,
    SSOTokenProvider,
    TokenProviderChain,
    _utc_now,
)

logger = logging.getLogger(__name__)


def create_token_resolver(session):
    providers = [
        AioSSOTokenProvider(session),
    ]
    return TokenProviderChain(providers=providers)


class AioDeferredRefreshableToken(DeferredRefreshableToken):
    def __init__(
        self, method, refresh_using, time_fetcher=_utc_now
    ):  # noqa: E501, lgtm [py/missing-call-to-init]
        self._time_fetcher = time_fetcher
        self._refresh_using = refresh_using
        self.method = method

        # The frozen token is protected by this lock
        self._refresh_lock = asyncio.Lock()
        self._frozen_token = None
        self._next_refresh = None

    async def get_frozen_token(self):
        await self._refresh()
        return self._frozen_token

    async def _refresh(self):
        # If we don't need to refresh just return
        refresh_type = self._should_refresh()
        if not refresh_type:
            return None

        # Block for refresh if we're in the mandatory refresh window
        block_for_refresh = refresh_type == "mandatory"
        if block_for_refresh or not self._refresh_lock.locked():
            async with self._refresh_lock:
                await self._protected_refresh()

    async def _protected_refresh(self):
        # This should only be called after acquiring the refresh lock
        # Another task may have already refreshed, double check refresh
        refresh_type = self._should_refresh()
        if not refresh_type:
            return None

        try:
            now = self._time_fetcher()
            self._next_refresh = now + timedelta(seconds=self._attempt_timeout)
            self._frozen_token = await self._refresh_using()
        except Exception:
            logger.warning(
                "Refreshing token failed during the %s refresh period.",
                refresh_type,
                exc_info=True,
            )
            if refresh_type == "mandatory":
                # This refresh was mandatory, error must be propagated back
                raise

        if self._is_expired():
            # Fresh credentials should never be expired
            raise TokenRetrievalError(
                provider=self.method,
                error_msg="Token has expired and refresh failed",
            )


class AioSSOTokenProvider(SSOTokenProvider):
    async def _attempt_create_token(self, token):
        response = await self._client.create_token(
            grantType=self._GRANT_TYPE,
            clientId=token["clientId"],
            clientSecret=token["clientSecret"],
            refreshToken=token["refreshToken"],
        )
        expires_in = timedelta(seconds=response["expiresIn"])
        new_token = {
            "startUrl": self._sso_config["sso_start_url"],
            "region": self._sso_config["sso_region"],
            "accessToken": response["accessToken"],
            "expiresAt": self._now() + expires_in,
            # Cache the registration alongside the token
            "clientId": token["clientId"],
            "clientSecret": token["clientSecret"],
            "registrationExpiresAt": token["registrationExpiresAt"],
        }
        if "refreshToken" in response:
            new_token["refreshToken"] = response["refreshToken"]
        logger.info("SSO Token refresh succeeded")
        return new_token

    async def _refresh_access_token(self, token):
        keys = (
            "refreshToken",
            "clientId",
            "clientSecret",
            "registrationExpiresAt",
        )
        missing_keys = [k for k in keys if k not in token]
        if missing_keys:
            msg = f"Unable to refresh SSO token: missing keys: {missing_keys}"
            logger.info(msg)
            return None

        expiry = dateutil.parser.parse(token["registrationExpiresAt"])
        if total_seconds(expiry - self._now()) <= 0:
            logger.info(f"SSO token registration expired at {expiry}")
            return None

        try:
            return await self._attempt_create_token(token)
        except ClientError:
            logger.warning("SSO token refresh attempt failed", exc_info=True)
            return None

    async def _refresher(self):
        start_url = self._sso_config["sso_start_url"]
        session_name = self._sso_config["session_name"]
        logger.info(f"Loading cached SSO token for {session_name}")
        token_dict = self._token_loader(start_url, session_name=session_name)
        expiration = dateutil.parser.parse(token_dict["expiresAt"])
        logger.debug(f"Cached SSO token expires at {expiration}")

        remaining = total_seconds(expiration - self._now())
        if remaining < self._REFRESH_WINDOW:
            new_token_dict = await self._refresh_access_token(token_dict)
            if new_token_dict is not None:
                token_dict = new_token_dict
                expiration = token_dict["expiresAt"]
                self._token_loader.save_token(
                    start_url, token_dict, session_name=session_name
                )

        return FrozenAuthToken(
            token_dict["accessToken"], expiration=expiration
        )

    def load_token(self):
        if self._sso_config is None:
            return None

        return AioDeferredRefreshableToken(
            self.METHOD, self._refresher, time_fetcher=self._now
        )
