# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lowering and execution path that converts jaxprs into MLIR.
from __future__ import annotations

import collections
from collections.abc import Iterator, Sequence
import dataclasses
import functools
from functools import partial
import io
import itertools
import operator
import os
import re
import types
import typing
from typing import Any, Callable, NamedTuple, Protocol, Union, cast as type_cast
import warnings

import numpy as np

from jax._src import ad_util
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import effects as effects_lib
from jax._src import linear_util as lu
from jax._src import path
from jax._src import pickle_util
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout
from jax._src.sharding import Sharding as JSharding
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib.mlir import dialects
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir import register_jax_dialects
from jax._src.state.types import AbstractRef

map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip

T = typing.TypeVar("T")

Value = Any  # = ir.Value

# mypy implicitly sets this variable to true when type checking.
MYPY = False

_JAX_DUMP_IR_TO = config.DEFINE_string(
    'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
    help="Path to which the IR that is emitted by JAX should be dumped as "
         "text files. If omitted, JAX will not dump IR. "
         "Supports the special value 'sponge' to pick the path from the "
         "environment variable TEST_UNDECLARED_OUTPUTS_DIR.")

_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS = config.DEFINE_string(
  'jax_include_debug_info_in_dumps',
  os.getenv('JAX_INCLUDE_DEBUG_INFO_IN_DUMPS', "True"),
  help="Determine whether or not to keep debug symbols and location information "
       "when dumping IR code. By default, debug information will be preserved in "
       "the IR dump. To avoid exposing source code and potentially sensitive "
       "information, set to false")
lowerable_effects: effects_lib.EffectTypeSet = effects_lib.lowerable_effects


# IR Helpers

def dense_int_elements(xs) -> ir.DenseIntElementsAttr:
  return type_cast(ir.DenseIntElementsAttr,
                   ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)))

def dense_int_array(xs) -> ir.DenseI64ArrayAttr:
  return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))  # type: ignore

def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
  a = np.packbits(np.array(xs, np.bool_), bitorder='little')
  # TODO(b/209005197): Work around for MLIR crash for non-splat single element
  # buffers.
  if len(xs) == 1:
    a = np.array(0 if a.item() == 0 else 0xff, np.uint8)
  return ir.DenseElementsAttr.get(
      a, type=ir.IntegerType.get_signless(1), shape=[len(xs)])

def dense_bool_array(xs: Sequence[bool]) -> ir.DenseBoolArrayAttr:
  return ir.DenseBoolArrayAttr.get(xs)  # type: ignore

def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i)
def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i)

def shape_tensor(sizes: Sequence[int | ir.RankedTensorType]
                 ) -> ir.RankedTensorType:
  int1d = aval_to_ir_type(core.ShapedArray((1,), np.int32))
  i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
  def lower_dim(d):
    if type(d) is int:
      return ir_constant(np.array([d], np.int32))
    else:
      if d.type != i32_type:
        d = hlo.convert(i32_type, d)
      return hlo.reshape(int1d, d)
  ds = map(lower_dim, sizes)
  if not ds:
    return type_cast(ir.RankedTensorType, ir_constant(np.array([], np.int32)))
  elif len(ds) == 1:
    return ds[0]
  else:
    return hlo.concatenate(ds, i64_attr(0))


def delegate_lowering(ctx, lowering_fun, *args, **ctx_override_kwargs):
  """Side-effects on `ctx`"""
  ctx_new = ctx.replace(**ctx_override_kwargs)
  out = lowering_fun(ctx_new, *args)
  ctx.set_tokens_out(ctx_new.tokens_out)
  return out


# IR Types

# Non-canonicalized dtype to IR type mapping.
_dtype_to_ir_type : dict[np.dtype, Callable[[], ir.Type]] = {
  np.dtype(dtypes.float0): partial(ir.IntegerType.get_signless, 1),
  np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
  np.dtype(dtypes.int4): partial(ir.IntegerType.get_signless, 4),
  np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
  np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16),
  np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32),
  np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64),
  np.dtype(dtypes.uint4): partial(ir.IntegerType.get_unsigned, 4),
  np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8),
  np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16),
  np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32),
  np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64),
  np.dtype(dtypes.float8_e4m3b11fnuz): ir.Float8E4M3B11FNUZType.get,
  np.dtype(dtypes.float8_e4m3fn): ir.Float8E4M3FNType.get,
  np.dtype(dtypes.float8_e4m3fnuz): ir.Float8E4M3FNUZType.get,
  np.dtype(dtypes.float8_e5m2): ir.Float8E5M2Type.get,
  np.dtype(dtypes.float8_e5m2fnuz): ir.Float8E5M2FNUZType.get,
  np.dtype(dtypes.bfloat16): ir.BF16Type.get,
  np.dtype(np.float16): ir.F16Type.get,
  np.dtype(np.float32): ir.F32Type.get,
  np.dtype(np.float64): ir.F64Type.get,
  np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
  np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
}

def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
  if isinstance(dtype, core.bint):
    # TODO Support different-size underlying dtypes to take advantage of the
    # bound for packing?
    dtype = np.dtype(np.int32)
  assert isinstance(dtype, (np.dtype, np.generic)), type(dtype)
  dtype = np.dtype(dtype)
  try:
    ir_type_factory = _dtype_to_ir_type[dtype]
  except KeyError as err:
    raise TypeError(
        f"No dtype_to_ir_type handler for dtype: {dtype}") from err
  return ir_type_factory()

def _array_ir_types(aval: core.ShapedArray | core.DShapedArray
                    ) -> Sequence[ir.Type]:
  aval = core.physical_aval(aval)  # type: ignore
  if not core.is_constant_shape(aval.shape):
    return _dynamic_array_ir_types(aval)  # type: ignore
  return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),)  # type: ignore

def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:
  dyn_size = ir.ShapedType.get_dynamic_size()
  shape = [d if type(d) is int else dyn_size for d in aval.shape]
  return (ir.RankedTensorType.get(shape, dtype_to_ir_type(aval.dtype)),)

ir_type_handlers: dict[type[core.AbstractValue],
                        Callable[[Any], Sequence[ir.Type]]] = {}

def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[ir.Type]:
  """Converts a JAX aval to zero or more MLIR IR types.

  In general, a JAX value may be represented by multiple IR values, so this
  function returns multiple types."""
  try:
    return ir_type_handlers[type(aval)](aval)
  except KeyError as err:
    raise TypeError(f"No ir_type_handler for aval type: {type(aval)}") from err

ir_type_handlers[core.ShapedArray] = _array_ir_types
ir_type_handlers[core.ConcreteArray] = _array_ir_types
ir_type_handlers[core.AbstractToken] = lambda _: [hlo.TokenType.get()]
ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types

def aval_to_ir_type(aval: core.AbstractValue) -> ir.Type:
  """Convenience wrapper around aval_to_ir_types for single types.

  For some common cases, e.g. dense arrays, we know JAX values are represented
  by a single IR value."""
  types = aval_to_ir_types(aval)
  if len(types) != 1:
    raise TypeError(f"aval_to_ir_type called on {aval} which corresponds to "
                    f"multiple IR types {types}")
  return types[0]


# Constants

class ConstantHandler(Protocol):
  def __call__(self, val: Any) -> Sequence[ir.Value]:
    """Builds an IR representation for a constant `val`.

    A JAX value is represented by zero or more IR values."""

_constant_handlers : dict[type, ConstantHandler] = {}

def register_constant_handler(type_: type, handler_fun: ConstantHandler):
  _constant_handlers[type_] = handler_fun

def get_constant_handler(type_: type) -> ConstantHandler:
  return _constant_handlers[type_]

def ir_constants(val: Any) -> Sequence[ir.Value]:
  """Translate a Python `val` to an IR constant, canonicalizing its dtype.

  Args:
    val: a Python value to be translated to a constant.

  Returns:
    A representation of the constant as a list of IR values.
  """
  for t in type(val).__mro__:
    handler = _constant_handlers.get(t)
    if handler:
      out = handler(val)
      assert all(isinstance(v, ir.Value) for v in out), (type(val), out)
      return out
  if hasattr(val, '__jax_array__'):
    return ir_constants(val.__jax_array__())
  raise TypeError(f"No constant handler for type: {type(val)}")

def ir_constant(val: Any) -> ir.Value:
  """Convenience wrapper around ir_constants for singleton values."""
  values = ir_constants(val)
  if len(values) != 1:
    raise TypeError(f"ir_constant called on {val} which corresponds to "
                    f"multiple IR values {values}")
  return values[0]


def _numpy_array_constant(x: np.ndarray | np.generic) -> Sequence[ir.Value]:
  element_type = dtype_to_ir_type(x.dtype)
  shape = x.shape
  if x.dtype == np.bool_:
    x = np.packbits(x, bitorder='little')  # type: ignore
  x = np.ascontiguousarray(x)
  attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape)  # type: ignore
  return (hlo.constant(attr),)


def _masked_array_constant_handler(*args, **kwargs):
  raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
                   "Use arr.filled() to convert the value to a standard numpy array.")

register_constant_handler(np.ma.MaskedArray, _masked_array_constant_handler)

def _ndarray_constant_handler(val: np.ndarray | np.generic) -> Sequence[ir.Value]:
  """Constant handler for ndarray literals, handling zero-size strides.

  In most cases this function calls _numpy_array_constant(val) except it has
  special handling of arrays with any strides of size zero: for those, it
  generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
  to avoid staging in large literals that might arise from np.zeros or np.ones
  or the output of lax.broadcast (which uses np.broadcast_to which in turn
  uses size-zero strides).

  Args:
    val: an ndarray.

  Returns:
    An XLA ComputationDataHandle / XlaOp representing the constant ndarray
    staged into the XLA Computation.
  """
  if val.dtype == dtypes.float0:
    return _numpy_array_constant(np.zeros(val.shape, dtype=np.bool_))
  elif np.any(np.equal(0, val.strides)) and val.size > 0:
    zero_stride_axes, = np.where(np.equal(0, val.strides))
    other_axes, = np.where(np.not_equal(0, val.strides))
    collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)  # type: ignore
                              for ax in range(val.ndim))]
    out = hlo.broadcast_in_dim(
        ir.RankedTensorType.get(
            val.shape, dtype_to_ir_type(collapsed_val.dtype)),  # type: ignore
        _numpy_array_constant(collapsed_val)[0],
        dense_int_array(other_axes))
    return (out,)
  else:
    return _numpy_array_constant(val)

register_constant_handler(np.ndarray, _ndarray_constant_handler)

for _scalar_type in [np.int8, np.int16, np.int32, np.int64,
                     np.uint8, np.uint16, np.uint32, np.uint64,
                     np.float16, np.float32, np.float64,
                     np.complex64, np.complex128,
                     np.bool_, np.longlong, dtypes.bfloat16]:
  register_constant_handler(_scalar_type, _ndarray_constant_handler)  # type: ignore

def _python_scalar_handler(dtype, val):
  return _numpy_array_constant(np.array(val, dtype))

for ptype, dtype in dtypes.python_scalar_dtypes.items():
  register_constant_handler(ptype, partial(_python_scalar_handler, dtype))

def _token_constant_handler(val):
  return [hlo.create_token()]
register_constant_handler(core.Token, _token_constant_handler)

# Source locations

def get_canonical_source_file(file_name: str, caches: TracebackCaches) -> str:
  canonical_file_name = caches.canonical_name_cache.get(file_name, None)
  if canonical_file_name is not None:
    return canonical_file_name

  pattern = config.hlo_source_file_canonicalization_regex.value
  if pattern:
    file_name = re.sub(pattern, '', file_name)
  caches.canonical_name_cache[file_name] = file_name
  return file_name

def _is_user_file(ctx: ModuleContext, file_name: str) -> bool:
  is_user = ctx.traceback_caches.is_user_file_cache.get(file_name, None)
  if is_user is not None:
    return is_user
  out = source_info_util.is_user_filename(file_name)
  ctx.traceback_caches.is_user_file_cache[file_name] = out
  return out

def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
  """Converts a full traceback to a callsite() MLIR location."""
  loc = ctx.traceback_caches.traceback_cache.get(tb, None)
  if loc is not None:
    return loc

  frame_locs = []
  frames_limit = config.traceback_in_locations_limit.value
  frames_limit = frames_limit if frames_limit >= 0 else 1000

  codes, lastis = tb.raw_frames()
  for i, code in enumerate(codes):
    if not _is_user_file(ctx, code.co_filename):
      continue

    lasti = lastis[i]
    code_lasti = code, lasti
    loc = ctx.traceback_caches.location_cache.get(code_lasti, None)
    if loc is None:
      frame = source_info_util.raw_frame_to_frame(code, lasti)
      file_loc = ir.Location.file(
          get_canonical_source_file(frame.file_name, ctx.traceback_caches),
          frame.start_line,
          frame.start_column,
      )
      loc = ir.Location.name(frame.function_name, childLoc=file_loc)
      ctx.traceback_caches.location_cache[code_lasti] = loc
    frame_locs.append(loc)
    if len(frame_locs) >= frames_limit:
      break

  n = len(frame_locs)
  if n == 0:
    loc = ir.Location.unknown()
  elif n == 1:
    loc = frame_locs[0]
  else:
    loc = ir.Location.callsite(frame_locs[0], frame_locs[1:])
  ctx.traceback_caches.traceback_cache[tb] = loc
  return loc

