# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570

from jax._src.lax.lax import (
  DotDimensionNumbers as DotDimensionNumbers,
  Precision as Precision,
  PrecisionLike as PrecisionLike,
  RandomAlgorithm as RandomAlgorithm,
  RoundingMethod as RoundingMethod,
  abs as abs,
  abs_p as abs_p,
  acos as acos,
  acos_p as acos_p,
  acosh as acosh,
  acosh_p as acosh_p,
  add as add,
  add_p as add_p,
  after_all as after_all,
  after_all_p as after_all_p,
  and_p as and_p,
  argmax as argmax,
  argmax_p as argmax_p,
  argmin as argmin,
  argmin_p as argmin_p,
  asin as asin,
  asin_p as asin_p,
  asinh as asinh,
  asinh_p as asinh_p,
  atan as atan,
  atan_p as atan_p,
  atan2 as atan2,
  atan2_p as atan2_p,
  atanh as atanh,
  atanh_p as atanh_p,
  batch_matmul as batch_matmul,
  bitcast_convert_type as bitcast_convert_type,
  bitcast_convert_type_p as bitcast_convert_type_p,
  bitwise_and as bitwise_and,
  bitwise_not as bitwise_not,
  bitwise_or as bitwise_or,
  bitwise_xor as bitwise_xor,
  broadcast as broadcast,
  broadcast_in_dim as broadcast_in_dim,
  broadcast_in_dim_p as broadcast_in_dim_p,
  broadcast_shapes as broadcast_shapes,
  broadcast_to_rank as broadcast_to_rank,
  broadcasted_iota as broadcasted_iota,
  cbrt as cbrt,
  cbrt_p as cbrt_p,
  ceil as ceil,
  ceil_p as ceil_p,
  clamp as clamp,
  clamp_p as clamp_p,
  clz as clz,
  clz_p as clz_p,
  collapse as collapse,
  complex as complex,
  complex_p as complex_p,
  concatenate as concatenate,
  concatenate_p as concatenate_p,
  conj as conj,
  conj_p as conj_p,
  convert_element_type as convert_element_type,
  convert_element_type_p as convert_element_type_p,
  copy_p as copy_p,
  cos as cos,
  cos_p as cos_p,
  cosh as cosh,
  cosh_p as cosh_p,
  create_token as create_token,
  create_token_p as create_token_p,
  div as div,
  div_p as div_p,
  dot as dot,
  dot_general as dot_general,
  dot_general_p as dot_general_p,
  dtype as dtype,
  eq as eq,
  eq_p as eq_p,
  eq_to_p as eq_to_p,
  exp as exp,
  exp_p as exp_p,
  exp2 as exp2,
  exp2_p as exp2_p,
  expand_dims as expand_dims,
  expm1 as expm1,
  expm1_p as expm1_p,
  floor as floor,
  floor_p as floor_p,
  full as full,
  full_like as full_like,
  ge as ge,
  ge_p as ge_p,
  gt as gt,
  gt_p as gt_p,
  imag as imag,
  imag_p as imag_p,
  infeed as infeed,
  infeed_p as infeed_p,
  integer_pow as integer_pow,
  integer_pow_p as integer_pow_p,
  iota as iota,
  iota_p as iota_p,
  is_finite as is_finite,
  is_finite_p as is_finite_p,
  le as le,
  le_p as le_p,
  le_to_p as le_to_p,
  log as log,
  log1p as log1p,
  log1p_p as log1p_p,
  log_p as log_p,
  logistic as logistic,
  logistic_p as logistic_p,
  lt as lt,
  lt_p as lt_p,
  lt_to_p as lt_to_p,
  max as max,
  max_p as max_p,
  min as min,
  min_p as min_p,
  mul as mul,
  mul_p as mul_p,
  ne as ne,
  ne_p as ne_p,
  neg as neg,
  neg_p as neg_p,
  nextafter as nextafter,
  nextafter_p as nextafter_p,
  not_p as not_p,
  or_p as or_p,
  outfeed as outfeed,
  outfeed_p as outfeed_p,
  pad as pad,
  pad_p as pad_p,
  padtype_to_pads as padtype_to_pads,
  population_count as population_count,
  population_count_p as population_count_p,
  pow as pow,
  pow_p as pow_p,
  ragged_dot as ragged_dot,
  real as real,
  real_p as real_p,
  reciprocal as reciprocal,
  reduce as reduce,
  reduce_and_p as reduce_and_p,
  reduce_max_p as reduce_max_p,
  reduce_min_p as reduce_min_p,
  reduce_or_p as reduce_or_p,
  reduce_p as reduce_p,
  reduce_precision as reduce_precision,
  reduce_precision_p as reduce_precision_p,
  reduce_prod_p as reduce_prod_p,
  reduce_sum_p as reduce_sum_p,
  reduce_xor_p as reduce_xor_p,
  rem as rem,
  rem_p as rem_p,
  reshape as reshape,
  reshape_p as reshape_p,
  rev as rev,
  rev_p as rev_p,
  rng_bit_generator as rng_bit_generator,
  rng_bit_generator_p as rng_bit_generator_p,
  rng_uniform as rng_uniform,
  rng_uniform_p as rng_uniform_p,
  round as round,
  round_p as round_p,
  rsqrt as rsqrt,
  rsqrt_p as rsqrt_p,
  select as select,
  select_n as select_n,
  select_n_p as select_n_p,
  shift_left as shift_left,
  shift_left_p as shift_left_p,
  shift_right_arithmetic as shift_right_arithmetic,
  shift_right_arithmetic_p as shift_right_arithmetic_p,
  shift_right_logical as shift_right_logical,
  shift_right_logical_p as shift_right_logical_p,
  sign as sign,
  sign_p as sign_p,
  sin as sin,
  sin_p as sin_p,
  sinh as sinh,
  sinh_p as sinh_p,
  sort as sort,
  sort_key_val as sort_key_val,
  sort_p as sort_p,
  sqrt as sqrt,
  sqrt_p as sqrt_p,
  square as square,
  squeeze as squeeze,
  squeeze_p as squeeze_p,
  stop_gradient as stop_gradient,
  sub as sub,
  sub_p as sub_p,
  tan as tan,
  tan_p as tan_p,
  tanh as tanh,
  tanh_p as tanh_p,
  top_k as top_k,
  top_k_p as top_k_p,
  transpose as transpose,
  transpose_p as transpose_p,
  xor_p as xor_p,
  zeros_like_array as zeros_like_array,
)
from jax._src.lax.special import (
  bessel_i0e as bessel_i0e,
  bessel_i0e_p as bessel_i0e_p,
  bessel_i1e as bessel_i1e,
  bessel_i1e_p as bessel_i1e_p,
  betainc as betainc,
  digamma as digamma,
  digamma_p as digamma_p,
  erf as erf,
  erfc as erfc,
  erfc_p as erfc_p,
  erf_inv as erf_inv,
  erf_inv_p as erf_inv_p,
  erf_p as erf_p,
  igamma as igamma,
  igammac as igammac,
  igammac_p as igammac_p,
  igamma_grad_a as igamma_grad_a,
  igamma_grad_a_p as igamma_grad_a_p,
  igamma_p as igamma_p,
  lgamma as lgamma,
  lgamma_p as lgamma_p,
  polygamma as polygamma,
  polygamma_p as polygamma_p,
  random_gamma_grad as random_gamma_grad,
  random_gamma_grad_p as random_gamma_grad_p,
  regularized_incomplete_beta_p as regularized_incomplete_beta_p,
  zeta as zeta,
  zeta_p as zeta_p,
)
from jax._src.lax.slicing import (
  GatherDimensionNumbers as GatherDimensionNumbers,
  GatherScatterMode as GatherScatterMode,
  ScatterDimensionNumbers as ScatterDimensionNumbers,
  dynamic_index_in_dim as dynamic_index_in_dim,
  dynamic_slice as dynamic_slice,
  dynamic_slice_in_dim as dynamic_slice_in_dim,
  dynamic_slice_p as dynamic_slice_p,
  dynamic_update_index_in_dim as dynamic_update_index_in_dim,
  dynamic_update_slice as dynamic_update_slice,
  dynamic_update_slice_in_dim as dynamic_update_slice_in_dim,
  dynamic_update_slice_p as dynamic_update_slice_p,
  gather as gather,
  gather_p as gather_p,
  index_in_dim as index_in_dim,
  index_take as index_take,
  scatter as scatter,
  scatter_apply as scatter_apply,
  scatter_add as scatter_add,
  scatter_add_p as scatter_add_p,
  scatter_max as scatter_max,
  scatter_max_p as scatter_max_p,
  scatter_min as scatter_min,
  scatter_min_p as scatter_min_p,
  scatter_mul as scatter_mul,
  scatter_mul_p as scatter_mul_p,
  scatter_p as scatter_p,
  slice as slice,
  slice_in_dim as slice_in_dim,
  slice_p as slice_p,
)
from jax._src.lax.convolution import (
  ConvDimensionNumbers as ConvDimensionNumbers,
  ConvGeneralDilatedDimensionNumbers as ConvGeneralDilatedDimensionNumbers,
  conv as conv,
  conv_dimension_numbers as conv_dimension_numbers,
  conv_general_dilated as conv_general_dilated,
  conv_general_dilated_p as conv_general_dilated_p,
  conv_general_permutations as conv_general_permutations,
  conv_general_shape_tuple as conv_general_shape_tuple,
  conv_shape_tuple as conv_shape_tuple,
  conv_transpose as conv_transpose,
  conv_transpose_shape_tuple as conv_transpose_shape_tuple,
  conv_with_general_padding as conv_with_general_padding,
)
from jax._src.lax.windowed_reductions import (
  reduce_window as reduce_window,
  reduce_window_max_p as reduce_window_max_p,
  reduce_window_min_p as reduce_window_min_p,
  reduce_window_p as reduce_window_p,
  reduce_window_shape_tuple as reduce_window_shape_tuple,
  reduce_window_sum_p as reduce_window_sum_p,
  select_and_gather_add_p as select_and_gather_add_p,
  select_and_scatter_p as select_and_scatter_p,
  select_and_scatter_add_p as select_and_scatter_add_p,
)
from jax._src.lax.control_flow import (
  associative_scan as associative_scan,
  cond as cond,
  cond_p as cond_p,
  cumlogsumexp as cumlogsumexp,
  cumlogsumexp_p as cumlogsumexp_p,
  cummax as cummax,
  cummax_p as cummax_p,
  cummin as cummin,
  cummin_p as cummin_p,
  cumprod as cumprod,
  cumprod_p as cumprod_p,
  cumsum as cumsum,
  cumsum_p as cumsum_p,
  custom_linear_solve as custom_linear_solve,
  custom_root as custom_root,
  fori_loop as fori_loop,
  linear_solve_p as linear_solve_p,
  map as map,
  scan as scan,
  scan_bind as scan_bind,
  scan_p as scan_p,
  switch as switch,
  while_loop as while_loop,
  while_p as while_p,
  platform_dependent as platform_dependent,
)
from jax._src.lax.fft import (
  fft as fft,
  fft_p as fft_p,
)
from jax._src.lax.parallel import (
  all_gather as all_gather,
  all_gather_p as all_gather_p,
  all_to_all as all_to_all,
  all_to_all_p as all_to_all_p,
  axis_index as axis_index,
  axis_index_p as axis_index_p,
  pbroadcast as pbroadcast,
  pmax as pmax,
  pmax_p as pmax_p,
  pmean as pmean,
  pmin as pmin,
  pmin_p as pmin_p,
  ppermute as ppermute,
  ppermute_p as ppermute_p,
  pshuffle as pshuffle,
  psum as psum,
  psum_p as psum_p,
  psum_scatter as psum_scatter,
  pswapaxes as pswapaxes,
  pdot as pdot,
  xeinsum as xeinsum,
)
from jax._src.lax.other import (
  conv_general_dilated_local as conv_general_dilated_local,
  conv_general_dilated_patches as conv_general_dilated_patches
)
from jax._src.lax.ann import (
  approx_max_k as approx_max_k,
  approx_min_k as approx_min_k,
  approx_top_k_p as approx_top_k_p
)
from jax._src.ad_util import stop_gradient_p as stop_gradient_p
from jax.lax import linalg as linalg

from jax._src.pjit import with_sharding_constraint as with_sharding_constraint
from jax._src.pjit import sharding_constraint_p as sharding_constraint_p
from jax._src.dispatch import device_put_p as device_put_p


_deprecations = {
  # Finalized 2024-05-13; remove after 2024-08-13
  "tie_in": (
    "jax.lax.tie_in is deprecated: it has been a no-op since JAX v0.2.0. "
    "Replace z = tie_in(x, y) with z = y.", None,
  ),
}

from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
