import aioitertools
import jmespath
from botocore.exceptions import PaginationError
from botocore.paginate import PageIterator, Paginator
from botocore.utils import merge_dicts, set_value_from_jmespath


class AioPageIterator(PageIterator):
    def __aiter__(self):
        return self.__anext__()

    async def __anext__(self):
        current_kwargs = self._op_kwargs
        previous_next_token = None
        next_token = {key: None for key in self._input_token}
        if self._starting_token is not None:
            # If the starting token exists, populate the next_token with the
            # values inside it. This ensures that we have the service's
            # pagination token on hand if we need to truncate after the
            # first response.
            next_token = self._parse_starting_token()[0]
        # The number of items from result_key we've seen so far.
        total_items = 0
        first_request = True
        primary_result_key = self.result_keys[0]
        starting_truncation = 0
        self._inject_starting_params(current_kwargs)

        while True:
            response = await self._make_request(current_kwargs)
            parsed = self._extract_parsed_response(response)
            if first_request:
                # The first request is handled differently.  We could
                # possibly have a resume/starting token that tells us where
                # to index into the retrieved page.
                if self._starting_token is not None:
                    starting_truncation = self._handle_first_request(
                        parsed, primary_result_key, starting_truncation
                    )
                first_request = False
                self._record_non_aggregate_key_values(parsed)
            else:
                # If this isn't the first request, we have already sliced into
                # the first request and had to make additional requests after.
                # We no longer need to add this to truncation.
                starting_truncation = 0
            current_response = primary_result_key.search(parsed)
            if current_response is None:
                current_response = []
            num_current_response = len(current_response)
            truncate_amount = 0
            if self._max_items is not None:
                truncate_amount = (
                    total_items + num_current_response - self._max_items
                )

            if truncate_amount > 0:
                self._truncate_response(
                    parsed,
                    primary_result_key,
                    truncate_amount,
                    starting_truncation,
                    next_token,
                )
                yield response
                break
            else:
                yield response
                total_items += num_current_response
                next_token = self._get_next_token(parsed)
                if all(t is None for t in next_token.values()):
                    break
                if (
                    self._max_items is not None
                    and total_items == self._max_items
                ):
                    # We're on a page boundary so we can set the current
                    # next token to be the resume token.
                    self.resume_token = next_token
                    break
                if (
                    previous_next_token is not None
                    and previous_next_token == next_token
                ):
                    message = (
                        f"The same next token was received "
                        f"twice: {next_token}"
                    )
                    raise PaginationError(message=message)
                self._inject_token_into_kwargs(current_kwargs, next_token)
                previous_next_token = next_token

    async def search(self, expression):
        compiled = jmespath.compile(expression)
        async for page in self:
            results = compiled.search(page)
            if isinstance(results, list):
                for element in results:
                    yield element  # unfortunately yield from not avail from async f
            else:
                yield results

    def result_key_iters(self):
        teed_results = aioitertools.tee(self, len(self.result_keys))
        return [
            ResultKeyIterator(i, result_key)
            for i, result_key in zip(teed_results, self.result_keys)
        ]

    async def build_full_result(self):
        complete_result = {}
        async for response in self:
            page = response
            # We want to try to catch operation object pagination
            # and format correctly for those. They come in the form
            # of a tuple of two elements: (http_response, parsed_responsed).
            # We want the parsed_response as that is what the page iterator
            # uses. We can remove it though once operation objects are removed.
            if isinstance(response, tuple) and len(response) == 2:
                page = response[1]
            # We're incrementally building the full response page
            # by page.  For each page in the response we need to
            # inject the necessary components from the page
            # into the complete_result.
            for result_expression in self.result_keys:
                # In order to incrementally update a result key
                # we need to search the existing value from complete_result,
                # then we need to search the _current_ page for the
                # current result key value.  Then we append the current
                # value onto the existing value, and re-set that value
                # as the new value.
                result_value = result_expression.search(page)
                if result_value is None:
                    continue
                existing_value = result_expression.search(complete_result)
                if existing_value is None:
                    # Set the initial result
                    set_value_from_jmespath(
                        complete_result,
                        result_expression.expression,
                        result_value,
                    )
                    continue
                # Now both result_value and existing_value contain something
                if isinstance(result_value, list):
                    existing_value.extend(result_value)
                elif isinstance(result_value, (int, float, str)):
                    # Modify the existing result with the sum or concatenation
                    set_value_from_jmespath(
                        complete_result,
                        result_expression.expression,
                        existing_value + result_value,
                    )
        merge_dicts(complete_result, self.non_aggregate_part)
        if self.resume_token is not None:
            complete_result['NextToken'] = self.resume_token
        return complete_result


class AioPaginator(Paginator):
    PAGE_ITERATOR_CLS = AioPageIterator


class ResultKeyIterator:
    """Iterates over the results of paginated responses.

    Each iterator is associated with a single result key.
    Iterating over this object will give you each element in
    the result key list.

    :param pages_iterator: An iterator that will give you
        pages of results (a ``PageIterator`` class).
    :param result_key: The JMESPath expression representing
        the result key.

    """

    def __init__(self, pages_iterator, result_key):
        self._pages_iterator = pages_iterator
        self.result_key = result_key

    def __aiter__(self):
        return self.__anext__()

    async def __anext__(self):
        async for page in self._pages_iterator:
            results = self.result_key.search(page)
            if results is None:
                results = []
            for result in results:
                yield result  # yield from not avail from async func
