"""An async reimplementation of the blocking elements from botocore.retries.adaptive."""
import asyncio
import logging

from botocore.retries import standard, throttling

# The RateClocker from botocore uses a threading.Lock, but in a single-threaded asyncio
# program, the lock will be acquired then released by the same coroutine without
# blocking.
from botocore.retries.adaptive import RateClocker

from . import bucket

logger = logging.getLogger(__name__)


def register_retry_handler(client):
    clock = bucket.Clock()
    rate_adjustor = throttling.CubicCalculator(
        starting_max_rate=0, start_time=clock.current_time()
    )
    token_bucket = bucket.AsyncTokenBucket(max_rate=1, clock=clock)
    rate_clocker = RateClocker(clock)
    throttling_detector = standard.ThrottlingErrorDetector(
        retry_event_adapter=standard.RetryEventAdapter(),
    )
    limiter = AsyncClientRateLimiter(
        rate_adjustor=rate_adjustor,
        rate_clocker=rate_clocker,
        token_bucket=token_bucket,
        throttling_detector=throttling_detector,
        clock=clock,
    )
    client.meta.events.register(
        'before-send',
        limiter.on_sending_request,
    )
    client.meta.events.register(
        'needs-retry',
        limiter.on_receiving_response,
    )
    return limiter


class AsyncClientRateLimiter:
    """An async reimplementation of ClientRateLimiter."""

    # Most of the code here comes directly from botocore. The main change is making the
    # callbacks async.
    # This doesn't inherit from the botocore ClientRateLimiter for two reasons:
    # * the interface is slightly changed (methods are now async)
    # * we rewrote the entirety of the class anyway

    _MAX_RATE_ADJUST_SCALE = 2.0

    def __init__(
        self,
        rate_adjustor,
        rate_clocker,
        token_bucket,
        throttling_detector,
        clock,
    ):
        self._rate_adjustor = rate_adjustor
        self._rate_clocker = rate_clocker
        self._token_bucket = token_bucket
        self._throttling_detector = throttling_detector
        self._clock = clock
        self._enabled = False
        self._lock = asyncio.Lock()

    async def on_sending_request(self, request, **kwargs):
        if self._enabled:
            await self._token_bucket.acquire()

    # Hooked up to needs-retry.
    async def on_receiving_response(self, **kwargs):
        measured_rate = self._rate_clocker.record()
        timestamp = self._clock.current_time()
        async with self._lock:
            if not self._throttling_detector.is_throttling_error(**kwargs):
                new_rate = self._rate_adjustor.success_received(timestamp)
            else:
                if not self._enabled:
                    rate_to_use = measured_rate
                else:
                    rate_to_use = min(
                        measured_rate, self._token_bucket.max_rate
                    )
                new_rate = self._rate_adjustor.error_received(
                    rate_to_use, timestamp
                )
                logger.debug(
                    "Throttling response received, new send rate: %s "
                    "measured rate: %s, token bucket capacity "
                    "available: %s",
                    new_rate,
                    measured_rate,
                    self._token_bucket.available_capacity,
                )
                self._enabled = True
            await self._token_bucket.set_max_rate(
                min(new_rate, self._MAX_RATE_ADJUST_SCALE * measured_rate)
            )