def _source_info_to_location(
    ctx: ModuleContext, primitive: core.Primitive, params: dict[str, Any],
    source_info: source_info_util.SourceInfo) -> ir.Location:
  eqn_str = (f'{source_info.name_stack}/'
             f'{core.str_eqn_compact(primitive.name, params)}')
  if config.include_full_tracebacks_in_locations.value:
    if source_info.traceback is None:
      loc = ir.Location.unknown()
    else:
      loc = _traceback_to_location(ctx, source_info.traceback)
  else:
    frame = source_info_util.user_frame(source_info)
    if frame is None:
      loc = ir.Location.unknown()
    else:
      loc = ir.Location.file(get_canonical_source_file(frame.file_name,
                                                       ctx.traceback_caches),
                             frame.start_line, frame.start_column)
  loc = ir.Location.name(eqn_str, childLoc=loc)
  # TODO(phawkins): also include primitive.name as the operator type.
  return loc

upstream_dialects = ir.DialectRegistry()
if register_jax_dialects:
  register_jax_dialects.register_dialects(upstream_dialects)

# Dumping MLIR modules
_ir_dump_counter = itertools.count()

def dump_module_to_file(module: ir.Module, stage_name: str) -> str | None:
  """Dumps the `module` IR to a file.

  Dumps the module if JAX_DUMP_IR_TO is defined.

  Args:
    module: The module to dump
    stage_name: A name to distinguish different stages of a module, will be
      appended to the `module.name`.

  Returns:
    The name of the file containing the dump if JAX_DUMP_IR_TO is defined and
    the module was dumped, `None` otherwise.
  """
  out_dir_name = _JAX_DUMP_IR_TO.value
  if not out_dir_name:
    return None
  if out_dir_name == "sponge":
    out_dir_name = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", "")
    if not out_dir_name:
      raise ValueError("JAX_DUMP_IR_TO='sponge' but "
                       "TEST_UNDECLARED_OUTPUTS_DIR is not defined")

  id = next(_ir_dump_counter)
  sym_name = module.operation.attributes['sym_name']
  module_name = ir.StringAttr(sym_name).value

  name = f"jax_ir{id}_{_make_string_safe_for_filename(module_name)}_{stage_name}.mlir"

  out_dir = path.Path(out_dir_name)
  out_dir.mkdir(parents=True, exist_ok=True)
  full_path = out_dir / name
  full_path.write_text(module_to_string(module))
  return name

def dump_module_message(module: ir.Module, stage_name: str) -> str:
  dumped_to = dump_module_to_file(module, stage_name)
  if dumped_to:
    return f"The module was dumped to {dumped_to}."
  else:
    return "Define JAX_DUMP_IR_TO to dump the module."

def _make_string_safe_for_filename(s: str) -> str:
  return re.sub(r'[^\w.)( -]', '', s)

def module_to_string(module: ir.Module, enable_debug_info=None) -> str:
  output = io.StringIO()
  if enable_debug_info is None:
    enable_debug_flag = str.lower(_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS.value)
    enable_debug_info = enable_debug_flag not in ('false', '0')
  module.operation.print(file=output, enable_debug_info=enable_debug_info)
  return output.getvalue()

def module_to_bytecode(module: ir.Module) -> bytes:
  output = io.BytesIO()
  module.operation.write_bytecode(file=output)
  return output.getvalue()

# Translation rules
def make_ir_context() -> ir.Context:
  """Creates an MLIR context suitable for JAX IR."""
  context = ir.Context()
  context.append_dialect_registry(upstream_dialects)
  context.load_all_available_dialects()

  # If threading is enabled, each MLIR context will keep alive a thread pool.
  # Since we cache MLIR modules (and hence contexts), this means we might keep
  # several threads alive for each cache entry. This is a terrible idea. However
  # we don't do any heavy computation on MLIR modules from Python anyway, so we
  # just disable threading.
  context.enable_multithreading(False)

  dialects.mhlo.register_mhlo_dialect(context)
  dialects.chlo.register_dialect(context)
  dialects.hlo.register_dialect(context)
  return context


AxisContext = Union[
    sharding_impls.SPMDAxisContext,
    sharding_impls.ReplicaAxisContext,
    sharding_impls.ShardingContext,
]

class ShapePolyLoweringState:
  # The names of the dimension variables, sorted by name. This is the order in
  # which they are passed to the IR functions that need them. This is only
  # used for native serialization with polymorphic shapes when
  # --jax_dynamic_shapes is off.
  # TODO: for multi-platform lowering we prepend to the regular dimension
  # variables a fake dimension variable "platform_index_". This is a
  # temporary abuse, taking advantage that for platform index we need the
  # same lowering strategy as for dimension variables: add it as argument to
  # inner functions, and pass the values along at the call sites.
  dim_vars: tuple[str, ...]
  # Whether the module uses dimension variables, either in its inputs or
  # from an inner call to Exported modules that uses dimension variables.
  # This includes the case when the called Exported module uses a platform
  # index argument.
  uses_dim_vars: bool

  # If the first dimension variable is a platform index argument
  has_platform_index_argument: bool

  def __init__(self,
               dim_vars: tuple[str, ...],
               lowering_platforms: tuple[str, ...] | None):
    if lowering_platforms is not None and len(lowering_platforms) > 1:
      dim_vars = ("_platform_index",) + tuple(dim_vars)
      self.has_platform_index_argument = True
    else:
      self.has_platform_index_argument = False
    self.uses_dim_vars = (len(dim_vars) > 0)
    self.dim_vars = dim_vars

@dataclasses.dataclass(frozen=True)
class LoweringParameters:
  # A mapping between primitives and user-defined LoweringRules.
  # When lowering a primitive, give priorioty to the rule in this map over
  # existing Jax rules.
  override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None = None

  # Signals that the entire computation being lowered operates on global
  # constants. This will result in adding jax.global_constant attributes
  # to the arguments of all functions that are created, e.g., floor_divide.
  # This is used only in export and jax2tf in presence of shape polymorphism
  # or multi-platform lowering.
  global_constant_computation: bool = False

  # Signals that we are lowering for exporting.
  for_export: bool = False


@dataclasses.dataclass
class TracebackCaches:
  traceback_cache: dict[xc.Traceback, ir.Location]
  location_cache: dict[tuple[types.CodeType, int], ir.Location]
  canonical_name_cache: dict[str, str]
  is_user_file_cache: dict[str, bool]

  def __init__(self):
    self.traceback_cache = {}
    self.location_cache = {}
    self.canonical_name_cache = {}
    self.is_user_file_cache = {}

@dataclasses.dataclass
class ModuleContext:
  """Module-wide context information for MLIR lowering."""
  context: ir.Context
  module: ir.Module
  ip: ir.InsertionPoint
  symbol_table: ir.SymbolTable
  backend_or_name: str | xb.XlaBackend | None
  # The lowering platforms for the module. Can be more than one only when
  # exporting.
  platforms: Sequence[str]
  axis_context: AxisContext
  keepalives: list[Any]
  channel_iterator: Iterator[int]
  host_callbacks: list[Any]
  # Keep state for the lowering of shape polymorphism
  shape_poly_state: ShapePolyLoweringState

  # Cached primitive lowerings.
  cached_primitive_lowerings: dict[Any, func_dialect.FuncOp]

  # Cached traceback infromation.
  traceback_caches: TracebackCaches

  lowering_parameters: LoweringParameters

  @property
  def axis_env(self) -> sharding_impls.AxisEnv:
    return self.axis_context.axis_env

  def __init__(
      self,
      *,
      backend_or_name: str | xb.XlaBackend | None,
      platforms: Sequence[str],
      axis_context: AxisContext,
      keepalives: list[Any],
      channel_iterator: Iterator[int],
      host_callbacks: list[Any],
      lowering_parameters: LoweringParameters,
      context: ir.Context | None = None,
      module: ir.Module | None = None,
      ip: ir.InsertionPoint | None = None,
      symbol_table: ir.SymbolTable | None = None,
      cached_primitive_lowerings: None | (dict[Any, func_dialect.FuncOp]) = None,
      traceback_caches: None | TracebackCaches = None,
      shape_poly_state = None):

    self.context = context or make_ir_context()
    self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
    self.ip = ip or ir.InsertionPoint(self.module.body)
    self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation)
    self.backend_or_name = backend_or_name
    self.platforms = platforms
    self.axis_context = axis_context
    self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None
                                       else cached_primitive_lowerings)
    self.traceback_caches = (TracebackCaches() if traceback_caches is None
                             else traceback_caches)
    self.channel_iterator = channel_iterator
    self.keepalives = keepalives
    self.host_callbacks = host_callbacks
    self.shape_poly_state = (
      shape_poly_state or ShapePolyLoweringState((), tuple(platforms)))
    self.lowering_parameters = lowering_parameters

  @property
  def backend(self) -> xb.XlaBackend:
    # TODO(necula): clean the use of backend and backend_or_name vs. platforms
    if len(self.platforms) > 1:
      raise NotImplementedError(
        "accessing .backend in multi-lowering setting. This can occur when "
        "lowering a primitive that has not been adapted to multi-platform "
        "lowering")
    if self.backend_or_name is None or isinstance(self.backend_or_name, str):
      return xb.get_backend(self.backend_or_name)
    return self.backend_or_name

  def new_channel(self) -> int:
    return next(self.channel_iterator)

  # Adds an IFRT host callback object to the context. A reference to these
  # callbacks will be provided to IFRT during compilation so it can do things
  # like serialize them and keep them alive.
  def add_host_callback(self, host_callback: Any) -> None:
    self.host_callbacks.append(host_callback)

  # Keeps a value alive as long as the Python executable is alive.
  # TODO(phawkins): this feature is problematic, because you almost certainly
  # want to keep alive values as long as the underlying runtime executable is
  # still alive/executing. The Python executable object may have a shorter
  # lifetime, so it's highly likely any caller of this method is buggy.
  def add_keepalive(self, keepalive: Any) -> None:
    self.keepalives.append(keepalive)

  def replace(self, **kw): return dataclasses.replace(self, **kw)


@dataclasses.dataclass
class LoweringRuleContext:
  """Per-rule context information for MLIR lowering."""
  module_context: ModuleContext
  name_stack: source_info_util.NameStack
  primitive: core.Primitive | None
  avals_in: Sequence[core.AbstractValue]
  avals_out: Any  # Usually Sequence[core.AbstractValue], but sometimes None.
  tokens_in: TokenSet
  tokens_out: TokenSet | None  # Mutable store for output containers
  axis_size_env: dict[core.Var, ir.Value] | None = None  # Dynamic axis sizes
  # The values for the dimension variables in same order as
  # module_context.shape_poly_state.dim_vars
  dim_var_values: Sequence[ir.Value] = ()
  compute_type: str | None = None
  # Override module_context.platforms if not None. Used during multi-platform
  # lowering, when in a scope with a subset of the module_context.platforms.
  platforms: Sequence[str] | None = None

  def set_tokens_out(self, tokens_out: TokenSet):
    assert self.tokens_out is None, 'Should only set `tokens_out` once.'
    self.tokens_out = tokens_out

  def replace(self, **kw): return dataclasses.replace(self, **kw)  # pytype: disable=wrong-arg-types  # dataclasses-replace-types


if not MYPY:
  class LoweringRule(Protocol):
    def __call__(self, ctx: LoweringRuleContext,
                 *args: ir.Value | Sequence[ir.Value],
                 **kw) -> Sequence[ir.Value | Sequence[ir.Value]]:
      """Converts a JAX primitive invocation into MLIR."""
else:
  LoweringRule = Any

_lowerings: dict[core.Primitive, LoweringRule] = {}
_platform_specific_lowerings: dict[str, dict[core.Primitive, LoweringRule]]
_platform_specific_lowerings = collections.defaultdict(dict)

def register_lowering(prim: core.Primitive, rule: LoweringRule,
                      platform: str | None = None):
  if platform is None:
    _lowerings[prim] = rule
  else:
    # For backward compatibility reasons, we allow rules to be registered
    # under "gpu" even though the platforms are now called "cuda" and "rocm".
    # TODO(phawkins): fix up users to specify either "cuda" or "rocm" and remove
    # this expansion.
    for p in xb.expand_platform_alias(platform):
      _platform_specific_lowerings[p][prim] = rule
  return rule


def _unwrap_singleton_ir_values(x): return x[0] if len(x) == 1 else x
def wrap_singleton_ir_values(x: ir.Value | Sequence[ir.Value]
                             ) -> Sequence[ir.Value]:
  """Adds a consistent tuples to a mixture of tupled and untuple values."""
  return (x,) if isinstance(x, ir.Value) else tuple(x)

def flatten_lowering_ir_args(
    xs: Sequence[ir.Value | Sequence[ir.Value]]
) -> Sequence[ir.Value]:
  return util.flatten(map(wrap_singleton_ir_values, xs))

_module_name_regex = re.compile(r"[^\w.-]")

def sharded_aval(aval: core.AbstractValue,
                 sharding: JSharding | None) -> core.AbstractValue:
  """Returns the new aval sharded based on sharding proto."""
  if sharding is None:
    return aval
  if isinstance(aval, core.AbstractToken):
    return aval
  if not isinstance(aval, (core.ShapedArray, core.DShapedArray)):
    raise NotImplementedError
  return aval.update(sharding.shard_shape(aval.shape))  # type: ignore


def eval_dynamic_shape(ctx: LoweringRuleContext,
                       shape: core.Shape) -> tuple[int | Value, ...]:
  if config.dynamic_shapes.value:
    return tuple(ctx.axis_size_env.get(d, d) for d in shape)  # type: ignore
  else:
    ctx = ctx.replace(
        primitive="eval_dynamic_shape",
        avals_in=[core.dim_value_aval()] * len(ctx.module_context.shape_poly_state.dim_vars),
        tokens_out=None)

    res = lower_fun(
        partial(core.evaluate_shape, shape, ctx.module_context.shape_poly_state.dim_vars),
        multiple_results=True)(ctx, *ctx.dim_var_values)
    return tuple(operator.index(d) if core.is_constant_dim(d) else d_ir
                 for d, d_ir in zip(shape, util.flatten(res)))

