# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Mapping, Sequence
import functools

from jax._src.util import safe_zip, use_cpp_class, cache
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.op_shardings import (
    are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated,
    op_sharding_to_indices)

Shape = tuple[int, ...]
Device = xc.Device
Index = tuple[slice, ...]
XLADeviceAssignment = Sequence[Device]


@cache(max_size=4096, trace_context_in_key=False)
def _addressable_devices_indices_map(
    sharding: Sharding, global_shape: Shape) -> Mapping[Device, Index | None]:
  global_map = sharding.devices_indices_map(global_shape)
  if sharding.is_fully_addressable:
    return global_map
  if hasattr(sharding, '_internal_device_list'):
    return {d: global_map[d]
            for d in sharding._internal_device_list.addressable_device_list}
  return {d: ind for d, ind in global_map.items()
          if d.process_index == d.client.process_index()}

@cache(max_size=4096, trace_context_in_key=False)
def common_devices_indices_map(s, global_shape: Shape) -> Mapping[Device, Index]:
  s.shard_shape(global_shape)  # raises a good error message
  hlo_sharding = s._to_xla_hlo_sharding(len(global_shape))
  indices = op_sharding_to_indices(hlo_sharding, global_shape,
                                   len(s._device_assignment))
  return dict(safe_zip(s._device_assignment, indices))


@cache(max_size=4096, trace_context_in_key=False)
def _common_shard_shape(self, global_shape: Shape) -> Shape:
  hlo_sharding = self._to_xla_hlo_sharding(len(global_shape))
  if is_op_sharding_replicated(hlo_sharding):
    return global_shape
  partitions, _ = get_num_ways_dim_sharded(hlo_sharding)
  assert len(partitions) == len(global_shape), (len(partitions), len(global_shape))
  out = []
  for dim, (s, p) in enumerate(safe_zip(global_shape, partitions)):
    try:
      quotient, remainder = divmod(s, p)
    except TypeError:
      # TODO Figure out how to partition dynamic shapes
      raise NotImplementedError
    if remainder != 0:
      raise ValueError(
          f"Sharding {self} implies that array axis {dim} is partitioned "
          f"{p} times, but the dimension size is {s} "
          f"(full shape: {global_shape}, "
          f"per-dimension tiling factors: {partitions} should evenly divide "
          "the shape)")
    out.append(quotient)
  return tuple(out)


@use_cpp_class(xc.Sharding)
class Sharding:
  """Describes how a :class:`jax.Array` is laid out across devices.
  """

  # Abstract methods below that subclasses should implement.
  @property
  def device_set(self) -> set[Device]:
    """The set of devices that this :class:`Sharding` spans.

    In multi-controller JAX, the set of devices is global, i.e., includes
    non-addressable devices from other processes.
    """
    raise NotImplementedError('Subclasses should implement this method.')

  @property
  def is_fully_replicated(self) -> bool:
    """Is this sharding fully replicated?

    A sharding is fully replicated if each device has a complete copy of the
    entire data.
    """
    raise NotImplementedError('Subclasses should implement this method.')

  @property
  def is_fully_addressable(self) -> bool:
    """Is this sharding fully addressable?

    A sharding is fully addressable if the current process can address all of
    the devices named in the :class:`Sharding`. ``is_fully_addressable`` is
    equivalent to "is_local" in multi-process JAX.
    """
    raise NotImplementedError('Subclasses should implement this method.')

  @property
  def memory_kind(self) -> str | None:
    """Returns the memory kind of the sharding."""
    raise NotImplementedError('Subclasses should implement this method.')

  def with_memory_kind(self, kind: str) -> Sharding:
    """Returns a new Sharding instance with the specified memory kind."""
    raise NotImplementedError('Subclasses should implement this method')

  @property
  def _device_assignment(self) -> XLADeviceAssignment:
    raise NotImplementedError('Subclasses should implement this method.')

  def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
    raise NotImplementedError('Subclasses should implement this method.')


  #############################################################################
  # Default implementations below that all subclasses will inherit.

  @functools.cached_property
  def addressable_devices(self) -> set[Device]:
    """The set of devices in the :class:`Sharding` that are addressable by the
       current process.
    """
    # Add a fast path for single controller runtimes.
    if xb.process_count() == 1:
      return self.device_set
    return {d for d in self.device_set
            if d.process_index == d.client.process_index()}

  def addressable_devices_indices_map(
      self, global_shape: Shape) -> Mapping[Device, Index | None]:
    """A mapping from addressable devices to the slice of array data each contains.

    ``addressable_devices_indices_map`` contains that part of
    ``device_indices_map`` that applies to the addressable devices.
    """
    return _addressable_devices_indices_map(self, global_shape)

  def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
    """Returns a mapping from devices to the array slices each contains.

    The mapping includes all global devices, i.e., including
    non-addressable devices from other processes.
    """
    return common_devices_indices_map(self, global_shape)

  @functools.cached_property
  def _addressable_device_assignment(self) -> XLADeviceAssignment:
    if self.is_fully_addressable:
      return self._device_assignment
    if hasattr(self, '_internal_device_list'):
      return tuple(self._internal_device_list.addressable_device_list)
    return tuple(d for d in self._device_assignment
                 if d.process_index == d.client.process_index())

  def shard_shape(self, global_shape: Shape) -> Shape:
    """Returns the shape of the data on each device.

    The shard shape returned by this function is calculated from
    ``global_shape`` and the properties of the sharding.
    """
    return _common_shard_shape(self, global_shape)

  def is_equivalent_to(self: Sharding, other: Sharding, ndim: int) -> bool:
    """Returns ``True`` if two shardings are equivalent.

    Two shardings are equivalent if they place the same logical array shards on
    the same devices.

    For example, a :class:`NamedSharding` may be equivalent
    to a :class:`PositionalSharding` if both place the same shards of the array
    on the same devices.
    """
    try:
      return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim),
                                     other._to_xla_hlo_sharding(ndim))
              and self._internal_device_list == other._internal_device_list and  # type: ignore
              self.memory_kind == other.memory_kind)
    # NotImplementedError is raised by PmapSharding because it can't lower
    # to OpSharding. So if `other` is a PmapSharding, default to a strict
    # equality check.
    except NotImplementedError:
      return self == other
