import io

from botocore.httpchecksum import (
    _CHECKSUM_CLS,
    AwsChunkedWrapper,
    FlexibleChecksumError,
    _apply_request_header_checksum,
    _handle_streaming_response,
    base64,
    conditionally_calculate_md5,
    determine_content_length,
    logger,
)

from aiobotocore._helpers import resolve_awaitable


class AioAwsChunkedWrapper(AwsChunkedWrapper):
    async def _make_chunk(self):
        # NOTE: Chunk size is not deterministic as read could return less. This
        # means we cannot know the content length of the encoded aws-chunked
        # stream ahead of time without ensuring a consistent chunk size

        raw_chunk = await resolve_awaitable(self._raw.read(self._chunk_size))
        hex_len = hex(len(raw_chunk))[2:].encode("ascii")
        self._complete = not raw_chunk

        if self._checksum:
            self._checksum.update(raw_chunk)

        if self._checksum and self._complete:
            name = self._checksum_name.encode("ascii")
            checksum = self._checksum.b64digest().encode("ascii")
            return b"0\r\n%s:%s\r\n\r\n" % (name, checksum)

        return b"%s\r\n%s\r\n" % (hex_len, raw_chunk)

    def __aiter__(self):
        return self

    async def __anext__(self):
        while not self._complete:
            return await self._make_chunk()
        raise StopAsyncIteration()


async def handle_checksum_body(
    http_response, response, context, operation_model
):
    headers = response["headers"]
    checksum_context = context.get("checksum", {})
    algorithms = checksum_context.get("response_algorithms")

    if not algorithms:
        return

    for algorithm in algorithms:
        header_name = "x-amz-checksum-%s" % algorithm
        # If the header is not found, check the next algorithm
        if header_name not in headers:
            continue

        # If a - is in the checksum this is not valid Base64. S3 returns
        # checksums that include a -# suffix to indicate a checksum derived
        # from the hash of all part checksums. We cannot wrap this response
        if "-" in headers[header_name]:
            continue

        if operation_model.has_streaming_output:
            response["body"] = _handle_streaming_response(
                http_response, response, algorithm
            )
        else:
            response["body"] = await _handle_bytes_response(
                http_response, response, algorithm
            )

        # Expose metadata that the checksum check actually occurred
        checksum_context = response["context"].get("checksum", {})
        checksum_context["response_algorithm"] = algorithm
        response["context"]["checksum"] = checksum_context
        return

    logger.info(
        f'Skipping checksum validation. Response did not contain one of the '
        f'following algorithms: {algorithms}.'
    )


async def _handle_bytes_response(http_response, response, algorithm):
    body = await http_response.content
    header_name = "x-amz-checksum-%s" % algorithm
    checksum_cls = _CHECKSUM_CLS.get(algorithm)
    checksum = checksum_cls()
    checksum.update(body)
    expected = response["headers"][header_name]
    if checksum.digest() != base64.b64decode(expected):
        error_msg = (
            "Expected checksum %s did not match calculated checksum: %s"
            % (
                expected,
                checksum.b64digest(),
            )
        )
        raise FlexibleChecksumError(error_msg=error_msg)
    return body


def apply_request_checksum(request):
    checksum_context = request.get("context", {}).get("checksum", {})
    algorithm = checksum_context.get("request_algorithm")

    if not algorithm:
        return

    if algorithm == "conditional-md5":
        # Special case to handle the http checksum required trait
        conditionally_calculate_md5(request)
    elif algorithm["in"] == "header":
        _apply_request_header_checksum(request)
    elif algorithm["in"] == "trailer":
        _apply_request_trailer_checksum(request)
    else:
        raise FlexibleChecksumError(
            error_msg="Unknown checksum variant: %s" % algorithm["in"]
        )


def _apply_request_trailer_checksum(request):
    checksum_context = request.get("context", {}).get("checksum", {})
    algorithm = checksum_context.get("request_algorithm")
    location_name = algorithm["name"]
    checksum_cls = _CHECKSUM_CLS.get(algorithm["algorithm"])

    headers = request["headers"]
    body = request["body"]

    if location_name in headers:
        # If the header is already set by the customer, skip calculation
        return

    # Cannot set this as aiohttp complains
    headers["Transfer-Encoding"] = "chunked"
    if "Content-Encoding" in headers:
        # We need to preserve the existing content encoding and add
        # aws-chunked as a new content encoding.
        headers["Content-Encoding"] += ",aws-chunked"
    else:
        headers["Content-Encoding"] = "aws-chunked"
    headers["X-Amz-Trailer"] = location_name

    content_length = determine_content_length(body)
    if content_length is not None:
        # Send the decoded content length if we can determine it. Some
        # services such as S3 may require the decoded content length
        headers["X-Amz-Decoded-Content-Length"] = str(content_length)

    if isinstance(body, (bytes, bytearray)):
        body = io.BytesIO(body)

    request["body"] = AioAwsChunkedWrapper(
        body,
        checksum_cls=checksum_cls,
        checksum_name=location_name,
    )