# TODO: replace usage of eval_dynamic_shape_as_vals with eval_dynamic_shape_as_ivals
def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext,
                               shape: core.Shape) -> tuple[Value, ...]:
  """Evaluates the dynamic shapes as int32 values."""
  def convert_dim(d: int | Value):
    if type(d) is int:
      return ir_constant(np.array(d, dtype=np.int32))
    else:
      i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
      if d.type != i32_type:  # type: ignore
        return hlo.convert(i32_type, d)
      else:
        return d
  return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape))


def eval_dynamic_shape_as_ivals(
    ctx: LoweringRuleContext, shape: core.Shape
    ) -> tuple[int | Value, ...]:
  """Evaluates the dynamic shapes as int or ir.int32 values."""
  def convert_dim(d: int | Value) -> int | ir.Value:
    if type(d) is int:
      return d
    else:
      i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
      if d.type != i32_type:  # type: ignore
        return hlo.convert(i32_type, d)
      else:
        return d
  return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape))

def eval_dynamic_shape_as_tensor(ctx: LoweringRuleContext,
                                 shape: core.Shape) -> Value:
  """Evaluates the dynamic shapes as one 1d int32 tensor."""
  return shape_tensor(eval_dynamic_shape(ctx, shape))

class LoweringResult(NamedTuple):
  module: ir.Module
  keepalive: Any | None
  host_callbacks: list[Any]
  shape_poly_state: ShapePolyLoweringState


_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]


def _to_physical_op_sharding(
    aval: core.AbstractValue, sharding: JSharding | None,
) -> xc.OpSharding | None:
  if sharding is None:
    return None
  assert isinstance(sharding, JSharding)
  if isinstance(aval, AbstractRef):
    return _to_physical_op_sharding(aval.inner_aval, sharding)
  assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
  if dtypes.issubdtype(aval.dtype, dtypes.extended):
    sharding = sharding_impls.physical_sharding(aval, sharding)
    aval = core.physical_aval(aval)
  return sharding._to_xla_hlo_sharding(aval.ndim).to_proto()  # type: ignore


def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout) -> str | None:
  if layout is None:
    return "default"
  if isinstance(layout, AutoLayout):
    return "auto"
  return layout._to_xla_layout()


def _get_mem_kind(s: JSharding | None) -> str | None:
  if s is None:
    return None
  assert isinstance(s, JSharding)
  return s.memory_kind


def lower_jaxpr_to_module(
    module_name: str,
    jaxpr: core.ClosedJaxpr,
    *,
    ordered_effects: list[core.Effect],
    backend_or_name: str | xb.XlaBackend | None,
    platforms: Sequence[str],
    axis_context: AxisContext,
    name_stack: source_info_util.NameStack,
    donated_args: Sequence[bool],
    replicated_args: Sequence[bool] | None = None,
    arg_shardings: Sequence[JSharding | None] | None = None,
    result_shardings: Sequence[JSharding | None] | None = None,
    in_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
    out_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
    arg_names: Sequence[str | None] | None = None,
    result_names: Sequence[str] | None = None,
    num_replicas: int = 1,
    num_partitions: int = 1,
    all_default_mem_kind: bool = True,
    input_output_aliases: None | tuple[int | None, ...] = None,
    propagated_out_mem_kinds: tuple[None | str, ...] | None = None,
    lowering_parameters: LoweringParameters,
) -> LoweringResult:
  """Lowers a top-level jaxpr to an MLIR module.

  Handles the quirks of the argument/return value passing conventions of the
  runtime.
  """
  platforms = tuple(map(xb.canonicalize_platform, platforms))

  in_avals = (jaxpr.in_avals if arg_shardings is None else
              map(sharded_aval, jaxpr.in_avals, arg_shardings))
  out_avals = (jaxpr.out_avals if result_shardings is None else
               map(sharded_aval, jaxpr.out_avals, result_shardings))
  if all_default_mem_kind:
    arg_memory_kinds = None
    result_memory_kinds = None
  else:
    arg_memory_kinds = (map(_get_mem_kind, arg_shardings)
                        if arg_shardings is not None else None)
    result_memory_kinds = (map(_get_mem_kind, result_shardings)
                          if result_shardings is not None else None)

  xla_donated_args = None
  platforms_with_donation = [p for p in platforms
                             if p in _platforms_with_donation]
  if platforms_with_donation:
    if len(platforms_with_donation) != len(platforms):
      raise NotImplementedError(
        "In multi-platform lowering either all or no lowering platforms "
        f"should support donation. Lowering for {platforms} of which "
        f"only {platforms_with_donation} support donation")
    if num_partitions > 1 and (
        result_shardings is None or all(s is None for s in result_shardings)):
      xla_donated_args = donated_args
    if xla_donated_args is None:
      input_output_aliases, donated_args = _set_up_aliases(
          input_output_aliases, in_avals, out_avals, donated_args,
          arg_memory_kinds, result_memory_kinds)
  unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
  if unlowerable_effects:
    raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')
  if xla_donated_args is None and any(donated_args):
    unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d]
    msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
    if not platforms_with_donation:
      msg = f"Donation is not implemented for {platforms}.\n{msg}"
    if unused_donations:
      warnings.warn("Some donated buffers were not usable:"
                    f" {', '.join(unused_donations)}.\n{msg}")

  if xla_donated_args is not None:
    assert input_output_aliases is None
  if input_output_aliases is not None:
    assert xla_donated_args is None

  # Delete donated_args by default here, since it's not needed beyond this point
  del donated_args

  # HLO channels need to start at 1
  channel_iter = itertools.count(1)
  # Create a keepalives list that will be mutated during the lowering.
  keepalives: list[Any] = []
  host_callbacks: list[Any] = []

  dim_vars: Sequence[str]
  if not config.dynamic_shapes.value:
    # Find the dimension variables
    all_dim_poly = [d for aval in jaxpr.in_avals if hasattr(aval, "shape")
                    for d in aval.shape if not core.is_constant_dim(d)]
    dim_vars = tuple(sorted(functools.reduce(lambda acc, new: acc.union(new._get_vars()),
                                             all_dim_poly, set())))
  else:
    dim_vars = ()

  ctx = ModuleContext(backend_or_name=backend_or_name,
                      platforms=platforms, axis_context=axis_context,
                      keepalives=keepalives,
                      channel_iterator=channel_iter,
                      host_callbacks=host_callbacks,
                      lowering_parameters=lowering_parameters,
                      shape_poly_state=ShapePolyLoweringState(dim_vars, platforms))
  with ctx.context, ir.Location.unknown(ctx.context):
    # Remove module name characters that XLA would alter. This ensures that
    # XLA computation preserves the module name.
    attrs = ctx.module.operation.attributes
    module_name = _module_name_regex.sub("_", module_name)
    attrs["sym_name"] = ir.StringAttr.get(module_name)
    attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
    attrs["mhlo.num_partitions"] = i32_attr(num_partitions)
    lower_jaxpr_to_fun(
        ctx, "main", jaxpr, ordered_effects,
        name_stack=name_stack,
        public=True,
        replicated_args=replicated_args,
        arg_shardings=arg_shardings,
        result_shardings=result_shardings,
        input_output_aliases=input_output_aliases,
        xla_donated_args=xla_donated_args,
        arg_names=arg_names,
        result_names=result_names,
        arg_memory_kinds=arg_memory_kinds,
        result_memory_kinds=result_memory_kinds,
        arg_layouts=in_layouts,
        result_layouts=out_layouts,
        propagated_out_mem_kinds=propagated_out_mem_kinds)

  try:
    if not ctx.module.operation.verify():
      raise ValueError(
          "Cannot lower jaxpr with verifier errors. " +
          dump_module_message(ctx.module, "verification"))
  except ir.MLIRError as e:
    msg_lines = ["Cannot lower jaxpr with verifier errors:"]
    def emit_diagnostic_info(d):
      msg_lines.append(f"\t{d.message}")
      msg_lines.append(f"\t\tat {d.location}")
      for n in d.notes:
        emit_diagnostic_info(n)
    for d in e.error_diagnostics:
      emit_diagnostic_info(d)
    raise ValueError("\n".join(msg_lines) + "\n" +
                     dump_module_message(ctx.module, "verification")) from e

  return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks,
                        ctx.shape_poly_state)

def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args,
                    arg_memory_kinds, result_memory_kinds):
  if input_output_aliases is None:
    input_output_aliases = [None] * len(avals_in)
  else:
    input_output_aliases = list(input_output_aliases)
  # To match-up in-avals to out-avals we only care about the number of
  # bytes, so we strip off unrelated aval metadata (eg. the named shape)
  strip_metadata = lambda a: a.strip_named_shape().strip_weak_type()
  avals_in = map(strip_metadata, avals_in)
  avals_out = map(strip_metadata, avals_out)

  # Both arg and result memory kinds need to be specified to donate based on
  # the memory kind. For jit's where out_shardings is not specified, we don't
  # know the memory kind so don't condition the logic based on the memory kind.
  # TODO(yashkatariya): Note that this logic should be in C++ where we make
  # donation decisions are made after SPMD propagation passes and memory
  # placement passes so that we have all the information.
  if (arg_memory_kinds is None or result_memory_kinds is None or
      any(a is None for a in arg_memory_kinds) or
      any(r is None for r in result_memory_kinds)):
    arg_memory_kinds = [None] * len(avals_in)
    result_memory_kinds = [None] * len(avals_out)

  donations = collections.defaultdict(collections.deque)
  for i, (aval, am, donated, aliased) in enumerate(
      zip(avals_in, arg_memory_kinds, donated_args, input_output_aliases)):
    if donated and aliased is None:
      donations[(aval, am)].append(i)

  out_donated_args = list(donated_args)
  for i, (aval, rm) in enumerate(zip(avals_out, result_memory_kinds)):
    # Only donate if memory kinds match. Relax this when the compiler can
    # donate across memories.
    key = (aval, rm)
    if donations.get(key, ()):
      input_id = donations[key].popleft()
      input_output_aliases[input_id] = i
      out_donated_args[input_id] = False

  return input_output_aliases, out_donated_args

Token = Sequence[ir.Value]

def token_type() -> Sequence[ir.Type]:
  return [hlo.TokenType.get()]

def create_token() -> Token:
  return wrap_singleton_ir_values(hlo.create_token())

class TokenSet:
  """An immutable container of tokens to be used to lower effectful jaxprs. When lowering
  effectful jaxprs, we need to thread HLO tokens to sequence them. Each effect
  will need its own token that will be threaded in and out of the effectful
  primitives. A `TokenSet` encapsulates a set of HLO tokens that will be
  used by the lowering rules.
  """
  _tokens: collections.OrderedDict[core.Effect, Token]

  def __init__(self, *args, **kwargs):
    self._tokens = collections.OrderedDict(*args, **kwargs)

  def __len__(self):
    return len(self._tokens)

  def get(self, effect: core.Effect) -> Token:
    return self._tokens[effect]

  @classmethod
  def create(cls, effects: Sequence[core.Effect]) -> TokenSet:
    """Creates a `TokenSet` corresponding to a list of `core.Effect`s."""
    tokens = [create_token() for _ in effects]
    return TokenSet(zip(effects, tokens))

  def items(self) -> Sequence[tuple[core.Effect, Token]]:
    return tuple(self._tokens.items())

  def effects(self) -> set[core.Effect]:
    return set(self._tokens.keys())

  def subset(self, effects: Sequence[core.Effect]) -> TokenSet:
    """Return a subset of the `TokenSet` restricted to a set of `core.Effect`s."""
    return TokenSet((eff, self._tokens[eff]) for eff in effects)

  def update_tokens(self, tokens: TokenSet) -> TokenSet:
    """Returns a new `TokenSet` with tokens replaced with ones from the input `TokenSet`."""
    new_tokens = []
    for eff in self.effects():
      if eff in tokens._tokens:
        new_tokens.append((eff, tokens._tokens[eff]))
      else:
        new_tokens.append((eff, self._tokens[eff]))
    return TokenSet(new_tokens)

