from botocore.retryhandler import (
    ChecksumError,
    CRC32Checker,
    ExceptionRaiser,
    HTTPStatusCodeChecker,
    MaxAttemptsDecorator,
    MultiChecker,
    RetryHandler,
    ServiceErrorCodeChecker,
    _extract_retryable_exception,
    crc32,
    create_retry_action_from_config,
    logger,
)

from ._helpers import resolve_awaitable


def create_retry_handler(config, operation_name=None):
    checker = create_checker_from_retry_config(
        config, operation_name=operation_name
    )
    action = create_retry_action_from_config(
        config, operation_name=operation_name
    )
    return AioRetryHandler(checker=checker, action=action)


def create_checker_from_retry_config(config, operation_name=None):
    checkers = []
    max_attempts = None
    retryable_exceptions = []
    if '__default__' in config:
        policies = config['__default__'].get('policies', [])
        max_attempts = config['__default__']['max_attempts']
        for key in policies:
            current_config = policies[key]
            checkers.append(_create_single_checker(current_config))
            retry_exception = _extract_retryable_exception(current_config)
            if retry_exception is not None:
                retryable_exceptions.extend(retry_exception)
    if operation_name is not None and config.get(operation_name) is not None:
        operation_policies = config[operation_name]['policies']
        for key in operation_policies:
            checkers.append(_create_single_checker(operation_policies[key]))
            retry_exception = _extract_retryable_exception(
                operation_policies[key]
            )
            if retry_exception is not None:
                retryable_exceptions.extend(retry_exception)
    if len(checkers) == 1:
        # Don't need to use a MultiChecker
        return AioMaxAttemptsDecorator(checkers[0], max_attempts=max_attempts)
    else:
        multi_checker = AioMultiChecker(checkers)
        return AioMaxAttemptsDecorator(
            multi_checker,
            max_attempts=max_attempts,
            retryable_exceptions=tuple(retryable_exceptions),
        )


def _create_single_checker(config):
    if 'response' in config['applies_when']:
        return _create_single_response_checker(
            config['applies_when']['response']
        )
    elif 'socket_errors' in config['applies_when']:
        return ExceptionRaiser()


def _create_single_response_checker(response):
    if 'service_error_code' in response:
        checker = ServiceErrorCodeChecker(
            status_code=response['http_status_code'],
            error_code=response['service_error_code'],
        )
    elif 'http_status_code' in response:
        checker = HTTPStatusCodeChecker(
            status_code=response['http_status_code']
        )
    elif 'crc32body' in response:
        checker = AioCRC32Checker(header=response['crc32body'])
    else:
        # TODO: send a signal.
        raise ValueError("Unknown retry policy")
    return checker


class AioRetryHandler(RetryHandler):
    async def _call(self, attempts, response, caught_exception, **kwargs):
        """Handler for a retry.

        Intended to be hooked up to an event handler (hence the **kwargs),
        this will process retries appropriately.

        """
        checker_kwargs = {
            'attempt_number': attempts,
            'response': response,
            'caught_exception': caught_exception,
        }
        if isinstance(self._checker, MaxAttemptsDecorator):
            retries_context = kwargs['request_dict']['context'].get('retries')
            checker_kwargs.update({'retries_context': retries_context})

        if await resolve_awaitable(self._checker(**checker_kwargs)):
            result = self._action(attempts=attempts)
            logger.debug("Retry needed, action of: %s", result)
            return result
        logger.debug("No retry needed.")

    def __call__(self, *args, **kwargs):
        return self._call(*args, **kwargs)  # return awaitable


class AioMaxAttemptsDecorator(MaxAttemptsDecorator):
    async def _call(
        self, attempt_number, response, caught_exception, retries_context
    ):
        if retries_context:
            retries_context['max'] = max(
                retries_context.get('max', 0), self._max_attempts
            )

        should_retry = await self._should_retry(
            attempt_number, response, caught_exception
        )
        if should_retry:
            if attempt_number >= self._max_attempts:
                # explicitly set MaxAttemptsReached
                if response is not None and 'ResponseMetadata' in response[1]:
                    response[1]['ResponseMetadata'][
                        'MaxAttemptsReached'
                    ] = True
                logger.debug(
                    "Reached the maximum number of retry attempts: %s",
                    attempt_number,
                )
                return False
            else:
                return should_retry
        else:
            return False

    def __call__(self, *args, **kwargs):
        return self._call(*args, **kwargs)

    async def _should_retry(self, attempt_number, response, caught_exception):
        if self._retryable_exceptions and attempt_number < self._max_attempts:
            try:
                return await resolve_awaitable(
                    self._checker(attempt_number, response, caught_exception)
                )
            except self._retryable_exceptions as e:
                logger.debug(
                    "retry needed, retryable exception caught: %s",
                    e,
                    exc_info=True,
                )
                return True
        else:
            # If we've exceeded the max attempts we just let the exception
            # propagate if one has occurred.
            return await resolve_awaitable(
                self._checker(attempt_number, response, caught_exception)
            )


class AioMultiChecker(MultiChecker):
    async def _call(self, attempt_number, response, caught_exception):
        for checker in self._checkers:
            checker_response = await resolve_awaitable(
                checker(attempt_number, response, caught_exception)
            )
            if checker_response:
                return checker_response
        return False

    def __call__(self, *args, **kwargs):
        return self._call(*args, **kwargs)


class AioCRC32Checker(CRC32Checker):
    async def _call(self, attempt_number, response, caught_exception):
        if response is not None:
            return await self._check_response(attempt_number, response)
        elif caught_exception is not None:
            return self._check_caught_exception(
                attempt_number, caught_exception
            )
        else:
            raise ValueError("Both response and caught_exception are None.")

    def __call__(self, *args, **kwargs):
        return self._call(*args, **kwargs)

    async def _check_response(self, attempt_number, response):
        http_response = response[0]
        expected_crc = http_response.headers.get(self._header_name)
        if expected_crc is None:
            logger.debug(
                "crc32 check skipped, the %s header is not "
                "in the http response.",
                self._header_name,
            )
        else:
            actual_crc32 = crc32(await response[0].content) & 0xFFFFFFFF
            if not actual_crc32 == int(expected_crc):
                logger.debug(
                    "retry needed: crc32 check failed, expected != actual: "
                    "%s != %s",
                    int(expected_crc),
                    actual_crc32,
                )
                raise ChecksumError(
                    checksum_type='crc32',
                    expected_checksum=int(expected_crc),
                    actual_checksum=actual_crc32,
                )
