# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from functools import partial
import inspect
from typing import Optional
import weakref

import jax
from jax._src import core
from jax import tree_util
from jax._src import linear_util as lu
from jax._src import sharding_impls
from jax.errors import UnexpectedTracerError
from jax._src import mesh as mesh_lib
from jax._src import dispatch
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir import ir
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.sharding_impls import _op_sharding_to_pos_sharding
from jax._src import custom_api_util
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.api_util import flatten_fun_nokwargs, argnums_partial


def _resolve_kwargs(fun, args, kwargs):
  ba = inspect.signature(fun).bind(*args, **kwargs)
  ba.apply_defaults()
  if ba.kwargs:
    raise TypeError("keyword arguments could not be resolved to positions")
  else:
    return ba.args


class _ShardingCallbackInfo:

  def __init__(self, propagate_user_sharding, partition, to_mesh_pspec_sharding,
      in_tree, out_tree, infer_sharding_from_operands, module_context, mesh,
      static_args):
    self.propagate_user_sharding = propagate_user_sharding
    self.partition = partition
    self.to_mesh_pspec_sharding = to_mesh_pspec_sharding
    self.in_tree = in_tree
    self.out_tree = out_tree
    self.infer_sharding_from_operands = infer_sharding_from_operands
    self.module_context = module_context
    self.mesh = mesh
    self.static_args = static_args

  def unflatten_arg_shape(self, s, sharding):
    return _to_jax_sharded_shape(
        s, self.to_mesh_pspec_sharding(sharding, len(s.dimensions()))
    )

  def unflatten_arg_shapes(self, arg_shapes, arg_shardings):
    return self.in_tree.unflatten(
        [
            self.unflatten_arg_shape(s, sharding)
            for s, sharding in zip(arg_shapes, arg_shardings)
        ]
    )


_sharding_callbacks = weakref.WeakValueDictionary()  # type: ignore

_CUSTOM_PARTITIONING_CALL_NAME = "CustomSPMDPartitioning"


def _to_jax_shape(s):
  return core.ShapedArray(s.dimensions(), s.numpy_dtype())


def _to_jax_sharded_shape(s, sharding):
  return jax.ShapeDtypeStruct(
      s.dimensions(), s.numpy_dtype(), sharding=sharding
  )


def _pack_result_sharding(shape, result_shardings):
  if shape.is_tuple():
    return xc.HloSharding.tuple_sharding(shape, result_shardings)
  else:
    return result_shardings[0]


def _flatten_sharding(tree, shardings, shapes):
  return [
      _to_hlo_sharding(sharding, len(shape.dimensions()))
      for sharding, shape in zip(
          tree.flatten_up_to(shardings), shapes
      )
  ]


def _custom_partitioning_propagate_user_sharding(user_sharding, shape,
                                                 backend_string):
  info = _sharding_callbacks[backend_string]
  if info.propagate_user_sharding is None:
    return user_sharding
  if shape.is_tuple():
    user_shapes = shape.tuple_shapes()
    user_shardings = user_sharding.tuple_elements()
  else:
    user_shapes = (shape,)
    user_shardings = (user_sharding,)
  user_shape = info.out_tree.unflatten(
      [
          info.unflatten_arg_shape(s, sharding)
          for s, sharding in zip(user_shapes, user_shardings)
      ]
  )
  result_sharding = info.propagate_user_sharding(
      *info.static_args, info.mesh, user_shape
  )
  result_shardings = _flatten_sharding(
      info.out_tree, result_sharding, user_shapes)
  return _pack_result_sharding(shape, result_shardings)


def _to_hlo_sharding(sharding, num_dimensions):
  if not isinstance(sharding, jax.sharding.Sharding):
    raise ValueError("Custom Partitioning rules must return Sharding.")
  return sharding._to_xla_hlo_sharding(num_dimensions)