def lower_jaxpr_to_fun(
    ctx: ModuleContext,
    name: str,
    jaxpr: core.ClosedJaxpr,
    effects: Sequence[core.Effect],
    name_stack: source_info_util.NameStack,
    *,
    public: bool = False,
    replicated_args: Sequence[bool] | None = None,
    arg_shardings: Sequence[JSharding | None] | None = None,
    result_shardings: Sequence[JSharding | None] | None = None,
    use_sharding_annotations: bool = True,
    input_output_aliases: Sequence[int | None] | None = None,
    xla_donated_args: Sequence[bool] | None = None,
    api_name: str = "jit",
    arg_names: Sequence[str | None] | None = None,
    result_names: Sequence[str] | None = None,
    arg_memory_kinds: Sequence[str | None] | None = None,
    result_memory_kinds: Sequence[str | None] | None = None,
    arg_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
    result_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
    propagated_out_mem_kinds: tuple[None | str, ...] | None = None,
) -> func_dialect.FuncOp:
  """Lowers jaxpr and its callees to an IR function.

  Assumes that an MLIR context, location, and insertion point are set.

  Args:
    ctx: the lowering context.
    name: the function name. The name will be uniquified by the symbol table,
      so it is ok to use the same name multiple times.
    jaxpr: the jaxpr to lower.
    effects: a sequence of `core.Effect`s corresponding to an ordering of tokens
      that will be created in or used by the lowered function.
    public: if true, the function's visibility is set to "public".
    replicated_args: if present, annotates arguments as replicated.
    arg_shardings: sharding annotations for each argument (optional).
    result_shardings: sharding annotations for each result (optional).
    use_sharding_annotations: if True, use "mhlo.sharding" annotations on
      parameters and return values to express sharding. If False, use
      hlo.custom_call operators with sharding annotations.
      TODO(b/228598865): remove this option when "mhlo.sharding" annotations are
      propagated on non-entry functions during MLIR->HLO conversion.
    input_output_aliases: optional sequence that maps argument numbers to the
      corresponding output that should alias them.
    xla_donated_args: optional sequence of args to set donation annotations.
    api_name: The name of the higher level primitive which should show up in the
      name stack.
  Returns:
    MLIR func op
  """

  # The first dimension variable may be the platform index
  num_dim_vars = len(ctx.shape_poly_state.dim_vars)
  dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars
  dim_var_types = map(aval_to_ir_types, dim_var_avals)

  # Function inputs: *dim_var_values, *tokens, *actual_inputs
  input_types = map(aval_to_ir_types, jaxpr.in_avals)
  output_types = map(aval_to_ir_types, jaxpr.out_avals)
  num_tokens = len(effects)

  token_types = [token_type() for _ in effects]
  token_avals = [core.abstract_token] * num_tokens
  # Order of arguments: dim vars, tokens, array inputs
  input_avals = dim_var_avals + token_avals + jaxpr.in_avals
  input_types = [*dim_var_types, *token_types, *input_types]
  output_avals = [core.abstract_token] * num_tokens + jaxpr.out_avals
  output_types = [*token_types, *output_types]

  if input_output_aliases is not None:
    token_input_output_aliases = [None] * (num_dim_vars + num_tokens)
    input_output_aliases = [*token_input_output_aliases, *input_output_aliases]
    # Update the existing aliases to account for the new output values
    input_output_aliases = [None if a is None
                            else a + num_tokens
                            for a in input_output_aliases]

  if arg_shardings is not None:
    token_shardings = [None] * (num_dim_vars + num_tokens)
    arg_shardings = [*token_shardings, *arg_shardings]
  if result_shardings is not None:
    token_shardings = [None] * num_tokens
    result_shardings = [*token_shardings, *result_shardings]
  if replicated_args is not None:
    token_replicated_args = [False] * (num_dim_vars + num_tokens)
    replicated_args = [*token_replicated_args, *replicated_args]
  if arg_memory_kinds is not None:
    token_memory_kinds = [None] * (num_dim_vars + num_tokens)
    arg_memory_kinds = [*token_memory_kinds, *arg_memory_kinds]
  if result_memory_kinds is not None:
    token_memory_kinds = [None] * num_tokens
    result_memory_kinds = [*token_memory_kinds, *result_memory_kinds]
  if arg_layouts is not None:
    token_layouts = [None] * (num_dim_vars + num_tokens)
    arg_layouts = [*token_layouts, *arg_layouts]
  if result_layouts is not None:
    token_layouts = [None] * num_tokens
    result_layouts = [*token_layouts, *result_layouts]
  if xla_donated_args is not None:
    xla_donated_args = [*([False] * (num_dim_vars + num_tokens)), *xla_donated_args]

  flat_input_types = util.flatten(input_types)
  flat_output_types = util.flatten(output_types)
  ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
  func_op = func_dialect.FuncOp(name, ftype, ip=ctx.ip)
  func_op.attributes["sym_visibility"] = ir.StringAttr.get(
      "public" if public else "private")
  ctx.symbol_table.insert(func_op)

  ir_arg_shardings = None
  if arg_shardings is not None:
    in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals)
    ir_arg_shardings = util.flatten(
        [[_to_physical_op_sharding(a, s)] * len(types)
         for a, s, types in zip(in_avals, arg_shardings, input_types)])
    del in_avals

  ir_arg_memory_kinds = None
  if arg_memory_kinds is not None:
    ir_arg_memory_kinds = util.flatten(
        [[mk] * len(types) for mk, types in zip(arg_memory_kinds, input_types)])

  ir_arg_layouts = None
  if arg_layouts is not None:
    ir_arg_layouts = util.flatten(
        [[_to_xla_layout(l)] * len(types)
         for l, types in zip(arg_layouts, input_types)])

  ir_donated_args = None
  if xla_donated_args is not None:
    ir_donated_args = util.flatten(
        [[is_donated] * len(types) for is_donated, types in zip(xla_donated_args, input_types)])

  ir_result_shardings = None
  if result_shardings is not None:
    ir_result_shardings = util.flatten(
        [[_to_physical_op_sharding(a, s)] * len(types)
         for a, s, types in zip(output_avals, result_shardings, output_types)])

  ir_result_memory_kinds = None
  custom_call_ir_result_memory_kinds = None
  if result_memory_kinds is not None:
    if propagated_out_mem_kinds is None:
      propagated_out_mem_kinds = (None,) * len(result_memory_kinds)
    res, custom_call_res = [], []
    for pom, mk, types in zip(propagated_out_mem_kinds, result_memory_kinds,
                              output_types):
      if pom is not None and mk is None:
        res.append([pom] * len(types))
      else:
        res.append([mk] * len(types))  # type: ignore
      # To add the custom call on the output to signal a transfer, only do it
      # if memory kind comes from out_shardings on `jit` and result_memory_kinds
      # comes from out_shardings on `jit`.
      custom_call_res.append([mk] * len(types))
    ir_result_memory_kinds = util.flatten(res)
    custom_call_ir_result_memory_kinds = util.flatten(custom_call_res)

  ir_result_layouts = None
  if result_layouts is not None:
    ir_result_layouts = util.flatten(
        [[_to_xla_layout(l)] * len(types)
         for l, types in zip(result_layouts, output_types)])

  if (
      replicated_args is not None
      or ir_arg_shardings is not None
      or ir_arg_memory_kinds is not None
      or ir_arg_layouts is not None
      or input_output_aliases is not None
      or ir_donated_args is not None
      or arg_names is not None
      or num_tokens > 0
      or num_dim_vars > 0
  ):
    arg_attrs: list[dict[str, ir.Attribute]] = [
        {} for _ in range(len(flat_input_types))]

    if replicated_args is not None:
      replicated_ir_args = [[replicated] * len(types) for replicated, types
                            in zip(replicated_args, input_types)]
      for attrs, replicated in zip(arg_attrs, util.flatten(replicated_ir_args)):
        if replicated:
          attrs["mhlo.is_same_data_across_replicas"] = ir.BoolAttr.get(True)

    if use_sharding_annotations and ir_arg_shardings is not None:
      for attrs, sharding in zip(arg_attrs, ir_arg_shardings):
        if sharding is not None:
          attrs["mhlo.sharding"] = get_sharding_attr(sharding)

    if ir_arg_memory_kinds is not None:
      for attrs, memory_kind in zip(arg_attrs, ir_arg_memory_kinds):
        if memory_kind is not None:
          attrs["mhlo.memory_kind"] = ir.StringAttr.get(memory_kind)

    if ir_arg_layouts is not None:
      for attrs, layout in zip(arg_attrs, ir_arg_layouts):
        if layout is not None:
          attrs["mhlo.layout_mode"] = ir.StringAttr.get(layout)

    if ir_donated_args is not None:
      for attrs, is_donated in zip(arg_attrs, ir_donated_args):
        if is_donated:
          attrs["jax.buffer_donor"] = ir.BoolAttr.get(True)

    if input_output_aliases is not None:
      output_ids = util.unflatten(list(range(len(flat_output_types))),
                                  map(len, output_types))
      aliases: list[int | None] = []
      for itypes, alias in zip(input_types, input_output_aliases):
        if alias is None:
          aliases.extend([None] * len(itypes))
        else:
          aliases.extend(output_ids[alias])

      for attrs, alias in zip(arg_attrs, aliases):
        if alias is not None:
          attrs["tf.aliasing_output"] = i32_attr(alias)

    if num_dim_vars > 0:
      for var_name, attrs in zip(ctx.shape_poly_state.dim_vars,
                                 arg_attrs[:num_dim_vars]):
        attrs["jax.global_constant"] = ir.StringAttr.get(var_name)
    elif ctx.lowering_parameters.global_constant_computation:
      for attrs in arg_attrs:
        attrs["jax.global_constant"] = ir.StringAttr.get("")

    if num_tokens > 0:
      token_arg_attrs = arg_attrs[num_dim_vars:num_dim_vars + num_tokens]
      for attrs in token_arg_attrs:
        attrs["jax.token"] = ir.BoolAttr.get(True)

    func_op.arg_attrs = ir.ArrayAttr.get(
        [ir.DictAttr.get(attrs) for attrs in arg_attrs])

  result_attrs: list[dict[str, ir.Attribute]] = [
      {} for _ in range(len(flat_output_types))]

  if num_tokens > 0:
    token_result_attrs = result_attrs[:num_tokens]
    for attrs in token_result_attrs:
      attrs["jax.token"] = ir.BoolAttr.get(True)

  if result_names:
    named_result_attrs = result_attrs[num_tokens:]
    if len(named_result_attrs) == len(result_names):
      for attrs, name_ in zip(named_result_attrs, result_names):
        attrs['jax.result_info'] = ir.StringAttr.get(name_)

  if use_sharding_annotations and ir_result_shardings is not None:
    for attrs, sharding in zip(result_attrs, ir_result_shardings):
      if sharding is not None:
        attrs['mhlo.sharding'] = get_sharding_attr(sharding)

  if ir_result_memory_kinds is not None:
    for attrs, mem_kind in zip(result_attrs, ir_result_memory_kinds):
      if mem_kind is not None:
        attrs['mhlo.memory_kind'] = ir.StringAttr.get(mem_kind)

  if ir_result_layouts is not None:
    for attrs, layout in zip(result_attrs, ir_result_layouts):
      if layout is not None:
        attrs['mhlo.layout_mode'] = ir.StringAttr.get(layout)

  func_op.result_attrs = ir.ArrayAttr.get(
      [ir.DictAttr.get(attrs) for attrs in result_attrs])

  if arg_names:
    arg_locs = [ir.Location.unknown()] * (num_dim_vars + num_tokens)
    for n in arg_names:
      arg_locs.append(ir.Location.name(n) if n else ir.Location.unknown())
    entry_block = func_op.add_entry_block(arg_locs)
  else:
    entry_block = func_op.add_entry_block()

  with ir.InsertionPoint(entry_block):
    flat_args = entry_block.arguments
    # We separate out the dimension variable inputs, the token inputs and
    # the regular inputs. The dimension variables and token inputs
    # will be passed to `jaxpr_subcomp` separately from the `args`.
    dim_var_values, _, _ = util.split_list(flat_args, [num_dim_vars, num_tokens])
    # A lowering context just for function body entry/exit code.
    entry_lowering_ctx = LoweringRuleContext(
        module_context=ctx, name_stack=name_stack, primitive=None,
        avals_in=[], avals_out=None,
        tokens_in=TokenSet.create([]), tokens_out=None,
        axis_size_env=None, dim_var_values=dim_var_values)
    if not use_sharding_annotations and ir_arg_shardings is not None:
      flat_args = [
          a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s)
          for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)]

    if ir_arg_shardings is not None and name == "main":
      flat_args = [
          replicate_trailing_dims(entry_lowering_ctx, o, a)
          if (a is not core.abstract_token and
              dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o  # pytype: disable=attribute-error
          for o, s, a in zip(flat_args, ir_arg_shardings, input_avals)
      ]

    _, token_args, unflattened_args = util.split_list(
        util.unflatten(flat_args, map(len, input_types)),
        [num_dim_vars, num_tokens])
    tokens_in = TokenSet(zip(effects, token_args))
    args: list[list[ir.Value]] = unflattened_args
    if name is not None:
      callee_name_stack = name_stack.extend(util.wrap_name(name, api_name))
    else:
      callee_name_stack = name_stack
    consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
    out_vals, tokens_out = jaxpr_subcomp(
        ctx, jaxpr.jaxpr, callee_name_stack, tokens_in,
        consts, *args, dim_var_values=dim_var_values)
    outs = []
    for eff in effects:
      outs.append(wrap_singleton_ir_values(tokens_out.get(eff)))
    outs.extend(out_vals)

    flat_outputs = util.flatten(outs)

    if not use_sharding_annotations and ir_result_shardings is not None:
      flat_outputs = [
          o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s)
          for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)]

    # Insert a custom call if output is on host because XLA needs that to do the
    # transfer.
    if custom_call_ir_result_memory_kinds is not None and name == "main":
      flat_outputs = [
          o if mk is None else wrap_with_memory_kind(o, mk, o_aval)
          for o, mk, o_aval in zip(
              flat_outputs, custom_call_ir_result_memory_kinds, output_avals)]

    if ir_result_shardings is not None and name == "main":
      flat_outputs = [
          replicate_trailing_dims(entry_lowering_ctx, o, a)
          if (a is not core.abstract_token and
              dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o  # pytype: disable=attribute-error
          for o, s, a in zip(flat_outputs, ir_result_shardings, output_avals)
      ]

    func_dialect.return_(flat_outputs)

  return func_op


def wrap_with_memory_kind(
    x: ir.Value, memory_kind: str, aval_out: core.AbstractValue) -> ir.Value:
  if aval_out is None:
    result_type = x.type
  else:
    result_type = aval_to_ir_type(aval_out)
  op = custom_call("annotate_device_placement", result_types=[result_type],
                   operands=[x], has_side_effect=True, api_version=1)
  dict_attr = {"_xla_buffer_placement": ir.StringAttr.get(memory_kind)}
  op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr)
  return op.result


def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
  # Set the sharding of extended dtypes to be UNCONSTRAINED
  # (i.e. XLA will choose) on aval.shape.
  # For the trailing dims i.e. the dimension of key_shape on the base_array,
  # the sharding is set to be REPLICATED always.
  # For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2),
  # then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None).
  # The below custom call achieves the sharding like above example.
  return wrap_with_sharding_op(
      ctx, val, aval, xc.HloSharding.replicate().to_proto(),
      unspecified_dims=set(range(aval.ndim)))


