# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence
import logging
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm

logger = logging.getLogger(__name__)

class ClusterEnv:
  """Interface for defining a cluster environment.

  To enable auto bootrapping (aka :func:`jax.distributed.initialize()`),
  cluster environments need to derive from :class:`ClusterEnv` and implement
  :func:`is_env_present`, :func:`get_coordinator_address`,
  :func:`get_process_count`, and :func:`get_process_id`.
  :class:`ClusterEnv` subclasses are automatically detected when imported.
  """

  _cluster_types: list[type[ClusterEnv]] = []

  def __init_subclass__(cls, **kwargs):
    super().__init_subclass__(**kwargs)
    cls._cluster_types.append(cls)

  @classmethod
  # pytype: disable=bad-return-type
  def auto_detect_unset_distributed_params(cls,
                                           coordinator_address: str | None,
                                           num_processes: int | None,
                                           process_id: int | None,
                                           local_device_ids: Sequence[int] | None,
                                           initialization_timeout: int | None,
                                          ) -> tuple[str | None, int | None, int | None,
                                                     Sequence[int] | None]:
    if all(p is not None for p in (coordinator_address, num_processes,
      process_id, local_device_ids)):
      return (coordinator_address, num_processes, process_id,
              local_device_ids)
    env = next((env for env in cls._cluster_types if env.is_env_present()), None)
    if env:
      logger.debug('Initializing distributed JAX environment via %s', env.__name__)
      if coordinator_address is None:
        coordinator_address = env.get_coordinator_address(timeout_secs=initialization_timeout)
      if num_processes is None:
        num_processes = env.get_process_count()
      if process_id is None:
        process_id = env.get_process_id()
      # Never automatically set local_device_ids on TPUs
      # Defaults to single process per device if local_process_id is available.
      # This only runs if we're in a managed distributed environment.
      # Otherwise local_device_ids will remain unset,
      # which will default to all devices being visible.
      if (local_device_ids is None and not running_in_cloud_tpu_vm and
          env.get_local_process_id() is not None):
        local_device_ids = [env.get_local_process_id()] # type: ignore[list-item]
    else:
      logger.debug('Could not find a known environment for initializing distributed JAX. '
        'Known environments: %s', ', '.join(e.__name__ for e in cls._cluster_types))
    return (coordinator_address, num_processes, process_id, local_device_ids)
  # pytype: enable=bad-return-type

  @classmethod
  def is_env_present(cls) -> bool:
    """Returns True if process is running in this cluster environment.
    """
    raise NotImplementedError("ClusterEnv subclasses must implement is_env_present")

  @classmethod
  def get_coordinator_address(cls, timeout_secs: int | None) -> str:
    """Returns address and port used by JAX to bootstrap.

    Process id 0 will open a tcp socket at "hostname:port" where
    all the processes will connect to initialize the distributed JAX service.
    The selected port needs to be free.
    :func:`get_coordinator_address` needs to return the same hostname and port on all the processes.

    Returns:
      "hostname:port"
    """
    raise NotImplementedError("ClusterEnv subclasses must implement get_coordinator_address")

  @classmethod
  def get_process_count(cls) -> int:
    raise NotImplementedError("ClusterEnv subclasses must implement get_process_count")

  @classmethod
  def get_process_id(cls) -> int:
    raise NotImplementedError("ClusterEnv subclasses must implement get_process_id")

  @classmethod
  def get_local_process_id(cls) -> int | None:
    """ Get index of current process inside a host.

    The method is only useful to support single device per process.
    In that case, each process will see a local device whose ID is
    the same as its local process ID.
    If None, JAX will not restrict the visible devices.
    """
    return None