def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
                                   result_sharding, backend_string):
  info = _sharding_callbacks[backend_string]
  if result_shape.is_tuple():
    result_shapes = result_shape.tuple_shapes()
    result_shardings = result_sharding.tuple_elements()
  else:
    result_shapes = (result_shape,)
    result_shardings = (result_sharding,)
  mesh, lower_fn, result_sharding, arg_shardings = info.partition(
      *info.static_args,
      info.mesh,
      info.unflatten_arg_shapes(arg_shapes, arg_shardings),
      info.out_tree.unflatten(
          [
              info.unflatten_arg_shape(s, sharding)
              for s, sharding in zip(result_shapes, result_shardings)
          ]
      ),
  )
  module_context = info.module_context

  result_shardings = _flatten_sharding(
      info.out_tree, result_sharding, result_shapes)
  arg_shardings = _flatten_sharding(info.in_tree, arg_shardings, arg_shapes)
  tiled_args = [
      _to_jax_shape(sharding.tile(s))
      for sharding, s in zip(arg_shardings, arg_shapes)
  ]
  tiled_results = [
      _to_jax_shape(sharding.tile(s))
      for sharding, s in zip(result_shardings, result_shapes)
  ]
  closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))(
      *tiled_args
  )
  if closed_jaxpr.out_avals != tiled_results:
    raise ValueError(
        "Mismatch in result shapes. %s vs %s"
        % (repr(closed_jaxpr.out_avals), repr(tiled_results))
    )
  axis_context = sharding_impls.SPMDAxisContext(mesh)
  with core.extend_axis_env_nd(mesh.shape.items()):
    module = mlir.build_mlir_module_helper(
        closed_jaxpr,
        name="tmp_xla_computation",
        platforms=module_context.platforms,
        backend_or_name=module_context.backend_or_name,
        axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)),
    )
  result_sharding = _pack_result_sharding(result_shape, result_shardings)
  return mlir.module_to_bytecode(module), arg_shardings, result_sharding


def _custom_partitioning_infer_sharding_from_operands(arg_shapes, arg_shardings,
                                                      result_shape,
                                                      backend_string):
  info = _sharding_callbacks[backend_string]
  if result_shape.is_tuple():
    result_shapes = result_shape.tuple_shapes()
  else:
    result_shapes = (result_shape,)
  result_sharding = info.infer_sharding_from_operands(
      *info.static_args,
      info.mesh,
      info.unflatten_arg_shapes(arg_shapes, arg_shardings),
      info.out_tree.unflatten([_to_jax_shape(s) for s in result_shapes]),
  )
  result_shardings = _flatten_sharding(
      info.out_tree, result_sharding, result_shapes)
  return _pack_result_sharding(result_shape, result_shardings)


custom_partitioning_p = core.Primitive("custom_partitioning")
custom_partitioning_p.multiple_results = True
dispatch.prim_requires_devices_during_lowering.add(custom_partitioning_p)


def _custom_partitioning_abstract_eval(*avals, call, in_tree, out_tree,
                                       propagate_user_sharding, partition,
                                       infer_sharding_from_operands,
                                       decode_shardings,
                                       static_args):
  del in_tree, out_tree, propagate_user_sharding, partition
  del infer_sharding_from_operands, decode_shardings, static_args
  return call.out_avals


def _custom_partitioning_impl(*args, call, in_tree, out_tree,
                              propagate_user_sharding,
                              partition, infer_sharding_from_operands,
                              decode_shardings, static_args):
  del in_tree, out_tree, propagate_user_sharding, partition
  del infer_sharding_from_operands, decode_shardings, static_args
  return core.jaxpr_as_fun(call)(*args)


custom_partitioning_p.def_abstract_eval(_custom_partitioning_abstract_eval)
custom_partitioning_p.def_impl(_custom_partitioning_impl)


def _check_for_tracers(x):
  for leaf in tree_util.tree_leaves(x):
    if isinstance(x, core.Tracer):
      msg = (
          "Found a JAX Tracer object passed as an argument to a"
          "custom_partitioning function in a position indicated as static by"
          "static_argnums. "
      )
      raise UnexpectedTracerError(msg)


