# Copyright (C) 2012 Anaconda, Inc
# SPDX-License-Identifier: BSD-3-Clause
"""JLAP consumer."""
from __future__ import annotations

import io
import json
import logging
import pathlib
import pprint
import re
import time
from contextlib import contextmanager
from hashlib import blake2b
from typing import Iterator

import jsonpatch
import zstandard
from requests import HTTPError

from conda.base.context import context
from conda.gateways.connection import Response, Session
from conda.gateways.repodata import (
    ETAG_KEY,
    LAST_MODIFIED_KEY,
    RepodataCache,
    RepodataState,
)

from .core import JLAP

log = logging.getLogger(__name__)


DIGEST_SIZE = 32  # 256 bits

JLAP_KEY = "jlap"
HEADERS = "headers"
NOMINAL_HASH = "blake2_256_nominal"
ON_DISK_HASH = "blake2_256"
LATEST = "latest"
JLAP_UNAVAILABLE = "jlap_unavailable"
ZSTD_UNAVAILABLE = "zstd_unavailable"

# save these headers. at least etag, last-modified, cache-control plus a few
# useful extras.
STORE_HEADERS = {
    "etag",
    "last-modified",
    "cache-control",
    "content-range",
    "content-length",
    "date",
    "content-type",
    "content-encoding",
}


def hash():
    """Ordinary hash."""
    return blake2b(digest_size=DIGEST_SIZE)


class Jlap304NotModified(Exception):
    pass


class JlapSkipZst(Exception):
    pass


class JlapPatchNotFound(LookupError):
    pass


def process_jlap_response(response: Response, pos=0, iv=b""):
    # if response is 304 Not Modified, could return a buffer with only the
    # cached footer...
    if response.status_code == 304:
        raise Jlap304NotModified()

    def lines() -> Iterator[bytes]:
        yield from response.iter_lines(delimiter=b"\n")  # type: ignore

    buffer = JLAP.from_lines(lines(), iv, pos)

    # new iv == initial iv if nothing changed
    pos, footer, _ = buffer[-2]
    footer = json.loads(footer)

    new_state = {
        # we need to save etag, last-modified, cache-control
        "headers": {
            k.lower(): v
            for k, v in response.headers.items()
            if k.lower() in STORE_HEADERS
        },
        "iv": buffer[-3][-1],
        "pos": pos,
        "footer": footer,
    }

    return buffer, new_state


def fetch_jlap(url, pos=0, etag=None, iv=b"", ignore_etag=True, session=None):
    response = request_jlap(
        url, pos=pos, etag=etag, ignore_etag=ignore_etag, session=session
    )
    return process_jlap_response(response, pos=pos, iv=iv)


def request_jlap(
    url, pos=0, etag=None, ignore_etag=True, session: Session | None = None
):
    """Return the part of the remote .jlap file we are interested in."""
    headers = {}
    if pos:
        headers["range"] = f"bytes={pos}-"
    if etag and not ignore_etag:
        headers["if-none-match"] = etag

    log.debug("%s %s", url, headers)

    assert session is not None

    timeout = context.remote_connect_timeout_secs, context.remote_read_timeout_secs
    response = session.get(url, stream=True, headers=headers, timeout=timeout)
    response.raise_for_status()

    if response.request:
        log.debug("request headers: %s", pprint.pformat(response.request.headers))
    else:
        log.debug("response without request.")
    log.debug(
        "response headers: %s",
        pprint.pformat(
            {k: v for k, v in response.headers.items() if k.lower() in STORE_HEADERS}
        ),
    )
    log.debug("status: %d", response.status_code)
    if "range" in headers:
        # 200 is also a possibility that we'd rather not deal with; if the
        # server can't do range requests, also mark jlap as unavailable. Which
        # status codes mean 'try again' instead of 'it will never work'?
        if response.status_code not in (206, 304, 404, 416):
            raise HTTPError(
                f"Unexpected response code for range request {response.status_code}",
                response=response,
            )

    log.info("%s", response)

    return response


def format_hash(hash):
    """Abbreviate hash for formatting."""
    return hash[:16] + "\N{HORIZONTAL ELLIPSIS}"


def find_patches(patches, have, want):
    apply = []
    for patch in reversed(patches):
        if have == want:
            break
        if patch["to"] == want:
            apply.append(patch)
            want = patch["from"]

    if have != want:
        log.debug(f"No patch from local revision {format_hash(have)}")
        raise JlapPatchNotFound(f"No patch from local revision {format_hash(have)}")

    return apply


