# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A small library of helpers for use in jaxlib to build MLIR operations."""

from __future__ import annotations

from collections.abc import Sequence
from functools import partial
from typing import Callable, Union

import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.stablehlo as hlo
import numpy as np


_dtype_to_ir_type_factory : dict[np.dtype, Callable[[], ir.Type]] = {
  np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
  np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
  np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16),
  np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32),
  np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64),
  np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8),
  np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16),
  np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32),
  np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64),
  np.dtype(np.float16): ir.F16Type.get,
  np.dtype(np.float32): ir.F32Type.get,
  np.dtype(np.float64): ir.F64Type.get,
  np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
  np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
}
def dtype_to_ir_type(dtype) -> ir.Type:
  return _dtype_to_ir_type_factory[np.dtype(dtype)]()


def shape_dtype_to_ir_type(shape: Sequence[int], dtype) -> ir.Type:
  return ir.RankedTensorType.get(shape, dtype_to_ir_type(dtype))


# When we generate custom calls with dynamic shapes we have to pass
# both the result_types, with ir.ShapedType.get_dynamic_size in place of
# the dynamic dimensions, and also result_shapes, which are ir.Value
# representing 1D int32 tensors. If all the shapes are static we can use
# result_shapes=None. We first construct for each result a pair with the shape
# and element type, the shape containing either integer or ir.Value.
DimensionSize = Union[int, ir.Value]  # an ir.Value if not static dimension
ShapeTypePair = tuple[Sequence[DimensionSize], ir.Type]

def mk_result_types_and_shapes(
    shape_type_pairs: Sequence[ShapeTypePair]
) -> tuple[list[ir.Type], list[ir.Value] | None]:
  result_types: list[ir.Type] = []
  result_shapes: list[ir.Value] = []
  has_dynamic_shapes = any(
      any(not isinstance(d, int) for d in rshape)
      for rshape, _ in shape_type_pairs)
  for (rshape, rtype) in shape_type_pairs:
    if has_dynamic_shapes:
      result_shapes.append(shape_tensor(rshape))
    result_types.append(
        ir.RankedTensorType.get(
            [d if isinstance(d, int) else ir.ShapedType.get_dynamic_size()
             for d in rshape],
            rtype))
  return (result_types,
          result_shapes if has_dynamic_shapes else None)

# TODO(necula): share this with mlir.shape_tensor
def shape_tensor(sizes: Sequence[int | ir.Value]) -> ir.Value:
  int1d = shape_dtype_to_ir_type((1,), np.int32)
  i32_type = shape_dtype_to_ir_type((), np.int32)
  def dim_to_i32x1(d):
    if type(d) is int:
      return hlo_const(np.array([d], dtype=np.int32))
    else:
      if d.type != i32_type:
        d = hlo.convert(i32_type, d)
      return hlo.reshape(int1d, d)
  ds = [dim_to_i32x1(sz) for sz in sizes]
  if not ds:
    return hlo_const(np.array([], np.int32))
  elif len(ds) == 1:
    return ds[0]
  else:
    return hlo.concatenate(
        ds, ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0))

def hlo_const(x: np.ndarray) -> ir.Value:
  assert isinstance(x, np.ndarray)
  return hlo.constant(
      ir.DenseElementsAttr.get(x, type=dtype_to_ir_type(x.dtype)))

def hlo_u8(x: int):
  return hlo_const(np.array(x, dtype=np.uint8))
def hlo_s32(x: int):
  return hlo_const(np.array(x, dtype=np.int32))

def ensure_hlo_s32(x: DimensionSize):
  return hlo_s32(x) if isinstance(x, int) else x

def dense_int_array(xs) -> ir.DenseI64ArrayAttr:
  return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))

def hlo_min(x: DimensionSize, y: DimensionSize) -> DimensionSize:
  if type(x) is int:
    if type(y) is int:
      return min(x, y)
    x = hlo_s32(x)
  if type(y) is int:
    y = hlo_s32(y)
  return hlo.minimum(x, y)


def hlo_add(x: DimensionSize, y: DimensionSize) -> DimensionSize:
  if type(x) is int:
    if type(y) is int:
      return x + y
    x = hlo_s32(x)
  if type(y) is int:
    y = hlo_s32(y)
  return hlo.add(x, y)