def _emit_lowering_rule_as_fun(lowering_rule,
                               ctx: LoweringRuleContext) -> func_dialect.FuncOp:
  """Emits the contents of a lowering rule as a private function."""
  num_dim_vars = len(ctx.module_context.shape_poly_state.dim_vars)
  # TODO(necula) maybe only pass the dim_vars if they are needed?
  dim_var_types = map(aval_to_ir_types, [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars)

  input_types = map(aval_to_ir_types, ctx.avals_in)
  output_types = map(aval_to_ir_types, ctx.avals_out)
  effs = list(ctx.tokens_in.effects())
  token_types = [token_type() for _ in effs]
  input_types = [*dim_var_types, *token_types, *input_types]
  output_types = [*token_types, *output_types]

  flat_input_types = util.flatten(input_types)
  flat_output_types = util.flatten(output_types)
  ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
  assert ctx.primitive is not None
  func_op = func_dialect.FuncOp(ctx.primitive.name, ftype,
                                ip=ctx.module_context.ip)
  func_op.attributes["sym_visibility"] = ir.StringAttr.get("private")
  ctx.module_context.symbol_table.insert(func_op)
  entry_block = func_op.add_entry_block()
  with ir.InsertionPoint(entry_block):
    unflattened_args = util.unflatten(entry_block.arguments,
                                      map(len, input_types))
    dim_var_values, token_args, unflattened_args = util.split_list(unflattened_args, [num_dim_vars, len(ctx.tokens_in)])
    sub_ctx = ctx.replace(tokens_in=TokenSet(zip(effs, token_args)),
                          dim_var_values=dim_var_values)
    outs = lowering_rule(sub_ctx, *_unwrap_singleton_ir_values(unflattened_args))
    if sub_ctx.tokens_out:
      outs = [*[sub_ctx.tokens_out.get(eff) for eff in effs], outs]
    func_dialect.return_(util.flatten(map(wrap_singleton_ir_values, outs)))
  return func_op

def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
                  name_stack: source_info_util.NameStack,
                  tokens: TokenSet,
                  consts: Sequence[Sequence[ir.Value]],
                  *args: Sequence[ir.Value],
                  dim_var_values: Sequence[ir.Value]
                  ) -> tuple[Sequence[Sequence[ir.Value]], TokenSet]:
  """Lowers a jaxpr into MLIR, inlined into an existing function.

  Assumes that an MLIR context, location, and insertion point are set.

  dim_var_values: the list of dimension variables values in the current
    IR function, in the order of ctx.shape_poly_state.dim_vars.
  """
  assert "gpu" not in ctx.platforms
  def read(v: core.Atom) -> Sequence[ir.Value]:
    if type(v) is core.Literal:
      return ir_constants(xla.canonicalize_dtype(v.val))
    else:
      assert isinstance(v, core.Var)
      return env[v]

  def aval(v: core.Atom) -> core.AbstractValue:
    if type(v) is core.Literal:
      return xla.abstractify(v.val)
    else:
      return v.aval

  def write(v: core.Var, node: Sequence[ir.Value]):
    assert node is not None
    env[v] = tuple(node)

  def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None:
    if ctx.lowering_parameters.override_lowering_rules is None:
      return None
    for p, rule in ctx.lowering_parameters.override_lowering_rules:
      if primitive is p:
        return rule
    return None

  env: dict[core.Var, tuple[ir.Value, ...]] = {}

  assert isinstance(name_stack, source_info_util.NameStack), type(name_stack)
  assert len(args) == len(jaxpr.invars), (jaxpr, args)
  assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
  assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts
  assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values)
  map(write, jaxpr.constvars, consts)
  map(write, jaxpr.invars, args)
  last_used = core.last_used(jaxpr)
  for eqn in jaxpr.eqns:
    in_nodes = map(read, eqn.invars)
    source_info = eqn.source_info.replace(
        name_stack=name_stack + eqn.source_info.name_stack)
    loc = _source_info_to_location(ctx, eqn.primitive, eqn.params, source_info)
    with (source_info_util.user_context(eqn.source_info.traceback), loc,
          eqn.ctx.manager):
      override_rule = get_override_lowering_rule(eqn.primitive)
      platform_rules: dict[str, LoweringRule] = {}
      default_rule: LoweringRule | None = None
      # See mlir.lower_per_platform for meaning of `platform_rules` and `default_rule`
      if override_rule is not None:
        default_rule = override_rule
      else:
        # First the platform-specific rules
        for p in ctx.platforms:
          if eqn.primitive in _platform_specific_lowerings[p]:
            platform_rules[p] = _platform_specific_lowerings[p][eqn.primitive]
          elif eqn.primitive in xla._backend_specific_translations[p]:
            platform_rules[p] = xla_fallback_lowering(eqn.primitive)
        # Now the default rule
        if eqn.primitive in _lowerings:
          default_rule = _lowerings[eqn.primitive]
        elif eqn.primitive in xla._translations:
          default_rule = xla_fallback_lowering(eqn.primitive)

      effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
      tokens_in = tokens.subset(effects)
      avals_in = map(aval, eqn.invars)
      compute_type = eqn.ctx.compute_type if eqn.ctx is not None else None
      rule_ctx = LoweringRuleContext(
          module_context=ctx, primitive=eqn.primitive,
          name_stack=source_info.name_stack,
          avals_in=avals_in,
          avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
          tokens_out=None, dim_var_values=dim_var_values,
          compute_type=compute_type)
      if config.dynamic_shapes.value:
        axis_size_env = {d: read(d)[0]
                         for a in avals_in if type(a) is core.DShapedArray
                         for d in a.shape if type(d) is core.Var}
        rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)

      rule_inputs = map(_unwrap_singleton_ir_values, in_nodes)
      ans = lower_per_platform(rule_ctx, str(eqn.primitive),
                               platform_rules, default_rule,
                               eqn.effects,
                               *rule_inputs, **eqn.params)

      if effects:
        # If there were ordered effects in the primitive, there should be output
        # tokens we need for subsequent ordered effects.
        tokens_out = rule_ctx.tokens_out
        if tokens_out is None:
          raise ValueError(
              f'Lowering rule for `{eqn.primitive}` needs to set `tokens_out` '
              f'because it has effects: {eqn.effects}.')
        if tokens_out.effects() != tokens_in.effects():
          raise ValueError(
              f'Lowering rule for `{eqn.primitive}` '
              'returns incorrect set of output tokens. '
              f'Expected: {tuple(tokens_in.effects())} vs. Actual: {tuple(tokens_out.effects())}')
        tokens = tokens.update_tokens(tokens_out)

    try:
      out_nodes = tuple(map(wrap_singleton_ir_values, ans))
    except TypeError as e:
      raise ValueError("Output of translation rule must be iterable: "
                       f"{eqn}, got output {ans}") from e

    assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn)
    assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (
      ans, "lowering function returned a bad output", eqn)
    assert len(ans) == len(eqn.outvars), (ans, eqn)
    map(write, eqn.outvars, out_nodes)
    core.clean_up_dead_vars(eqn, env, last_used)
  return map(read, jaxpr.outvars), tokens


def lower_per_platform(ctx: LoweringRuleContext,
                       description: str,
                       platform_rules: dict[str, LoweringRule],
                       default_rule: LoweringRule | None,
                       effects: effects_lib.Effects,
                       *rule_args: ir.Value,
                       **rule_kwargs) -> Sequence[ir.Value]:
  """Emits code for a primitive for the current lowering platform(s).

  For example, given
      platform_rules = dict(tpu=rule0, cpu=rule0)
      default_rule = rule1

  and
      ctx.module_context.lowering_parameters.platforms = ("cpu",)

  emits:
      rule0(ctx, *rule_args, **rule_kwargs)

  In case of multi-platform lowering, e.g., if
      ctx.module_context.lowering_parameters.platforms = ("cpu", "cuda", "tpu")

  emits:
    rule_idx = case current_platform_idx:
                   0: return 0  # cpu rule index
                   1: return 1  # cuda rule index
                   2: return 0  # tpu rule index
    output = case rule_idx
               0: return rule0(*rule_args, **rule_kwargs)
               1: return rule1(*rule_args, **rule_kwargs)

  Args:
   ctx: lowering context.
   description: a string to include in error messages.
   platform_rules: map platform names, e.g., "cpu", "cuda", to
     `LoweringRule`s, for the platforms that have non-default lowering.
   default_rule: an optional rule to use for platforms not in `platform_rules`.
   effects: the set of effects for the current primitive.
   rule_args: the args of the lowering rules.
   rule_kwargs: the kwargs of the lowering rules.
  """
  platforms: Sequence[str] = ctx.platforms or ctx.module_context.platforms
  # Special case the common case (single-platform lowering)
  if len(platforms) == 1:
    rule = platform_rules.get(platforms[0], default_rule)
    if rule is None:
      raise NotImplementedError(
        f"MLIR translation rule for primitive '{description}' not "
        f"found for platform {platforms[0]}")

  # Multi-platform lowering
  kept_rules: list[LoweringRule] = []  # Only the rules for the platforms of interest
  platform_to_kept_rules_idx: dict[str, int] = {}
  for p, prule in platform_rules.items():
    if p not in platforms:
      continue
    platform_to_kept_rules_idx[p] = len(kept_rules)
    kept_rules.append(prule)

  platforms_without_specific_rule = [p for p in platforms
                                     if p not in platform_to_kept_rules_idx]
  if platforms_without_specific_rule:
    if default_rule is None:
      raise NotImplementedError(
        f"MLIR translation rule for primitive '{description}' not "
        f"found for platforms {platforms_without_specific_rule}")
    for p in platforms_without_specific_rule:
      platform_to_kept_rules_idx[p] = len(kept_rules)
    kept_rules.append(default_rule)

  assert kept_rules
  # If there is a single rule left just apply the rule, without conditionals.
  if len(kept_rules) == 1:
    output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
    wrapped_out = map(wrap_singleton_ir_values, output)
    map(lambda o: wrap_compute_type_in_place(ctx, o.owner),
        util.flatten(wrapped_out))
    return output

  assert len(platforms) > 1 and len(kept_rules) >= 2, (platforms, kept_rules)
  assert len(ctx.dim_var_values) >= 1, "Must have a platform_index variable"

  # The first dim_var_values is the platform index
  current_platform_idx = ctx.dim_var_values[0]
  # Compute the rule index based on the current platform
  i32_type = aval_to_ir_types(core.ShapedArray((), dtype=np.int32))[0]
  if current_platform_idx.type != i32_type:
    current_platform_idx = hlo.convert(i32_type, current_platform_idx)
  rule_idx_op = hlo.CaseOp([i32_type],
                           index=current_platform_idx,
                           num_branches=len(platforms))
  for i, p in enumerate(platforms):
    branch = rule_idx_op.regions[i].blocks.append()
    with ir.InsertionPoint(branch):
      hlo.return_(ir_constants(np.int32(platform_to_kept_rules_idx[p])))
  ordered_effects = effects_lib.ordered_effects.filter_in(effects)
  rule_out_avals = [core.abstract_token] * len(ordered_effects) + ctx.avals_out
  output_types = map(aval_to_ir_types, rule_out_avals)
  case_op = hlo.CaseOp(util.flatten(output_types),
                      index=rule_idx_op,
                      num_branches=len(kept_rules))
  for i, rule in enumerate(kept_rules):
    platforms_for_this_rule = [p
                               for p, rule_idx in platform_to_kept_rules_idx.items()
                               if rule_idx == i]
    inner_ctx = ctx.replace(platforms=platforms_for_this_rule)
    branch = case_op.regions[i].blocks.append()
    with ir.InsertionPoint(branch):
      output = rule(inner_ctx, *rule_args, **rule_kwargs)
      try:
        out_nodes = map(wrap_singleton_ir_values, output)
      except TypeError as e:
        raise ValueError("Output of translation rule must be iterable: "
                        f"{description}, got output {output}") from e
      map(lambda o: wrap_compute_type_in_place(ctx, o.owner),
          util.flatten(out_nodes))
      if inner_ctx.tokens_out is not None:
        assert len(ordered_effects) == len(inner_ctx.tokens_out)
        out_nodes = [inner_ctx.tokens_out.get(eff)
                     for eff in ordered_effects] + out_nodes
      hlo.return_(util.flatten(map(wrap_singleton_ir_values, out_nodes)))

  results = case_op.results
  if ordered_effects:
    tokens, results = util.split_list(
      util.unflatten(results, map(len, output_types)),
      [len(ordered_effects)])
    tokens_out = ctx.tokens_in.update_tokens(TokenSet(zip(ordered_effects,
                                                          tokens)))
    ctx.set_tokens_out(tokens_out)
  return results

def _ir_consts(consts):
  unique_consts = {id(const): const for const in consts}
  ir_consts = {
      id_: ir_constants(xla.canonicalize_dtype(const))
      for id_, const in unique_consts.items()
  }
  return [ir_consts[id(const)] for const in consts]


def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
  """Converts a traceable JAX function `fun` into a lowering rule.

  The returned function does not use `avals_out`, so callers may pass any value
  as `avals_out`."""
  def f_lowered(ctx: LoweringRuleContext, *args, **params):
    f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
    wrapped_fun = lu.wrap_init(f, params)

    if config.dynamic_shapes.value:
      # We might be applying this function to arguments with dynamic shapes,
      # i.e. there might be Vars in the shape tuples of ctx.avals_in. In that
      # case, we need to form a jaxpr with leading binders for those axis size
      # arguments (by computing an InputType and using trace_to_jaxpr_dynamic2),
      # and we need to call jaxpr_subcomp with these arguments made explicit.
      assert ctx.axis_size_env is not None
      args = (*ctx.axis_size_env.values(), *args)
      idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)}
      i32_aval = core.ShapedArray((), np.dtype('int32'))
      implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env)
      explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape))  # type: ignore
                        if type(a) is core.DShapedArray else a, True)
                       for a in ctx.avals_in]
      wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args))
      jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun)
    else:
      jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
      # TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?

    if ctx.platforms is not None:
      sub_context = ctx.module_context.replace(platforms=ctx.platforms)
    else:
      sub_context = ctx.module_context
    out, tokens = jaxpr_subcomp(
        sub_context, jaxpr, ctx.name_stack, ctx.tokens_in,
        _ir_consts(consts), *map(wrap_singleton_ir_values, args),
        dim_var_values=ctx.dim_var_values)
    ctx.set_tokens_out(tokens)
    return out

  return f_lowered