def apply_patches(data, apply):
    while apply:
        patch = apply.pop()
        log.debug(
            f"{format_hash(patch['from'])} \N{RIGHTWARDS ARROW} {format_hash(patch['to'])}, "
            f"{len(patch['patch'])} steps"
        )
        data = jsonpatch.JsonPatch(patch["patch"]).apply(data, in_place=True)


def withext(url, ext):
    return re.sub(r"(\.\w+)$", ext, url)


@contextmanager
def timeme(message):
    begin = time.monotonic()
    yield
    end = time.monotonic()
    log.debug("%sTook %0.02fs", message, end - begin)


def build_headers(json_path: pathlib.Path, state: RepodataState):
    """Caching headers for a path and state."""
    headers = {}
    # simplify if we require state to be empty when json_path is missing.
    if json_path.exists():
        etag = state.get("_etag")
        if etag:
            headers["if-none-match"] = etag
    return headers


class HashWriter(io.RawIOBase):
    def __init__(self, backing, hasher):
        self.backing = backing
        self.hasher = hasher

    def write(self, b: bytes):
        self.hasher.update(b)
        return self.backing.write(b)

    def close(self):
        self.backing.close()


def download_and_hash(
    hasher,
    url,
    json_path,
    session: Session,
    state: RepodataState | None,
    is_zst=False,
    dest_path: pathlib.Path | None = None,
):
    """Download url if it doesn't exist, passing bytes through hasher.update().

    json_path: Path of old cached data (ignore etag if not exists).
    dest_path: Path to write new data.
    """
    if dest_path is None:
        dest_path = json_path
    state = state or RepodataState()
    headers = build_headers(json_path, state)
    timeout = context.remote_connect_timeout_secs, context.remote_read_timeout_secs
    response = session.get(url, stream=True, timeout=timeout, headers=headers)
    log.debug("%s %s", url, response.headers)
    response.raise_for_status()
    length = 0
    # is there a status code for which we must clear the file?
    if response.status_code == 200:
        if is_zst:
            decompressor = zstandard.ZstdDecompressor()
            writer = decompressor.stream_writer(
                HashWriter(dest_path.open("wb"), hasher), closefd=True  # type: ignore
            )
        else:
            writer = HashWriter(dest_path.open("wb"), hasher)
        with writer as repodata:
            for block in response.iter_content(chunk_size=1 << 14):
                repodata.write(block)
    if response.request:
        log.info("Download %d bytes %r", length, response.request.headers)
    return response  # can be 304 not modified


