# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import math

from jax._src import api_util
from jax._src import basearray
from jax._src import core
from jax._src import tree_util
from jax._src import sharding_impls
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.util import safe_zip, safe_map

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

# EArray is an Array that can contain extended dtypes.
class EArray(basearray.Array):
  __slots__ = ['aval', '_data']
  __hash__ = None  # type: ignore[assignment]
  __array_priority__ = 100

  def __init__(self, aval, data):
    self.aval = aval
    self._data = data

  def block_until_ready(self):
    _ = self._data.block_until_ready()
    return self

  def copy_to_host_async(self):
    self._data.copy_to_host_async()

  def copy(self):
    return EArray(self.aval, self._data.copy())

  def __repr__(self):
    return 'E' + repr(self._data)

  def __iter__(self):
    if self.ndim == 0: raise TypeError('iteration over a 0-d array')
    raise NotImplementedError

  # forward to aval
  shape = property(lambda self: self.aval.shape)  # type: ignore[assignment]
  dtype = property(lambda self: self.aval.dtype)  # type: ignore[assignment]

  # computed from shape and dtype
  ndim = property(lambda self: len(self.aval.shape))  # type: ignore[assignment]
  size = property(lambda self: math.prod(self.aval.shape))  # type: ignore[assignment]
  itemsize = property(lambda self: self.aval.dtype.itemsize)  # type: ignore[assignment]
  def __len__(self):
    if self.ndim == 0: raise TypeError('len() of unsized object')
    return self.shape[0]

  # forward to self._data
  devices = property(lambda self: self._data.devices)  # type: ignore[assignment]
  _committed = property(lambda self: self._data._committed)
  is_fully_addressable = property(lambda self: self._data.is_fully_addressable)  # type: ignore[assignment]
  is_fully_replicated = property(lambda self: self._data.is_fully_replicated)  # type: ignore[assignment]
  delete = property(lambda self: self._data.delete)  # type: ignore[assignment]
  is_deleted = property(lambda self: self._data.is_deleted)  # type: ignore[assignment]
  on_device_size_in_bytes = property(lambda self: self._data.on_device_size_in_bytes)  # type: ignore[assignment]
  unsafe_buffer_pointer = property(lambda self: self._data.unsafe_buffer_pointer)  # type: ignore[assignment]

  # defer to extended dtype rules
  @property
  def sharding(self):
    phys_sharding = self._data.sharding
    return sharding_impls.logical_sharding(self.aval, phys_sharding)

  # TODO(mattjj): not implemented below here, need more methods from ArrayImpl

  def addressable_data(self, index: int) -> EArray:
    raise NotImplementedError

  @property
  def addressable_shards(self):
    raise NotImplementedError

  @property
  def global_shards(self):
    raise NotImplementedError

# TODO(mattjj): _set_array_base_attributes

def _earray_shard_arg_handler(xs, shardings):
  arrs = [x._data for x in xs]
  phys_shardings = [sharding_impls.physical_sharding(x.aval, sharding)
                    for x, sharding in zip(xs, shardings)]
  return pxla.shard_args(phys_shardings, arrs)
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler

api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval
core.pytype_aval_mappings[EArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[EArray] = lambda x: x
tree_util.dispatch_registry.register_node(
    EArray, lambda x: ((x._data,), x.aval), lambda a, xs: EArray(a, xs[0]))