# TODO(necula): this is identical with mlir.custom_call, but meant for use
# in jaxlib. Find a way to share these implementations.
def custom_call(
    call_target_name: str,
    *,
    result_types: Sequence[ir.Type],
    operands: Sequence[ir.Value],
    backend_config: str | bytes | dict[str, ir.Attribute] = "",
    has_side_effect: bool = False,
    result_shapes: Sequence[ir.Value] | None = None,
    called_computations: Sequence[str] = (),
    api_version: int = 2,
    operand_output_aliases: dict[int, int] | None = None,
    operand_layouts: Sequence[Sequence[int]] | None = None,
    result_layouts: Sequence[Sequence[int]] | None = None,
    extra_attributes: dict[str, ir.Attribute] | None = None,
) -> ir.Operation:
  """Helper function for building an hlo.CustomCall.

  Args:
    call_target_name: the name of the custom call target
    result_types: the MLIR types of the results of the custom call
    operands: the MLIR IR values that are arguments to the custom call
    backend_config: an opaque string passed to the custom call kernel
    has_side_effect: if True, marks the custom call as effectful
    result_shapes: tensors that represent the result shapes, to be used when
      the results have dynamic shapes. If not-None, its length must match the
      number of the results.
    called_computations: the list of function names called by the custom call.
    api_version: the ABI contract version of the custom call
    operand_output_aliases: a dict mapping operand numbers to outputs they alias
    operand_layouts: a sequence of layouts (dimension orders) for each operand
    result_layouts: a sequence of layouts (dimension orders) for each result
    extra_attributes: additional IR attributes to apply to the custom_call.
  """
  operands = list(operands)

  if backend_config is None:
    backend_config_attr = ir.StringAttr.get("")
  elif isinstance(backend_config, (str, bytes)):
    backend_config_attr = ir.StringAttr.get(backend_config)
  elif isinstance(backend_config, dict):
    # TODO(necula): it seems that the CustomCallOp constructor requires that
    # backend_config_attr be a string attribute, even though in some cases we
    # need it to be a DictAttr, e.g., for ApproxTopK on TPU.
    # "Verification failed: 'stablehlo.custom_call' op attribute 'backend_config' failed to satisfy constraint: string attribute"
    # To workaround this limitation we first set it to the empty string and we
    # use an unregistered attribute mhlo.backend_config to hold the DictAttr.
    # We must also use api_version=1 to ensure that mhlo.backend_config is
    # handled properly.
    backend_config_attr = ir.StringAttr.get("")
    api_version = 1
  else:
    raise ValueError("custom_call backend_config unexpected type: " + str(backend_config))
  attributes = dict(
      call_target_name=ir.StringAttr.get(call_target_name),
      has_side_effect=ir.BoolAttr.get(has_side_effect),
      backend_config=backend_config_attr,
      api_version=ir.IntegerAttr.get(
          ir.IntegerType.get_signless(32), api_version),
      called_computations=ir.ArrayAttr.get(
          [ir.FlatSymbolRefAttr.get(name) for name in called_computations]
      ),
  )
  if operand_output_aliases is not None:
    attributes["output_operand_aliases"] = ir.ArrayAttr.get([
      hlo.OutputOperandAlias.get(
          # if len(result_types) == 1 then the aliasing refers implicitly to
          # the only output.
          output_tuple_indices=[output_idx] if len(result_types) > 1 else [],
          operand_index=input_idx,
          operand_tuple_indices=[],
      )
      for input_idx, output_idx in (operand_output_aliases.items() or ())
    ])

  if extra_attributes is not None:
    attributes.update(extra_attributes)

  if result_shapes is not None:
    # We add the result_shapes at the end of the operands, and must pass
    # the indices_of_output_operands attribute. This attribute is not yet
    # accepted by the CustomCall constructor, so we use build_generic
    attributes["indices_of_shape_operands"] = ir.DenseIntElementsAttr.get(
        np.asarray(list(range(len(operands), len(operands) + len(result_shapes))),
                   dtype=np.int64))
    if operand_layouts is not None:
      assert len(operand_layouts) == len(operands), (operand_layouts, operands)
      operand_layouts = list(operand_layouts) + [(0,)] * len(result_shapes)
    operands = list(operands) + list(result_shapes)

  if operand_layouts is not None:
    attributes["operand_layouts"] = ir.ArrayAttr.get([
        ir.DenseIntElementsAttr.get(
            np.atleast_1d(np.asarray(l, dtype=np.int64)),
            type=ir.IndexType.get()) for l in operand_layouts
    ])
  if result_layouts is not None:
    assert result_layouts is not None
    assert len(result_layouts) == len(result_types), (
        result_layouts, result_types)
    attributes["result_layouts"] = ir.ArrayAttr.get([
        ir.DenseIntElementsAttr.get(
            np.atleast_1d(np.asarray(l, dtype=np.int64)),
            type=ir.IndexType.get()) for l in result_layouts
    ])

  op = hlo.CustomCallOp.build_generic(results=result_types, operands=operands,
                                      attributes=attributes)
  if isinstance(backend_config, dict):
    backend_config_attr = ir.DictAttr.get(backend_config)
    op.operation.attributes["mhlo.backend_config"] = backend_config_attr
  return op