@custom_api_util.register_custom_decorator_type
class custom_partitioning:
  """Inserts a CustomCallOp into the XLA graph with custom SPMD lowering rules.

  .. code-block:: python

    @custom_partitioning
    def f(*args):
      return ...

    def propagate_user_sharding(mesh, user_shape):
      '''Update the sharding of the op from a user's shape.sharding.'''
      user_sharding = jax.tree.map(lambda x: x.sharding, user_shape)

    def partition(mesh, arg_shapes, result_shape):
      def lower_fn(*args):
        ... builds computation on per-device shapes ...
      result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
      arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
      # result_sharding and arg_shardings may optionally be modified and the
      # partitioner will insert collectives to reshape.
      return mesh, lower_fn, result_sharding, arg_shardings

    def infer_sharding_from_operands(mesh, arg_shapes, shape):
      '''Compute the result sharding from the sharding of the operands.'''
      arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)


    f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands)

  The args to ``def_partition`` are as follows:

  * ``propagate_user_sharding``: Callable which takes the sharding of a user (in the dag)
    and returns a suggestion for a new `NamedSharding`. The default
    implementation is just to return the suggested sharding.
  * ``partition``: Callable which takes the SPMD suggested partition shapes and
    partition specs and returns the mesh, a per-shard lowering function, and the final
    input and output sharding specs (the SPMD partitioner will repartition the
    inputs to match). The mesh is returned to allow configuring axis_names for
    collectives when no mesh is provided.
  * ``infer_sharding_from_operands``: Callable which computes an output ``NamedSharding``
    from the ``NamedSharding`` chosen for each argument.
  * ``decode_shardings``: When set to True, convert input ``GSPMDSharding``s to
    ``NamedSharding`` if possible. This may not be possible if the user does not
    provide a contextual mesh.

  Positional arguments can be specified as static using static_argnums. JAX uses
  :code:`inspect.signature(fun)` to resolve these positional arguments.

  Example:

    As an example, assume we want to enhance the existing ``jax.numpy.fft.fft``. This function computes
    the discrete Fourier transform of an N-dimensional input along the last dimension, and is batched
    along the first N-1 dimensions.
    By default, however, it will ignore the sharding of the input and gather the input on all devices.
    However, since ``jax.numpy.fft.fft`` is batched along the first N-1 dimensions,
    this is unnecessary. We will create a new ``my_fft`` op that, instead, does not alter the sharding
    along the first `N-1` dimensions, and only gathers the input along the last dimension if needed.

    .. code-block:: python

      import jax
      from jax.sharding import NamedSharding
      from jax.experimental.custom_partitioning import custom_partitioning
      from jax.experimental.pjit import pjit
      from jax.sharding import PartitionSpec as P
      from jax.sharding import Mesh
      from jax.numpy.fft import fft
      import regex as re
      import numpy as np

      # Pattern to detect all-gather or dynamic-slice in the generated HLO
      _PATTERN = '(dynamic-slice|all-gather)'

      # For an N-D input, keeps sharding along the first N-1 dimensions
      # but replicate along the last dimension
      def supported_sharding(sharding, shape):
          rank = len(shape.shape)
          max_shared_dims = min(len(sharding.spec), rank-1)
          names = tuple(sharding.spec[:max_shared_dims]) + tuple(None for _ in range(rank - max_shared_dims))
          return NamedSharding(sharding.mesh, P(*names))

      def partition(mesh, arg_shapes, result_shape):
          result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
          arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
          return mesh, fft, \
              supported_sharding(arg_shardings[0], arg_shapes[0]), \
              (supported_sharding(arg_shardings[0], arg_shapes[0]),)

      def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
          arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
          return supported_sharding(arg_shardings[0], arg_shapes[0])

      @custom_partitioning
      def my_fft(x):
          return fft(x)

      my_fft.def_partition(
          infer_sharding_from_operands=infer_sharding_from_operands,
          partition=partition)

    Now create a 2D array sharded along the first axis, pass it through ``my_fft``
    and notice how it is still sharded as expected, and identical to the output
    of ``fft``. However, inspecting the HLO
    (using ``lower(x).compile().runtime_executable().hlo_modules()``) reveals that
    ``my_fft`` does not create any all-gather or dynamic-slice, while ``fft`` does.

    .. code-block::

      with Mesh(np.array(jax.devices()), ('x',)):
        x = np.asarray(np.random.randn(32*1024, 1024), dtype=np.complex64)
        y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x)
        pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x'))
        pjit_fft    = pjit(fft,    in_shardings=P('x'), out_shardings=P('x'))
        print(pjit_my_fft(y))
        print(pjit_fft(y))
        # dynamic-slice or all-gather are not present in the HLO for my_fft, because x is a 2D array
        assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None)
        # dynamic-slice or all-gather are present in the HLO for fft
        assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string())    is not None)

    .. code-block::

      # my_fft
      [[-38.840824   +0.j        -40.649452  +11.845365j
      ...
        -1.6937828  +0.8402481j  15.999859   -4.0156755j]]

      # jax.numpy.fft.fft
      [[-38.840824   +0.j        -40.649452  +11.845365j
        ...
        -1.6937828  +0.8402481j  15.999859   -4.0156755j]]

    Because of the logic in ``supported_sharding``, ``my_fft`` also works on 1-dimensional arrays.
    However, in this case, the HLO of ``my_fft`` does show a dynamic-slice, since the last dimension
    is the dimension along which FFTs are calculated and needs to be replicated on all devices before
    the computation can be done.

    .. code-block::

      with Mesh(np.array(jax.devices()), ('x',)):
        x = np.asarray(np.random.randn(32*1024*1024), dtype=np.complex64)
        y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x)
        pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x'))
        pjit_fft    = pjit(fft,    in_shardings=P('x'), out_shardings=P('x'))
        print(pjit_my_fft(y))
        print(pjit_fft(y))
        # dynamic-slice or all-gather are present in the HLO for my_fft, because x is a 1D array
        assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None)
        # dynamic-slice or all-gather are present in the HLO for fft
        assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string())    is not None)

    .. code-block::

      # my_fft
      [    7.217285   +0.j     -3012.4937  +4287.635j   -405.83594 +3042.984j
      ...  1422.4502  +7271.4297j  -405.84033 -3042.983j
      -3012.4963  -4287.6343j]

      # jax.numpy.fft.fft
      [    7.217285   +0.j     -3012.4937  +4287.635j   -405.83594 +3042.984j
      ...  1422.4502  +7271.4297j  -405.84033 -3042.983j
      -3012.4963  -4287.6343j]

  """

  def __init__(self, fun, static_argnums=()):
    self.fun = fun
    self.partition = None
    self.static_argnums = static_argnums
    self.propagate_user_sharding = None
    self.infer_sharding_from_operands = None

  __getattr__ = custom_api_util.forward_attr

  def def_partition(self, partition, infer_sharding_from_operands,
                    propagate_user_sharding=None, decode_shardings=True):
    self.partition = partition
    self.propagate_user_sharding = propagate_user_sharding
    self.infer_sharding_from_operands = infer_sharding_from_operands
    self.decode_shardings = decode_shardings
    return partition

  def __call__(self, *args, **kwargs):
    args = _resolve_kwargs(self.fun, args, kwargs)
    if self.static_argnums:
      static_argnums = set(self.static_argnums)
      args = tuple(x if i in static_argnums else x for i, x in enumerate(args))
      dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
      f_, dyn_args = argnums_partial(
          lu.wrap_init(self.fun),
          dyn_argnums,
          args,
          require_static_args_hashable=False,
      )
      static_args = [args[i] for i in self.static_argnums]
      _check_for_tracers(static_args)
    else:
      static_args = []
      f_, dyn_args = lu.wrap_init(self.fun), args
    args_flat, in_tree = tree_util.tree_flatten(dyn_args)
    flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
    in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
    debug = pe.debug_info(self.fun, in_tree, out_tree, False,
                          "custom_partitioning")
    jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
    assert not len(consts)
    closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
    out_flat = custom_partitioning_p.bind(
        *consts,
        *args_flat,
        call=closed_call,
        partition=self.partition,
        propagate_user_sharding=self.propagate_user_sharding,
        infer_sharding_from_operands=self.infer_sharding_from_operands,
        decode_shardings=self.decode_shardings,
        in_tree=in_tree,
        out_tree=out_tree(),
        static_args=static_args
    )
    return tree_util.tree_unflatten(out_tree(), out_flat)