def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, name_stack,
                               arg_names=None, result_names=None):
  if not call_jaxpr.consts and arg_names is result_names is None:
    # Cacheable.
    key = (fn_name, call_jaxpr.jaxpr, tuple(effects))
    try:
      func_op = ctx.cached_primitive_lowerings[key]
    except KeyError:
      func_op = lower_jaxpr_to_fun(
          ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names,
          result_names=result_names)
      ctx.cached_primitive_lowerings[key] = func_op
  else:
    func_op = lower_jaxpr_to_fun(
        ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names,
        result_names=result_names)
  return func_op


def check_backend_matches(inner_backend: str | None,
                          lowering_platforms: Sequence[str]):
  # For nested calls, the outermost call sets the backend for all inner calls;
  # it's an error if the inner call has a conflicting explicit backend spec.
  if inner_backend is None:
    return
  outer_backend, *more_lowering_platforms = lowering_platforms
  if more_lowering_platforms:
    raise NotImplementedError(
        "Multi-platform lowering when a backend= parameter is specified")
  if (inner_backend != outer_backend and
      outer_backend not in xb.expand_platform_alias(inner_backend)):
    raise ValueError(
        f"Outer-jit backend specification {outer_backend} must match explicit "
        f"inner-jit backend specification {inner_backend}.")


def call_lowering(fn_name, name_stack, call_jaxpr, backend,
                  ctx: ModuleContext, avals_in,
                  avals_out, tokens_in, *args,
                  dim_var_values: Sequence[ir.Value],
                  arg_names=None, result_names=None):
  del avals_in
  if isinstance(call_jaxpr, core.Jaxpr):
    call_jaxpr = pe.close_jaxpr(call_jaxpr)
  check_backend_matches(backend, ctx.platforms)
  effects = list(tokens_in.effects())
  output_types = map(aval_to_ir_types, avals_out)
  output_types = [token_type()] * len(effects) + output_types
  flat_output_types = util.flatten(output_types)
  symbol_name = _lower_jaxpr_to_fun_cached(
      ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names,
      result_names=result_names).name.value
  tokens = [tokens_in.get(eff) for eff in effects]
  args = (*dim_var_values, *tokens, *args)
  call = func_dialect.CallOp(flat_output_types,
                             ir.FlatSymbolRefAttr.get(symbol_name),
                             flatten_lowering_ir_args(args))
  out_nodes = util.unflatten(call.results, map(len, output_types))
  tokens, out_nodes = util.split_list(out_nodes, [len(effects)])
  tokens_out = tokens_in.update_tokens(TokenSet(zip(effects, tokens)))
  return out_nodes, tokens_out

def core_call_lowering(ctx: LoweringRuleContext,
                       *args, name, backend=None, call_jaxpr):
  out_nodes, tokens = call_lowering(
      name, ctx.name_stack, call_jaxpr, backend, ctx.module_context,
      ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args,
      dim_var_values=ctx.dim_var_values)
  ctx.set_tokens_out(tokens)
  return out_nodes

register_lowering(core.call_p, partial(core_call_lowering, name="core_call"))
register_lowering(core.closed_call_p,
                  partial(core_call_lowering, name=None))

def map_compute_type(c_type):
  if c_type == 'device_host':
    return 'host'
  elif c_type == 'device':
    return 'dense'
  raise ValueError('Invalid compute type received. Current supported values '
                   'are `device_host` and `device`')

def wrap_compute_type_in_place(ctx, op):
  if ctx.compute_type is not None:
    dict_attr = {"_xla_compute_type": ir.StringAttr.get(
        map_compute_type(ctx.compute_type))}
    op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr)


def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *,
                     broadcast_dimensions) -> ir.Value:
  # broadcast_dimension[i] is the axis of the result where the axis i of
  # op is broadcast.
  # Lower a possibly-dynamic broadcast_in_dim
  if dtypes.issubdtype(aval_out.dtype, dtypes.extended):  # type: ignore
    elt_shape = aval_out.dtype._rules.physical_element_aval(  # type: ignore
        aval_out.dtype).shape                                 # type: ignore
    trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))]  # type: ignore
    broadcast_dimensions = [*broadcast_dimensions, *trailing_dims]
    physical_aval_out = core.physical_aval(aval_out)
    return broadcast_in_dim(
        ctx, op, physical_aval_out, broadcast_dimensions=broadcast_dimensions)
  else:
    if not core.is_constant_shape(aval_out.shape):  # type: ignore
      shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape)  # type: ignore
      out = hlo.dynamic_broadcast_in_dim(
          aval_to_ir_type(aval_out), op,
          shape,
          dense_int_array(broadcast_dimensions),
      )
    else:
      assert all(d != ir.ShapedType.get_dynamic_size()
                 for d in aval_out.shape), aval_out  # type: ignore
      out = hlo.broadcast_in_dim(
          aval_to_ir_type(aval_out), op,
          dense_int_array(broadcast_dimensions))
    wrap_compute_type_in_place(ctx, out.owner)
    return out

def multi_broadcast_in_dim(ctx: LoweringRuleContext,
                           ops: Sequence[ir.Value],
                           ops_avals: Sequence[core.AbstractValue],
                           out_shape: core.Shape) -> Sequence[ir.Value]:
  """Broadcasts multiple ops to the out_shape."""
  out = []
  for op, op_aval in zip(ops, ops_avals):
    op_aval_shape = op_aval.shape  # type: ignore
    if core.definitely_equal_shape(op_aval_shape, out_shape):
      out.append(op)
    else:
      assert len(op_aval_shape) <= len(out_shape), (op_aval_shape, out_shape)
      broadcast_dimensions = list(range(len(out_shape) - len(op_aval_shape), len(out_shape)))
      out.append(broadcast_in_dim(ctx, op,
                                  core.ShapedArray(out_shape, op_aval.dtype),  # type: ignore
                                  broadcast_dimensions=broadcast_dimensions))
  return out

def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Value:
  aval_out = core.physical_aval(aval_out)
  if not core.is_constant_shape(aval_out.shape):  # type: ignore
    shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape)  # type: ignore
    return hlo.dynamic_reshape(
        aval_to_ir_type(aval_out), op, shape,
    )
  else:
    return hlo.reshape(aval_to_ir_type(aval_out), op)

def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
             start_indices, limit_indices, strides) -> ir.Value:
  if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
    elt_shape = aval_out.dtype._rules.physical_element_aval(
        aval_out.dtype).shape
    trailing_zeros = [0] * len(elt_shape)
    trailing_ones  = [1] * len(elt_shape)
    start_indices = (*start_indices, *trailing_zeros)
    limit_indices = (*limit_indices, *elt_shape)
    strides = (*strides, *trailing_ones)
    physical_aval_out = core.physical_aval(aval_out)
    return slice_op(ctx, x, physical_aval_out, start_indices=start_indices,
                    limit_indices=limit_indices, strides=strides)
  else:
    if any(not core.is_constant_shape(s) for s in (start_indices, limit_indices, strides)):
      start_indices = eval_dynamic_shape_as_tensor(ctx, start_indices)
      limit_indices = eval_dynamic_shape_as_tensor(ctx, limit_indices)
      strides = eval_dynamic_shape_as_tensor(ctx, strides)
      return hlo.real_dynamic_slice(
        aval_to_ir_type(aval_out),
        x, start_indices, limit_indices, strides)
    else:
      return hlo.slice(x,
                       dense_int_array(start_indices),
                       dense_int_array(limit_indices),
                       dense_int_array(strides))

def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
                  start_indices) -> ir.Value:
  x_aval = ctx.avals_in[0]
  if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
    elt_shape = aval_out.dtype._rules.physical_element_aval(
        aval_out.dtype).shape
    index_avals = ctx.avals_in[1:]
    dtype = dtypes.canonicalize_dtype(
        index_avals[0].dtype if index_avals else 'int64')  # type: ignore
    trailing_zeros = [ir_constant(np.array(0, dtype))] * len(elt_shape)
    start_indices = (*start_indices, *trailing_zeros)
    aval_out = core.physical_aval(aval_out)
    x_aval = core.physical_aval(x_aval)

  slice_sizes = aval_out.shape
  if not core.is_constant_shape(slice_sizes):
    # lax.dynamic_slice clamps the start indices, but we are going to
    # lower to RealDynamicSliceOp, which is a version of SliceOp, and does
    # not have the clamping behavior. We clamp start ourselves.
    slice_sizes = eval_dynamic_shape_as_tensor(ctx, slice_sizes)
    clamped_start = hlo.clamp(
      shape_tensor([0] * len(start_indices)),
      shape_tensor(start_indices),
      hlo.subtract(
        eval_dynamic_shape_as_tensor(ctx, x_aval.shape),  # type: ignore
        slice_sizes))
    return hlo.real_dynamic_slice(
        aval_to_ir_type(aval_out), x,
        clamped_start,
        hlo.add(clamped_start, slice_sizes),
        shape_tensor([1] * len(start_indices))
    )
  else:
    return hlo.dynamic_slice(x, start_indices, dense_int_array(slice_sizes))

def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
                         start_indices) -> ir.Value:
  if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
    elt_shape = aval_out.dtype._rules.physical_element_aval(
        aval_out.dtype).shape
    index_avals = ctx.avals_in[2:]
    dtype = dtypes.canonicalize_dtype(
        index_avals[0].dtype if index_avals else 'int64')  # type: ignore
    zeros = [ir_constant(np.array(0, dtype=dtype))] * len(elt_shape)
    start_indices = (*start_indices, *zeros)
    physical_aval_out = core.physical_aval(aval_out)
    return dynamic_update_slice(ctx, physical_aval_out, x, update,
                                start_indices=start_indices)
  else:
    # TODO(necula): handle dynamic shapes
    return hlo.dynamic_update_slice(x, update, start_indices)

def pad(ctx: LoweringRuleContext, aval_out,
        x, padding_value,
        padding_low, padding_high, padding_interior) -> ir.Value:
  if all(core.is_constant_shape(s) for s in (padding_low,
                                             padding_high, padding_interior)):
    return hlo.pad(x, padding_value,
                   dense_int_array(padding_low),
                   dense_int_array(padding_high),
                   dense_int_array(padding_interior))
  else:
    padding_low = eval_dynamic_shape_as_tensor(ctx, padding_low)
    padding_high = eval_dynamic_shape_as_tensor(ctx, padding_high)
    padding_interior = eval_dynamic_shape_as_tensor(ctx, padding_interior)
    return hlo.dynamic_pad(
        aval_to_ir_type(aval_out),
        x, padding_value, padding_low, padding_high, padding_interior)

def iota(ctx: LoweringRuleContext, aval_out, *, dimension: int):
  if not core.is_constant_shape(aval_out.shape):
    shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape)
    return hlo.dynamic_iota(
        aval_to_ir_type(aval_out),
        shape,
        i64_attr(dimension),
    )
  else:
    return hlo.iota(aval_to_ir_type(aval_out), i64_attr(dimension))

def full_like_aval(ctx: LoweringRuleContext, value, aval: core.ShapedArray) -> ir.Value:
  """Returns an IR constant shaped full of `value` shaped like `aval`."""
  zero = ir_constant(np.array(value, dtypes.canonicalize_dtype(aval.dtype)))
  return broadcast_in_dim(ctx, zero, aval, broadcast_dimensions=())

def add_jaxvals_lowering(ctx, x, y):
  if (isinstance(a := ctx.avals_in[0], core.ShapedArray) and
      dtypes.issubdtype(a.dtype, dtypes.extended)):
    return lower_fun(lambda x, y: [a.dtype._rules.add(a.dtype, x, y)])(ctx, x, y)
  return [hlo.add(x, y)]
register_lowering(ad_util.add_jaxvals_p, add_jaxvals_lowering)

register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x])


def compare_hlo(x, y, direction: str, comparison_type: str | None = None):
  """Creates CompareOp."""
  if comparison_type is None:
    elem_type = ir.RankedTensorType(x.type).element_type
    if isinstance(elem_type, ir.IntegerType):
      comparison_type = "UNSIGNED" if elem_type.is_unsigned else "SIGNED"
    else:
      comparison_type = "FLOAT"

  return hlo.compare(
      x,
      y,
      hlo.ComparisonDirectionAttr.get(direction),
      compare_type=hlo.ComparisonTypeAttr.get(comparison_type))

def _minmax_hlo(op, cmp, x, y):
  """Min/max that compares complex values lexicographically as pairs."""
  tensor_type = ir.RankedTensorType(x.type)
  if ir.ComplexType.isinstance(tensor_type.element_type):
    rx = hlo.real(x)
    ry = hlo.real(y)
    real_eq = compare_hlo(rx, ry, "EQ", "FLOAT")
    real_cmp = compare_hlo(rx, ry, cmp, "FLOAT")
    imag_cmp = compare_hlo(hlo.imag(x), hlo.imag(y), cmp, "FLOAT")
    which = hlo.select(real_eq, imag_cmp, real_cmp)
    return hlo.select(which, x, y)
  else:
    return op(x, y)

