from __future__ import annotations

import contextlib
import logging
import pathlib
import subprocess
import time
import uuid
from collections.abc import Iterator, Sequence
from typing import Any, Literal
from urllib.parse import quote

from toolz.itertoolz import partition

from distributed import get_client
from distributed.worker import Worker

try:
    import memray
except ImportError:
    raise ImportError("You have to install memray to use this module.")

logger = logging.getLogger(__name__)


def _start_memray(dask_worker: Worker, filename: str, **kwargs: Any) -> bool:
    """Start the memray Tracker on a Server"""
    if hasattr(dask_worker, "_memray"):
        dask_worker._memray.close()

    path = pathlib.Path(dask_worker.local_directory) / (filename + str(dask_worker.id))
    if path.exists():
        path.rmdir()

    dask_worker._memray = contextlib.ExitStack()  # type: ignore[attr-defined]
    dask_worker._memray.enter_context(  # type: ignore[attr-defined]
        memray.Tracker(path, **kwargs)
    )

    return True


def _fetch_memray_profile(
    dask_worker: Worker, filename: str, report_args: Sequence[str] | Literal[False]
) -> bytes:
    """Generate and fetch the memray report"""
    if not hasattr(dask_worker, "_memray"):
        return b""
    path = pathlib.Path(dask_worker.local_directory) / (filename + str(dask_worker.id))
    dask_worker._memray.close()
    del dask_worker._memray

    if not report_args:
        with open(path, "rb") as fd:
            return fd.read()

    report_filename = path.with_suffix(".html")
    if not report_args[0] == "memray":
        report_args = ["memray"] + list(report_args)
    assert "-f" not in report_args, "Cannot provide filename for report generation"
    assert (
        "-o" not in report_args
    ), "Cannot provide output filename for report generation"
    report_args = list(report_args) + ["-f", str(path), "-o", str(report_filename)]
    subprocess.run(report_args)
    with open(report_filename, "rb") as fd:
        return fd.read()


@contextlib.contextmanager
def memray_workers(
    directory: str | pathlib.Path = "memray-profiles",
    workers: int | None | list[str] = None,
    report_args: Sequence[str] | Literal[False] = (
        "flamegraph",
        "--temporal",
        "--leaks",
    ),
    fetch_reports_parallel: bool | int = True,
    **memray_kwargs: Any,
) -> Iterator[None]:
    """Generate a Memray profile on the workers and download the generated report.

    Example::

        with memray_workers():
            client.submit(my_function).result()

        # Or even while the computation is already running

        fut = client.submit(my_function)

        with memray_workers():
            time.sleep(10)

        fut.result()

    Parameters
    ----------
    directory : str
        The directory to save the reports to.
    workers : int | None | list[str]
        The workers to profile. If int, the first n workers will be used.
        If None, all workers will be used.
        If list[str], the workers with the given addresses will be used.
    report_args : tuple[str]
        Particularly for native_traces=True, the reports have to be
        generated on the same host using the same Python interpreter as the
        profile was generated. Otherwise, native traces will yield unusable
        results. Therefore, we're generating the reports on the workers and
        download them afterwards. You can modify the report generation by
        providing additional arguments and we will generate the reports as::

            memray *report_args -f <filename> -o <filename>.html

        If the raw data should be fetched instead of the report, set this to
        False.

    fetch_reports_parallel : bool | int
        Fetching results is sometimes slow and it's sometimes not desired to
        wait for all workers to finish before receiving the first reports.
        This controls how many workers are fetched concurrently.

            int: Number of workers to fetch concurrently
            True: All workers concurrently
            False: One worker at a time

    **memray_kwargs
        Keyword arguments to be passed to memray.Tracker, e.g.
        {"native_traces": True}
    """
    directory = pathlib.Path(directory)
    client = get_client()
    scheduler_info = client.scheduler_info()
    worker_addr = scheduler_info["workers"]
    worker_names = {
        addr: winfo["name"] for addr, winfo in scheduler_info["workers"].items()
    }
    if not workers or isinstance(workers, int):
        nworkers = len(worker_addr)
        if isinstance(workers, int):
            nworkers = workers
        workers = list(worker_addr)[:nworkers]
    workers = list(workers)
    filename = uuid.uuid4().hex
    assert all(client.run(_start_memray, filename=filename, **memray_kwargs).values())
    # Sleep for a brief moment such that we get
    # a clear profiling signal when everything starts
    time.sleep(0.1)
    try:
        yield
    finally:
        directory.mkdir(exist_ok=True)

        client = get_client()
        if fetch_reports_parallel is True:
            fetch_parallel = len(workers)
        elif fetch_reports_parallel is False:
            fetch_parallel = 1
        else:
            fetch_parallel = fetch_reports_parallel

        for w in partition(fetch_parallel, workers):
            try:
                profiles = client.run(
                    _fetch_memray_profile,
                    filename=filename,
                    report_args=report_args,
                    workers=w,
                )
                for worker_addr, profile in profiles.items():
                    path = directory / quote(str(worker_names[worker_addr]), safe="")
                    if report_args:
                        suffix = ".html"
                    else:
                        suffix = ".memray"
                    with open(str(path) + suffix, "wb") as fd:
                        fd.write(profile)

            except Exception:
                logger.exception(
                    "Exception during report downloading from worker %s", w
                )


@contextlib.contextmanager
def memray_scheduler(
    directory: str | pathlib.Path = "memray-profiles",
    report_args: Sequence[str] | Literal[False] = (
        "flamegraph",
        "--temporal",
        "--leaks",
    ),
    **memray_kwargs: Any,
) -> Iterator[None]:
    """Generate a Memray profile on the Scheduler and download the generated report.

    Example::

        with memray_scheduler():
            client.submit(my_function).result()

        # Or even while the computation is already running

        fut = client.submit(my_function)

        with memray_scheduler():
            time.sleep(10)

        fut.result()

    Parameters
    ----------
    directory : str
        The directory to save the reports to.
    report_args : tuple[str]
        Particularly for native_traces=True, the reports have to be
        generated on the same host using the same Python interpreter as the
        profile was generated. Otherwise, native traces will yield unusable
        results. Therefore, we're generating the reports on the Scheduler and
        download them afterwards. You can modify the report generation by
        providing additional arguments and we will generate the reports as::

            memray *report_args -f <filename> -o <filename>.html

        If the raw data should be fetched instead of the report, set this to
        False.
    **memray_kwargs
        Keyword arguments to be passed to memray.Tracker, e.g.
        {"native_traces": True}
    """
    directory = pathlib.Path(directory)
    client = get_client()
    filename = uuid.uuid4().hex
    assert client.run_on_scheduler(_start_memray, filename=filename, **memray_kwargs)
    # Sleep for a brief moment such that we get
    # a clear profiling signal when everything starts
    time.sleep(0.1)
    try:
        yield
    finally:
        directory.mkdir(exist_ok=True)

        client = get_client()

        profile = client.run_on_scheduler(
            _fetch_memray_profile,
            filename=filename,
            report_args=report_args,
        )
        path = directory / "scheduler"
        if report_args:
            suffix = ".html"
        else:
            suffix = ".memray"
        with open(str(path) + suffix, "wb") as fd:
            fd.write(profile)