def request_url_jlap_state(
    url,
    state: RepodataState,
    full_download=False,
    *,
    session: Session,
    cache: RepodataCache,
    temp_path: pathlib.Path,
) -> dict | None:
    jlap_state = state.get(JLAP_KEY, {})
    headers = jlap_state.get(HEADERS, {})
    json_path = cache.cache_path_json

    buffer = JLAP()  # type checks

    if (
        full_download
        or not (NOMINAL_HASH in state and json_path.exists())
        or not state.should_check_format("jlap")
    ):
        hasher = hash()
        with timeme(f"Download complete {url} "):
            # Don't deal with 304 Not Modified if hash unavailable e.g. if
            # cached without jlap
            if NOMINAL_HASH not in state:
                state.pop(ETAG_KEY, None)
                state.pop(LAST_MODIFIED_KEY, None)

            try:
                if state.should_check_format("zst"):
                    response = download_and_hash(
                        hasher,
                        withext(url, ".json.zst"),
                        json_path,  # makes conditional request if exists
                        dest_path=temp_path,  # writes to
                        session=session,
                        state=state,
                        is_zst=True,
                    )
                else:
                    raise JlapSkipZst()
            except (JlapSkipZst, HTTPError) as e:
                if isinstance(e, HTTPError) and e.response.status_code != 404:
                    raise
                if not isinstance(e, JlapSkipZst):
                    # don't update last-checked timestamp on skip
                    state.set_has_format("zst", False)
                    state[ZSTD_UNAVAILABLE] = time.time_ns()  # alternate method
                response = download_and_hash(
                    hasher,
                    withext(url, ".json"),
                    json_path,
                    dest_path=temp_path,
                    session=session,
                    state=state,
                )

            # will we use state['headers'] for caching against
            state["_mod"] = response.headers.get("last-modified")
            state["_etag"] = response.headers.get("etag")
            state["_cache_control"] = response.headers.get("cache-control")

        # was not re-hashed if 304 not modified
        if response.status_code == 200:
            state[NOMINAL_HASH] = state[ON_DISK_HASH] = hasher.hexdigest()

        have = state[NOMINAL_HASH]

        # a jlap buffer with zero patches. the buffer format is (position,
        # payload, checksum) where position is the offset from the beginning of
        # the file; payload is the leading or trailing checksum or other data;
        # and checksum is the running checksum for the file up to that point.
        buffer = JLAP([[-1, "", ""], [0, json.dumps({LATEST: have}), ""], [1, "", ""]])

    else:
        have = state[NOMINAL_HASH]
        # have_hash = state.get(ON_DISK_HASH)

        need_jlap = True
        try:
            iv_hex = jlap_state.get("iv", "")
            pos = jlap_state.get("pos", 0)
            etag = headers.get(ETAG_KEY, None)
            jlap_url = withext(url, ".jlap")
            log.debug("Fetch %s from iv=%s, pos=%s", jlap_url, iv_hex, pos)
            # wrong to read state outside of function, and totally rebuild inside
            buffer, jlap_state = fetch_jlap(
                jlap_url,
                pos=pos,
                etag=etag,
                iv=bytes.fromhex(iv_hex),
                session=session,
                ignore_etag=False,
            )
            state.set_has_format("jlap", True)
            need_jlap = False
        except ValueError:
            log.info("Checksum not OK")
        except IndexError as e:
            log.info("Incomplete file?", exc_info=e)
        except HTTPError as e:
            # If we get a 416 Requested range not satisfiable, the server-side
            # file may have been truncated and we need to fetch from 0
            if e.response.status_code == 404:
                state.set_has_format("jlap", False)
                return request_url_jlap_state(
                    url,
                    state,
                    full_download=True,
                    session=session,
                    cache=cache,
                    temp_path=temp_path,
                )
            log.exception("Requests error")

        if need_jlap:  # retry whole file, if range failed
            try:
                buffer, jlap_state = fetch_jlap(withext(url, ".jlap"), session=session)
            except (ValueError, IndexError) as e:
                log.exception("Error parsing jlap", exc_info=e)
                # a 'latest' hash that we can't achieve, triggering later error handling
                buffer = JLAP(
                    [[-1, "", ""], [0, json.dumps({LATEST: "0" * 32}), ""], [1, "", ""]]
                )
                state.set_has_format("jlap", False)

        state[JLAP_KEY] = jlap_state

    with timeme("Apply Patches "):
        # buffer[0] == previous iv
        # buffer[1:-2] == patches
        # buffer[-2] == footer = new_state["footer"]
        # buffer[-1] == trailing checksum

        patches = list(json.loads(patch) for _, patch, _ in buffer.body)
        _, footer, _ = buffer.penultimate
        want = json.loads(footer)["latest"]

        try:
            apply = find_patches(patches, have, want)
            log.info(
                f"Apply {len(apply)} patches "
                f"{format_hash(have)} \N{RIGHTWARDS ARROW} {format_hash(want)}"
            )

            if apply:
                with timeme("Load "):
                    # we haven't loaded repodata yet; it could fail to parse, or
                    # have the wrong hash.
                    # if this fails, then we also need to fetch again from 0
                    repodata_json = json.loads(cache.load())
                    # XXX cache.state must equal what we started with, otherwise
                    # bail with 'repodata on disk' (indicating another process
                    # downloaded repodata.json in parallel with us)
                    if have != cache.state.get(NOMINAL_HASH):  # or check mtime_ns?
                        log.warn("repodata cache changed during jlap fetch.")
                        return None

                apply_patches(repodata_json, apply)

                with timeme("Write changed "), temp_path.open("wb") as repodata:
                    hasher = hash()
                    HashWriter(repodata, hasher).write(
                        json.dumps(repodata_json).encode("utf-8")
                    )

                    # actual hash of serialized json
                    state[ON_DISK_HASH] = hasher.hexdigest()

                    # hash of equivalent upstream json
                    state[NOMINAL_HASH] = want

                    # avoid duplicate parsing
                    return repodata_json
            else:
                assert state[NOMINAL_HASH] == want

        except (JlapPatchNotFound, json.JSONDecodeError) as e:
            if isinstance(e, JlapPatchNotFound):
                # 'have' hash not mentioned in patchset
                #
                # XXX or skip jlap at top of fn; make sure it is not
                # possible to download the complete json twice
                log.info(
                    "Current repodata.json %s not found in patchset. Re-download repodata.json"
                )

            assert not full_download, "Recursion error"  # pragma: no cover

            return request_url_jlap_state(
                url,
                state,
                full_download=True,
                session=session,
                cache=cache,
                temp_path=temp_path,
            )