min_hlo = partial(_minmax_hlo, hlo.minimum, "LT")
max_hlo = partial(_minmax_hlo, hlo.maximum, "GT")


def convert_hlo(ctx: LoweringRuleContext, x, aval_in, aval_out):
  """Variant of convert that has HLO semantics.

  In particular, treat casts to boolean as x != 0, rather than truncating
  integer values (b/209440332)."""
  if (not dtypes.issubdtype(aval_out.dtype, dtypes.extended) and
      aval_out.dtype == np.dtype(np.bool_)):
    if dtypes.issubdtype(aval_in.dtype, np.inexact):
      compare_type = "FLOAT"
    elif dtypes.issubdtype(aval_in.dtype, np.signedinteger):
      compare_type = "SIGNED"
    else:
      compare_type = "UNSIGNED"
    x = compare_hlo(x, full_like_aval(ctx, 0, aval_in), "NE", compare_type)
    # continue, to adjust the shape if needed
  return hlo.convert(aval_to_ir_type(aval_out), x)

def _wrap_with_spmd_op(name: str,
                       ctx: LoweringRuleContext,
                       x: ir.Value,
                       aval_out: core.AbstractValue,
                       sharding_proto: xc.OpSharding,
                       unspecified_dims: set[int] | None = None,
                       has_side_effect: bool = False):
  # unspecified_dims indicate dimensions whose shardings are not specified and
  # XLA sharding propagation can change them.
  if unspecified_dims:
    backend_config = "unspecified_dims=[" + ",".join(
        [str(i) for i in sorted(unspecified_dims)]) + "]"
  else:
    backend_config = ""
  result_type = aval_to_ir_type(aval_out)
  out_shape = core.physical_aval(aval_out).shape  # type: ignore
  if core.is_constant_shape(out_shape):
    result_shapes = None
  else:
    result_shapes = [eval_dynamic_shape_as_tensor(ctx, out_shape)]

  op = custom_call(name, result_types=[result_type], operands=[x],
                   backend_config=backend_config,
                   api_version=1,
                   result_shapes=result_shapes,
                   has_side_effect=has_side_effect)
  set_sharding(op, sharding_proto)
  return op.result


wrap_with_sharding_op = partial(_wrap_with_spmd_op, "Sharding")
wrap_with_full_to_shard_op = partial(_wrap_with_spmd_op, "SPMDFullToShardShape")
wrap_with_shard_to_full_op = partial(_wrap_with_spmd_op, "SPMDShardToFullShape")

def set_sharding(op, sharding_proto: xc.OpSharding):
  op.attributes["mhlo.sharding"] = get_sharding_attr(sharding_proto)


def get_sharding_attr(sharding_proto: xc.OpSharding):
  # If there are very large numbers of devices, use the proto representation.
  # The MHLO to HLO conversion supports both, and the proto representation is
  # more compact.
  if len(sharding_proto.tile_assignment_devices) > 100:
    return ir.StringAttr.get(sharding_proto.SerializeToString())  # type: ignore
  else:
    return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto)))


# MLIR lowerings for lax primitives

def cache_lowering(f):
  """Decorator that causes the contents of a lowering rule to be reused.

  The lowering will be emitted out-of-line in a separate function, together with
  a call to that function. If the same primitive is called with the same shapes
  and parameters, a new call to the original function will be added, without
  emitting a new function. We allow for different lowering for the same
  primitive for different platforms in the same module.
  """
  @functools.wraps(f)
  def cached_lowering(ctx, *args, **params):
    assert ctx.primitive is not None
    key = (f, ctx.primitive,
           tuple(ctx.avals_in), tuple(ctx.avals_out),
           tuple(params.items()))
    try:
      func = ctx.module_context.cached_primitive_lowerings.get(key)
    except TypeError:
      # If the parameters aren't hashable, give up on caching.
      # TODO(phawkins): switch to requiring hashability, when XLA fallback
      # computations have been ported to MLIR.
      return f(ctx, *args, **params)
    if func is None:
      func = _emit_lowering_rule_as_fun(partial(f, **params), ctx)
      ctx.module_context.cached_primitive_lowerings[key] = func

    output_types = map(aval_to_ir_types, ctx.avals_out)
    args = tuple(ctx.dim_var_values) + args
    flat_output_types = util.flatten(output_types)
    call = func_dialect.CallOp(flat_output_types,
                               ir.FlatSymbolRefAttr.get(func.name.value),
                               flatten_lowering_ir_args(args))
    return util.unflatten(call.results, map(len, output_types))
  return cached_lowering


def xla_computation_to_mlir_module(xla_computation: xc.XlaComputation
                                  ) -> ir.Module:
  module_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation)
  return ir.Module.parse(module_str)

def merge_mlir_modules(dst_module: ir.Module,
                       sym_name: str,
                       src_module: ir.Module,
                       dst_symtab: ir.SymbolTable | None = None) -> str:
  """
  Args:
    dst_module: the module into which the contents of src_module should be
      moved. Nothing in dst_module will be renamed.
    sym_name: the desired name for the "main" function of src_module after
      merging. This is a hint: the true name may be different because of symbol
      uniquification, and the true name is returned by this function.
    src_module: the module whose contents are to be alpha-renamed, set to
      private visibility, and merged into dst_module. src_module must contain
      exactly one symbol named "main".
    dst_symtab: the symbol table of `dst_module`

      Functions in src_module will be renamed such that they do not collide with
      functions in dst_module.

      This function mutates `src_module`. On return, `src_module` is left in an
      undefined state.

  Returns:
    the name of src_module's main() function, after renaming.
  """
  assert dst_module.context == src_module.context

  src_symtab = ir.SymbolTable(src_module.operation)
  dst_symtab = dst_symtab or ir.SymbolTable(dst_module.operation)
  used_names = set()

  # Rename all symbols in src_module that clash with names in dst_module, or
  # are the "main" symbol.
  renamings = {}
  for op in src_module.body.operations:
    name = op.name.value
    should_rename = name in dst_symtab or name == "main"
    if should_rename:
      base_name = sym_name if name == "main" else name
      new_name = base_name
      i = 0
      # Replacements are chosen such that the new names are present in neither
      # src_module, dst_module, or the set of fresh names we've already used.
      # Since we rename names one at a time, if new names were in src_module,
      # they might themselves collide with a later renaming.
      while (new_name in src_symtab or new_name in dst_symtab or
             new_name in used_names):
        new_name = f"{base_name}_{i}"
        i += 1
      renamings[name] = new_name
      used_names.add(new_name)

  # Apply the symbol renamings to symbol definitions.
  private = ir.StringAttr.get("private")
  for op in src_module.body.operations:
    if op.name.value in renamings:
      src_symtab.set_symbol_name(op, renamings[op.name.value])
    op.attributes["sym_visibility"] = private

  # Apply the symbol renamings to symbol uses.
  for old_name, new_name in renamings.items():
    for op in src_module.body.operations:
      src_symtab.replace_all_symbol_uses(old_name, new_name, op)

  for op in src_module.body.operations:
    dst_module.body.append(op)
    dst_symtab.insert(op)

  return renamings["main"]


def xla_fallback_lowering(prim: core.Primitive):
  @cache_lowering
  def fallback(ctx: LoweringRuleContext, *args, **params):
    module_ctx = ctx.module_context
    axis_ctx = module_ctx.axis_context
    if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
      axis_env = axis_ctx.unsafe_axis_env
    else:
      axis_env = module_ctx.axis_env

    if any(hasattr(a, "shape") and
           not core.is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
      raise NotImplementedError(
          f"Shape polymorphism for xla_fallback_lowering is not implemented ({ctx.primitive}); b/261682623")

    if len(module_ctx.platforms) > 1:
      raise NotImplementedError(
        "fallback lowering not implemented for multi-platform lowering")
    xla_computation = xla.primitive_subcomputation(
        module_ctx.platforms[0], axis_env, prim, ctx.avals_in,
        ctx.avals_out, **params)
    xla_module = xla_computation_to_mlir_module(xla_computation)
    callee_name = merge_mlir_modules(
        module_ctx.module, f"xla_fallback_{prim.name}", xla_module,
        dst_symtab=module_ctx.symbol_table)
    output_types = map(aval_to_ir_types, ctx.avals_out)
    flat_output_types = util.flatten(output_types)
    output_type = (ir.TupleType.get_tuple(flat_output_types)
                   if prim.multiple_results else flat_output_types[0])

    call = func_dialect.CallOp([output_type],
                               ir.FlatSymbolRefAttr.get(callee_name),
                               flatten_lowering_ir_args(args)).result
    if not prim.multiple_results:
      return [call]
    flat_results = [hlo.get_tuple_element(call, i32_attr(i))
                    for i in range(len(flat_output_types))]

    return util.unflatten(flat_results, map(len, output_types))
  return fallback


DEVICE_TO_DEVICE_TYPE = 1
SEND_TO_HOST_TYPE = 2
RECV_FROM_HOST_TYPE = 3


def is_empty_shape(s: core.Shape) -> bool:
  return any(d == 0 for d in s)

def send_to_host(channel: int, token: hlo.TokenType, operand: Any,
                 aval: core.ShapedArray, name: str, *,
                 sharding: xc.OpSharding | None = None) -> ir.Value:
  channel_handle = hlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE)
  send_op = hlo.SendOp([operand], token, channel_handle,
                        is_host_transfer=ir.BoolAttr.get(True))
  send_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
      dict(
          _xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
          _xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
  if sharding is not None:
    set_sharding(send_op, sharding)
  return send_op.result


def receive_from_host(channel: int, token: hlo.TokenType,
                      out_aval: core.ShapedArray, name: str, *,
                      sharding: xc.OpSharding | None = None,
) -> tuple[ir.Value, ir.Value]:
  channel_handle = hlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE)
  recv_op = hlo.RecvOp([aval_to_ir_type(out_aval),
                        hlo.TokenType.get()], token, channel_handle,
                        is_host_transfer=ir.BoolAttr.get(True))
  recv_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
      dict(
          _xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
          _xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
  if sharding is not None:
    set_sharding(recv_op, sharding)
  # Token should be at the end of the results
  result, token = recv_op.results
  return token, result


def _emit_tpu_python_callback(
    backend: xb.XlaBackend,
    ctx: LoweringRuleContext,
    callback,
    token: Any | None,
    operands: Sequence[ir.Value],
    operand_avals: Sequence[core.ShapedArray],
    operand_shapes: Sequence[xc.Shape],
    result_avals: Sequence[core.ShapedArray],
    result_shapes: Sequence[xc.Shape],
    *,
    sharding: xc.OpSharding | None = None
) -> tuple[Sequence[ir.Value], Any]:
  token = token or hlo.create_token()
  _wrapped_callback = callback

  send_channels = []
  if not operand_avals:
    # If there are no operands to the callback, we need to insert a dummy send
    # op or the callback will never be triggered!
    # TODO(sharadmv,chky): Enable this fix in the runtime as opposed to in
    # MLIR builder.
    callback_without_args = _wrapped_callback
    def _wrapped_callback(*args):  # pylint: disable=function-redefined
      del args
      return callback_without_args()
    send_channel = ctx.module_context.new_channel()
    dummy_send_aval = core.ShapedArray((1,), np.float32)
    dummy_send_val = ir_constant(np.zeros(1, np.float32))
    operand_shapes = [*operand_shapes,
                      xla.aval_to_xla_shapes(dummy_send_aval)[0]]
    token = send_to_host(send_channel, token, dummy_send_val, dummy_send_aval,
                         callback.__name__, sharding=sharding)
    send_channels.append(send_channel)
  else:
    for operand, operand_aval in zip(operands, operand_avals):
      channel = ctx.module_context.new_channel()
      token = send_to_host(channel, token, operand, operand_aval,
                           callback.__name__, sharding=sharding)
      send_channels.append(channel)

  recv_channels = []
  outputs = []
  for result_aval in result_avals:
    channel = ctx.module_context.new_channel()
    assert isinstance(result_aval, core.ShapedArray)
    token, out = receive_from_host(channel, token, result_aval,
                                   callback.__name__, sharding=sharding)
    outputs.append(out)
    recv_channels.append(channel)
  ifrt_callback = backend.make_python_callback_from_host_send_and_recv(
      _wrapped_callback, operand_shapes, result_shapes, send_channels,
      recv_channels, pickle_util.dumps)
  ctx.module_context.add_host_callback(ifrt_callback)
  return outputs, token

def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None):
  if minor_to_major is None:
    # Needed for token layouts
    layout = np.zeros((0,), dtype="int64")
  else:
    layout = np.array(minor_to_major, dtype="int64")
  return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get())

def _aval_to_default_layouts(aval):
  avals = [core.physical_aval(aval)]
  # Row major order is default for `NumPy`.
  return [list(range(aval.ndim - 1, -1, -1)) for aval in avals]