def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
                                       call, in_tree, out_tree,
                                       propagate_user_sharding, partition,
                                       infer_sharding_from_operands,
                                       decode_shardings,
                                       static_args):
  mesh = mesh_lib.thread_resources.env.physical_mesh
  axis_context = ctx.module_context.axis_context
  if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and
      set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)):
    return mlir.lower_fun(core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values)

  if isinstance(axis_context, sharding_impls.ShardingContext):
    devices = axis_context.device_assignment
    if devices is None:
      raise AssertionError(
          'Please file a bug at https://github.com/google/jax/issues')
  elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
    devices = axis_context.mesh._flat_devices_tuple
  else:
    devices = None

  if not devices or len(devices) == 1:
    return mlir.lower_fun(
        core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values)

  def to_mesh_pspec_sharding(hlo_sharding: xc.HloSharding | None, ndim):
    if hlo_sharding is None:
      return hlo_sharding
    if mesh.empty or not decode_shardings:
      assert devices is not None
      return _op_sharding_to_pos_sharding(hlo_sharding, devices)
    pspec = sharding_impls.parse_flatten_op_sharding(
        hlo_sharding, mesh)[0].get_partition_spec()
    pspec = jax.sharding.PartitionSpec(*pspec, *((None,) * (ndim - len(pspec))))
    return jax.sharding.NamedSharding(mesh, pspec)

  sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding,
      partition, to_mesh_pspec_sharding, in_tree, out_tree,
      infer_sharding_from_operands, ctx.module_context, mesh, static_args)
  key = str(id(sharding_callback_info))
  _sharding_callbacks[bytes(key, 'utf8')] = sharding_callback_info
  # We need to make sure `sharding_callback_info` is still alive when the SPMD
  # partitioner runs so we keep it alive by attaching it to the executable.
  ctx.module_context.add_keepalive(sharding_callback_info)

  result_types = [mlir.aval_to_ir_type(s) for s in call.out_avals]
  out = hlo.CustomCallOp(
      result_types,
      list(values),
      call_target_name=ir.StringAttr.get(_CUSTOM_PARTITIONING_CALL_NAME),
      has_side_effect=ir.BoolAttr.get(False),
      api_version=mlir.i32_attr(2),
      called_computations=ir.ArrayAttr.get([]),
      backend_config=ir.StringAttr.get(key),
      operand_layouts=None,
      result_layouts=None)
  return out.results

mlir.register_lowering(custom_partitioning_p,
                       _custom_partitioning_lowering_rule)

xc.register_custom_call_partitioner(
    _CUSTOM_PARTITIONING_CALL_NAME,
    _custom_partitioning_propagate_user_sharding,
    _custom_partitioning_partition,
    _custom_partitioning_infer_sharding_from_operands, True)  # type: ignore
xb.register_plugin_callbacks(
    partial(
        xc.register_custom_call_partitioner,
        name=_CUSTOM_PARTITIONING_CALL_NAME,
        prop_user_sharding=_custom_partitioning_propagate_user_sharding,
        partition=_custom_partitioning_partition,
        infer_sharding_from_operands=_custom_partitioning_infer_sharding_from_operands,
        can_side_effecting_have_replicated_sharding=True,
    )
)
