from __future__ import annotations

import uuid

from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.utils import TupleComparable


class GraphLayout(SchedulerPlugin):
    """Dynamic graph layout during computation

    This assigns (x, y) locations to all tasks quickly and dynamically as new
    tasks are added.  This scales to a few thousand nodes.

    It is commonly used with distributed/dashboard/components/scheduler.py::TaskGraph, which
    is rendered at /graph on the diagnostic dashboard.
    """

    def __init__(self, scheduler):
        self.name = f"graph-layout-{uuid.uuid4()}"
        self.x = {}
        self.y = {}
        self.collision = {}
        self.scheduler = scheduler
        self.index = {}
        self.index_edge = {}
        self.next_y = 0
        self.next_index = 0
        self.next_edge_index = 0
        self.new = []
        self.new_edges = []
        self.state_updates = []
        self.visible_updates = []
        self.visible_edge_updates = []

        if self.scheduler.tasks:
            dependencies = {
                k: [ds.key for ds in ts.dependencies]
                for k, ts in scheduler.tasks.items()
            }
            priority = {k: ts.priority for k, ts in scheduler.tasks.items()}
            self.update_graph(
                self.scheduler,
                tasks=self.scheduler.tasks,
                dependencies=dependencies,
                priority=priority,
            )

    def update_graph(
        self, scheduler, *, dependencies=None, priority=None, tasks=None, **kwargs
    ):
        stack = sorted(
            tasks, key=lambda k: TupleComparable(priority.get(k, 0)), reverse=True
        )
        while stack:
            key = stack.pop()
            if key in self.x or key not in scheduler.tasks:
                continue
            deps = dependencies.get(key, ())
            if deps:
                if not all(dep in self.y for dep in deps):
                    stack.append(key)
                    stack.extend(
                        sorted(
                            deps,
                            key=lambda k: TupleComparable(priority.get(k, 0)),
                            reverse=True,
                        )
                    )
                    continue
                else:
                    total_deps = sum(
                        len(scheduler.tasks[dep].dependents) for dep in deps
                    )
                    y = sum(
                        self.y[dep] * len(scheduler.tasks[dep].dependents) / total_deps
                        for dep in deps
                    )
                    x = max(self.x[dep] for dep in deps) + 1
            else:
                x = 0
                y = self.next_y
                self.next_y += 1

            if (x, y) in self.collision:
                old_x, old_y = x, y
                x, y = self.collision[(x, y)]
                y += 0.1
                self.collision[old_x, old_y] = (x, y)
            else:
                self.collision[(x, y)] = (x, y)

            self.x[key] = x
            self.y[key] = y
            self.index[key] = self.next_index
            self.next_index = self.next_index + 1
            self.new.append(key)
            for dep in deps:
                edge = (dep, key)
                self.index_edge[edge] = self.next_edge_index
                self.next_edge_index += 1
                self.new_edges.append(edge)

    def transition(self, key, start, finish, *args, **kwargs):
        if finish != "forgotten":
            self.state_updates.append((self.index[key], finish))
        else:
            self.visible_updates.append((self.index[key], "False"))
            task = self.scheduler.tasks[key]
            for dep in task.dependents:
                edge = (key, dep.key)
                self.visible_edge_updates.append((self.index_edge.pop(edge), "False"))
            for dep in task.dependencies:
                self.visible_edge_updates.append(
                    (self.index_edge.pop((dep.key, key)), "False")
                )

            try:
                del self.collision[(self.x[key], self.y[key])]
            except KeyError:
                pass

            for collection in [self.x, self.y, self.index]:
                del collection[key]

    def reset_index(self):
        """Reset the index and refill new and new_edges

        From time to time TaskGraph wants to remove invisible nodes and reset
        all of its indices.  This helps.
        """
        self.new = []
        self.new_edges = []
        self.visible_updates = []
        self.state_updates = []
        self.visible_edge_updates = []

        self.index = {}
        self.next_index = 0
        self.index_edge = {}
        self.next_edge_index = 0

        for key in self.x:
            self.index[key] = self.next_index
            self.next_index += 1
            self.new.append(key)
            for dep in self.scheduler.tasks[key].dependencies:
                edge = (dep.key, key)
                self.index_edge[edge] = self.next_edge_index
                self.next_edge_index += 1
                self.new_edges.append(edge)