def emit_python_callback(
    ctx: LoweringRuleContext,
    callback,
    token: Any | None,
    operands: Sequence[ir.Value],
    operand_avals: Sequence[core.ShapedArray],
    result_avals: Sequence[core.ShapedArray],
    *,
    has_side_effect: bool,
    sharding: xc.OpSharding | None = None,
    operand_layouts: Sequence[Sequence[int] | None] | None = None,
    result_layouts: Sequence[Sequence[int] | None] | None = None,
) -> tuple[Sequence[ir.Value], Any, Any]:
  """Emits MLIR that calls back to a provided Python function."""
  if len(ctx.module_context.platforms) > 1:
    raise NotImplementedError("multi-platform lowering for python_callback")
  platform = ctx.module_context.platforms[0]
  if platform not in {"cpu", "cuda", "rocm", "tpu"}:
    raise ValueError(
        f"`EmitPythonCallback` not supported on {platform} backend.")
  backend = ctx.module_context.backend
  result_shapes = util.flatten(
      [xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals])
  operand_shapes = util.flatten(
      [xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals])
  # Handling layouts
  if operand_layouts is None:
    operand_layouts = util.concatenate(
        map(_aval_to_default_layouts, operand_avals))
  operand_mlir_layouts = map(_layout_to_mlir_layout, operand_layouts)
  if result_layouts is None:
    result_layouts = util.concatenate(map(_aval_to_default_layouts, result_avals))
  result_mlir_layouts = map(_layout_to_mlir_layout, result_layouts)

  # First we apply checks to ensure output shapes and dtypes match the expected
  # ones.
  def _wrapped_callback(*args):
    out_vals = callback(*args)
    if len(out_vals) != len(result_avals):
      raise RuntimeError(
          "Mismatched number of outputs from callback. "
          "Expected: {}, Actual: {}".format(len(result_avals), len(out_vals)))
    # Handle Python literals, and custom arrays, e.g., tf.Tensor.
    out_vals = tuple(xla.canonicalize_dtype(np.asarray(a)) for a in out_vals)
    for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)):
      if out_val.shape != out_aval.shape:
        raise RuntimeError(
            f"Incorrect output shape for return value #{i}: "
            f"Expected: {out_aval.shape}, Actual: {out_val.shape}")
      if out_val.dtype != out_aval.dtype:
        raise RuntimeError(
            f"Incorrect output dtype for return value #{i}: "
            f"Expected: {out_aval.dtype}, Actual: {out_val.dtype}")

    if platform == "tpu":
      # On TPU we cannot receive empty arrays. So, we return from the wrapped
      # callback only the non-empty results, and we will create empty constants
      # in the receiving computation.
      # TODO(b/238239458): fix TPU Recv to work with empty arrays.
      non_empty_out_vals = tuple(
          out_val
          for out_val, result_aval in zip(out_vals, result_avals)
          if not is_empty_shape(result_aval.shape))
      return non_empty_out_vals
    else:
      return out_vals

  if platform == "tpu":
    non_empty_result_avals, non_empty_result_shapes = util.unzip2([
        (aval, shape)
        for aval, shape in zip(result_avals, result_shapes)
        if not is_empty_shape(aval.shape)])
    non_empty_outputs, token = _emit_tpu_python_callback(
        backend, ctx, _wrapped_callback,  token,
        operands, operand_avals, operand_shapes,
        non_empty_result_avals, non_empty_result_shapes,
        sharding=sharding)
    non_empty_outputs_iter = iter(non_empty_outputs)
    outputs = [
        ir_constant(np.zeros(result_aval.shape, dtype=result_aval.dtype))
        if is_empty_shape(result_aval.shape) else next(non_empty_outputs_iter)
        for result_aval in result_avals]
    return outputs, token, None

  result_types = util.flatten([aval_to_ir_types(aval) for aval in result_avals])
  if token:

    callback_without_token = _wrapped_callback
    def _wrapped_callback(token, *args):  # type: ignore  # pylint: disable=function-redefined
      return (token, *callback_without_token(*args))

    operand_shapes = [
        xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes
    ]
    result_shapes = [
        xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes
    ]
    operands = [token, *operands]
    result_types = [token_type()[0], *result_types]
    operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts]
    result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts]
  callback_descriptor, ifrt_callback = (
      backend.get_emit_python_callback_descriptor(_wrapped_callback,
                                                  operand_shapes,
                                                  result_shapes))
  ctx.module_context.add_host_callback(ifrt_callback)
  descriptor_operand = ir_constant(callback_descriptor)
  callback_operands = [descriptor_operand, *operands]
  if operand_mlir_layouts is not None:
    operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts]
  result_type = ir.TupleType.get_tuple(result_types)
  call_target_name = ("xla_python_gpu_callback"
                     if platform in {"cuda", "rocm"} else "xla_python_cpu_callback")
  result = hlo.CustomCallOp(
      [result_type],
      callback_operands,
      call_target_name=ir.StringAttr.get(call_target_name),
      has_side_effect=ir.BoolAttr.get(has_side_effect),
      api_version=i32_attr(2),
      called_computations=ir.ArrayAttr.get([]),
      backend_config=ir.StringAttr.get(str(callback_descriptor)),
      operand_layouts=(
        None if operand_mlir_layouts is None
        else ir.ArrayAttr.get(operand_mlir_layouts)),
      result_layouts=(
        None if result_mlir_layouts is None
        else ir.ArrayAttr.get(result_mlir_layouts)))
  if sharding is not None:
    set_sharding(result, sharding)
  results = [
      hlo.get_tuple_element(result, i32_attr(i))
      for i in range(len(result_types))
  ]
  if token:
    token, *results = results
  return results, token, ifrt_callback

def build_mlir_module_helper(
    closed_jaxpr: core.ClosedJaxpr, *, name: str,
    platforms: Sequence[str],
    backend_or_name: str, axis_context: AxisContext) -> ir.Module:
  """Helper to generate pmap-style XLA computations for custom partitioners."""
  unlowerable_effects = lowerable_effects.filter_not_in(closed_jaxpr.effects)
  if unlowerable_effects:
    raise ValueError(f'Cannot lower jaxpr with effects: {closed_jaxpr.effects}')
  lowering_result = lower_jaxpr_to_module(name, closed_jaxpr,
      backend_or_name=backend_or_name, ordered_effects=[],
      name_stack=source_info_util.NameStack(),
      donated_args=[False] * len(closed_jaxpr.jaxpr.invars),
      axis_context=axis_context, platforms=platforms,
      lowering_parameters=LoweringParameters())
  return lowering_result.module

def custom_call(
    call_target_name: str,
    *,
    result_types: Sequence[ir.Type],
    operands: Sequence[ir.Value],
    backend_config: str | bytes | dict[str, ir.Attribute] = "",
    has_side_effect: bool = False,
    result_shapes: Sequence[ir.Value] | None = None,
    called_computations: Sequence[str] = (),
    api_version: int = 2,
    operand_output_aliases: dict[int, int] | None = None,
    operand_layouts: Sequence[Sequence[int]] | None = None,
    result_layouts: Sequence[Sequence[int]] | None = None,
    extra_attributes: dict[str, ir.Attribute] | None = None,
) -> ir.Operation:
  """Helper function for building an hlo.CustomCall.

  Args:
    call_target_name: the name of the custom call target
    result_types: the MLIR types of the results of the custom call
    operands: the MLIR IR values that are arguments to the custom call
    backend_config: an opaque string passed to the custom call kernel
    has_side_effect: if True, marks the custom call as effectful
    result_shapes: tensors that represent the result shapes, to be used when
      the results have dynamic shapes. If not-None, its length must match the
      number of the results.
    called_computations: the list of function names called by the custom call.
    api_version: the ABI contract version of the custom call
    operand_output_aliases: a dict mapping operand numbers to outputs they alias
    operand_layouts: a sequence of layouts (dimension orders) for each operand
    result_layouts: a sequence of layouts (dimension orders) for each result
    extra_attributes: additional IR attributes to apply to the custom_call.
  """
  operands = list(operands)

  if backend_config is None:
    backend_config_attr = ir.StringAttr.get("")
  elif isinstance(backend_config, (str, bytes)):
    backend_config_attr = ir.StringAttr.get(backend_config)  # type: ignore
  elif isinstance(backend_config, dict):
    # TODO(necula): it seems that the CustomCallOp constructor requires that
    # backend_config_attr be a string attribute, even though in some cases we
    # need it to be a DictAttr, e.g., for ApproxTopK on TPU.
    # "Verification failed: 'stablehlo.custom_call' op attribute 'backend_config' failed to satisfy constraint: string attribute"
    # To workaround this limitation we first set it to the empty string and we
    # use an unregistered attribute mhlo.backend_config to hold the DictAttr.
    # We must also use api_version=1 to ensure that mhlo.backend_config is
    # handled properly.
    backend_config_attr = ir.StringAttr.get("")
    api_version = 1
  else:
    raise ValueError("custom_call backend_config unexpected type: " + str(backend_config))
  attributes = dict(
      call_target_name=ir.StringAttr.get(call_target_name),
      has_side_effect=ir.BoolAttr.get(has_side_effect),
      backend_config=backend_config_attr,
      api_version=i32_attr(api_version),
      called_computations=ir.ArrayAttr.get(
          [ir.FlatSymbolRefAttr.get(name) for name in called_computations]
      ),
  )
  if operand_output_aliases is not None:
    attributes["output_operand_aliases"] = ir.ArrayAttr.get([
      hlo.OutputOperandAlias.get(
          # if len(result_types) == 1 then the aliasing refers implicitly to
          # the only output.
          output_tuple_indices=[output_idx] if len(result_types) > 1 else [],
          operand_index=input_idx,
          operand_tuple_indices=[],
      )
      for input_idx, output_idx in (operand_output_aliases.items() or ())
    ])

  if extra_attributes is not None:
    attributes.update(extra_attributes)

  if result_shapes is not None:
    # We add the result_shapes at the end of the operands, and must pass
    # the indices_of_output_operands attribute. This attribute is not yet
    # accepted by the CustomCall constructor, so we use build_generic
    attributes["indices_of_shape_operands"] = ir.DenseIntElementsAttr.get(
        np.asarray(list(range(len(operands), len(operands) + len(result_shapes))),
                   dtype=np.int64))
    if operand_layouts is not None:
      assert len(operand_layouts) == len(operands), (operand_layouts, operands)
      operand_layouts = list(operand_layouts) + [(0,)] * len(result_shapes)
    operands = list(operands) + list(result_shapes)

  if operand_layouts is not None:
    attributes["operand_layouts"] = ir.ArrayAttr.get([
        ir.DenseIntElementsAttr.get(
            np.atleast_1d(np.asarray(l, dtype=np.int64)),
            type=ir.IndexType.get()) for l in operand_layouts
    ])
  if result_layouts is not None:
    assert result_layouts is not None
    assert len(result_layouts) == len(result_types), (
        result_layouts, result_types)
    attributes["result_layouts"] = ir.ArrayAttr.get([
        ir.DenseIntElementsAttr.get(
            np.atleast_1d(np.asarray(l, dtype=np.int64)),
            type=ir.IndexType.get()) for l in result_layouts
    ])

  op = hlo.CustomCallOp.build_generic(results=result_types, operands=operands,
                                      attributes=attributes)
  if isinstance(backend_config, dict):
    op.operation.attributes["mhlo.backend_config"] = ir.DictAttr.get(
        backend_config)
  return op


def reduce_window(
    ctx: LoweringRuleContext,
    *,
    # Base name to be used for the reducer function
    reducer_name: str,
    # Compute the reducer body given the reducer.
    reducer_body: Callable[[ir.Block], Sequence[ir.Value]],
    operands: Sequence[ir.Value],
    init_values: Sequence[ir.Value],
    init_values_avals: Sequence[core.AbstractValue],
    out_avals: Sequence[core.AbstractValue],
    window_dimensions, window_strides, padding, base_dilation, window_dilation):
  """Builds a ReduceWindowOp, with support for dynamic shapes."""

  scalar_types = [aval_to_ir_type(aval) for aval in init_values_avals]
  if any(not core.is_constant_shape(s)
         for s in [window_dimensions, window_dilation, window_strides, base_dilation, *padding]):
    # d_padding will be an array i32[N, 2] with pad_lo and pad_hi for each
    # spatial dimension.
    int2d = aval_to_ir_type(core.ShapedArray((1, 2), np.int32))
    def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]):
      pads = eval_dynamic_shape_as_tensor(ctx, pad_lo_hi)  # i32[2]
      return hlo.reshape(int2d, pads)
    d_padding = hlo.concatenate(list(map(prep_one_pad, padding)), i64_attr(0))
    # Build the reducer
    reducer_type = ir.FunctionType.get(scalar_types + scalar_types,
                                       scalar_types)
    with ir.InsertionPoint.at_block_begin(ctx.module_context.module.body):
      reducer = func_dialect.FuncOp(reducer_name, reducer_type)
    ctx.module_context.symbol_table.insert(reducer)
    entry_block = reducer.add_entry_block()
    with ir.InsertionPoint(entry_block):
      hlo.return_(reducer_body(entry_block))

    rw = custom_call(
      "stablehlo.dynamic_reduce_window",
      result_types=list(map(aval_to_ir_type, out_avals)),
      operands=[
        *operands, *init_values,
        eval_dynamic_shape_as_tensor(ctx, window_dimensions),
        eval_dynamic_shape_as_tensor(ctx, window_strides),
        eval_dynamic_shape_as_tensor(ctx, base_dilation),
        eval_dynamic_shape_as_tensor(ctx, window_dilation),
        d_padding],
       called_computations=[reducer.name.value],
    )
  else:  # Static shapes
    rw = hlo.ReduceWindowOp(
        list(map(aval_to_ir_type, out_avals)),
        operands, init_values,
        dense_int_array(window_dimensions),
        window_strides=dense_int_array(window_strides),
        base_dilations=dense_int_array(base_dilation),
        window_dilations=dense_int_array(window_dilation),
        padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
                                            shape=[len(padding), 2]))
    reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
    with ir.InsertionPoint(reducer):
      hlo.return_(reducer_body(reducer))
  return rw.results


def refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
  """Refines the polymorphic shapes inside a module.

  Given a module with static input shapes, but using dynamic shapes due to
  shape polymorphism, runs shape refinement to resolve all the dynamic shapes.
  Then verifies that there are no more dynamic shapes in the module.
  """
  try:
    refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
      module_to_bytecode(module), enable_shape_assertions=True,
      validate_static_shapes=True)
  except Exception as e:
    raise ValueError(
        "Error refining shapes. " +
        dump_module_message(module, "before_refine_polymorphic_shapes")) from e

  context = make_ir_context()
  with context:
    return ir.Module.parse(refined_module_str)
