# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Jet is an experimental module for higher-order automatic differentiation
  that does not rely on repeated first-order automatic differentiation.

  How? Through the propagation of truncated Taylor polynomials.
  Consider a function :math:`f = g \circ h`, some point :math:`x`
  and some offset :math:`v`.
  First-order automatic differentiation (such as :func:`jax.jvp`)
  computes the pair :math:`(f(x), \partial f(x)[v])` from the pair
  :math:`(h(x), \partial h(x)[v])`.

  :func:`jet` implements the higher-order analogue:
  Given the tuple

  .. math::
    (h_0, ... h_K) :=
    (h(x), \partial h(x)[v], \partial^2 h(x)[v, v], ..., \partial^K h(x)[v,...,v]),

  which represents a :math:`K`-th order Taylor approximation
  of :math:`h` at :math:`x`, :func:`jet` returns a :math:`K`-th order
  Taylor approximation of :math:`f` at :math:`x`,

  .. math::
    (f_0, ..., f_K) :=
    (f(x), \partial f(x)[v], \partial^2 f(x)[v, v], ..., \partial^K f(x)[v,...,v]).

  More specifically, :func:`jet` computes

  .. math::
    f_0, (f_1, . . . , f_K) = \texttt{jet} (f, h_0, (h_1, . . . , h_K))

  and can thus be used for high-order
  automatic differentiation of :math:`f`.
  Details are explained in
  `these notes <https://github.com/google/jax/files/6717197/jet.pdf>`__.

  Note:
    Help improve :func:`jet` by contributing
    `outstanding primitive rules <https://github.com/google/jax/issues/2431>`__.
"""

from typing import Any, Callable

from functools import partial

import numpy as np

from jax import lax
import jax.numpy as jnp
from jax.experimental import pjit
from jax.tree_util import (register_pytree_node, tree_structure,
                           treedef_is_leaf, tree_flatten, tree_unflatten,)

from jax._src import ad_util
from jax._src import core
from jax._src import dispatch
from jax._src import linear_util as lu
from jax._src import sharding_impls
from jax._src.api_util import shaped_abstractify
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.util import unzip2, weakref_lru_cache


def jet(fun, primals, series):
  r"""Taylor-mode higher-order automatic differentiation.

  Args:
    fun: Function to be differentiated. Its arguments should be arrays, scalars,
      or standard Python containers of arrays or scalars. It should return an
      array, scalar, or standard Python container of arrays or scalars.
    primals: The primal values at which the Taylor approximation of ``fun`` should be
      evaluated. Should be either a tuple or a list of arguments,
      and its length should be equal to the number of positional parameters of
      ``fun``.
    series: Higher order Taylor-series-coefficients.
      Together, `primals` and `series` make up a truncated Taylor polynomial.
      Should be either a tuple or a list of tuples or lists,
      and its length dictates the degree of the truncated Taylor polynomial.

  Returns:
    A ``(primals_out, series_out)`` pair, where ``primals_out`` is ``fun(*primals)``,
    and together, ``primals_out`` and ``series_out`` are a
    truncated Taylor polynomial of :math:`f(h(\cdot))`.
    The ``primals_out`` value has the same Python tree structure as ``primals``,
    and the ``series_out`` value the same Python tree structure as ``series``.

  For example:

  >>> import jax
  >>> import jax.numpy as np

  Consider the function :math:`h(z) = z^3`, :math:`x = 0.5`,
  and the first few Taylor coefficients
  :math:`h_0=x^3`, :math:`h_1=3x^2`, and :math:`h_2=6x`.
  Let :math:`f(y) = \sin(y)`.

  >>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5
  >>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)

  :func:`jet` returns the Taylor coefficients of :math:`f(h(z)) = \sin(z^3)`
  according to Faà di Bruno's formula:

  >>> f0, (f1, f2) =  jet(f, (h0,), ((h1, h2),))
  >>> print(f0,  f(h0))
  0.12467473 0.12467473

  >>> print(f1, df(h0) * h1)
  0.7441479 0.74414825

  >>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2)
  2.9064622 2.9064634
  """
  try:
    order, = set(map(len, series))
  except ValueError:
    msg = "jet terms have inconsistent lengths for different arguments"
    raise ValueError(msg) from None

  # TODO(mattjj): consider supporting pytree inputs
  for i, (x, terms) in enumerate(zip(primals, series)):
    treedef = tree_structure(x)
    if not treedef_is_leaf(treedef):
      raise ValueError(f"primal value at position {i} is not an array")
    for j, t in enumerate(terms):
      treedef = tree_structure(t)
      if not treedef_is_leaf(treedef):
        raise ValueError(f"term {j} for argument {i} is not an array")

  @lu.transformation_with_aux
  def flatten_fun_output(*args):
    ans = yield args, {}
    yield tree_flatten(ans)

  f, out_tree = flatten_fun_output(lu.wrap_init(fun))
  out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series)
  return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms)

@lu.transformation
def jet_fun(order, primals, series):
  with core.new_main(JetTrace) as main:
    main.order = order
    out_primals, out_terms = yield (main, primals, series), {}
    del main
  out_terms = [[jnp.zeros_like(p)] * order if s is zero_series else s
               for p, s in zip(out_primals, out_terms)]
  yield out_primals, out_terms

@lu.transformation
def jet_subtrace(main, primals, series):
  trace = JetTrace(main, core.cur_sublevel())
  in_tracers = map(partial(JetTracer, trace), primals, series)
  ans = yield in_tracers, {}
  out_tracers = map(trace.full_raise, ans)
  out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers)
  yield out_primals, out_terms

@lu.transformation_with_aux
def traceable(in_tree_def, *primals_and_series):
  primals_in, series_in = tree_unflatten(in_tree_def, primals_and_series)
  primals_out, series_out = yield (primals_in, series_in), {}
  out_flat, out_tree_def = tree_flatten((primals_out, series_out))
  yield out_flat, out_tree_def


class JetTracer(core.Tracer):
  __slots__ = ["primal", "terms"]

  def __init__(self, trace, primal, terms):
    assert type(terms) in (ZeroSeries, list, tuple)
    self._trace = trace
    self.primal = primal
    self.terms = terms

  @property
  def aval(self):
    return core.get_aval(self.primal)

  def full_lower(self):
    if self.terms is zero_series or all(t is zero_term for t in self.terms):
      return core.full_lower(self.primal)
    else:
      return self

class JetTrace(core.Trace):

  def pure(self, val):
    return JetTracer(self, val, zero_series)

  def lift(self, val):
    return JetTracer(self, val, zero_series)

  def sublift(self, val):
    return JetTracer(self, val.primal, val.terms)

  def process_primitive(self, primitive, tracers, params):
    order = self.main.order              # pytype: disable=attribute-error
    primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
    series_in = [[zero_term] * order if s is zero_series else s
                 for s in series_in]
    # TODO(mattjj): avoid always instantiating zeros
    series_in = [[jnp.zeros(np.shape(x), dtype=jnp.result_type(x))
                  if t is zero_term else t for t in series]
                 for x, series in zip(primals_in, series_in)]
    rule = jet_rules[primitive]
    primal_out, terms_out = rule(primals_in, series_in, **params)
    if not primitive.multiple_results:
      return JetTracer(self, primal_out, terms_out)
    else:
      return [JetTracer(self, p, ts) for p, ts in zip(primal_out, terms_out)]

  def process_call(self, call_primitive, f, tracers, params):
    primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
    primals_and_series, in_tree_def = tree_flatten((primals_in, series_in))
    f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def)
    update_params = call_param_updaters.get(call_primitive)
    new_params = (update_params(params, len(primals_and_series))
                  if update_params else params)
    result = call_primitive.bind(f_jet, *primals_and_series, **new_params)
    primals_out, series_out = tree_unflatten(out_tree_def(), result)
    return [JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)]

  def post_process_call(self, call_primitive, out_tracers, params):
    primals, series = unzip2((t.primal, t.terms) for t in out_tracers)
    out, treedef = tree_flatten((primals, series))
    del primals, series
    main = self.main
    def todo(x):
      primals, series = tree_unflatten(treedef, x)
      trace = JetTrace(main, core.cur_sublevel())
      return map(partial(JetTracer, trace), primals, series)
    return out, todo

  def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *,
                              symbolic_zeros):
    # TODO(mattjj): don't just ignore custom jvp rules?
    del primitive, jvp  # Unused.
    return fun.call_wrapped(*tracers)

  def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
    del primitive, fwd, bwd, out_trees  # Unused.
    return fun.call_wrapped(*tracers)


class ZeroTerm: pass
zero_term = ZeroTerm()
register_pytree_node(ZeroTerm, lambda z: ((), None), lambda _, xs: zero_term)

class ZeroSeries: pass
zero_series = ZeroSeries()
register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series)


call_param_updaters: dict[core.Primitive, Callable[..., Any]] = {}


### rule definitions

jet_rules = {}

def defzero(prim):
  jet_rules[prim] = partial(zero_prop, prim)

def zero_prop(prim, primals_in, series_in, **params):
  primal_out = prim.bind(*primals_in, **params)
  return primal_out, zero_series

defzero(lax.le_p)
defzero(lax.lt_p)
defzero(lax.gt_p)
defzero(lax.ge_p)
defzero(lax.eq_p)
defzero(lax.ne_p)
defzero(lax.not_p)
defzero(lax.and_p)
defzero(lax.or_p)
defzero(lax.xor_p)
defzero(lax.floor_p)
defzero(lax.ceil_p)
defzero(lax.round_p)
defzero(lax.sign_p)
defzero(ad_util.stop_gradient_p)
defzero(lax.is_finite_p)
defzero(lax.shift_left_p)
defzero(lax.shift_right_arithmetic_p)
defzero(lax.shift_right_logical_p)
defzero(lax.bitcast_convert_type_p)


def deflinear(prim):
  jet_rules[prim] = partial(linear_prop, prim)

def linear_prop(prim, primals_in, series_in, **params):
  primal_out = prim.bind(*primals_in, **params)
  series_out = [prim.bind(*terms_in, **params) for terms_in in zip(*series_in)]
  return primal_out, series_out

deflinear(lax.neg_p)
deflinear(lax.real_p)
deflinear(lax.complex_p)
deflinear(lax.conj_p)
deflinear(lax.imag_p)
deflinear(lax.add_p)
deflinear(ad_util.add_jaxvals_p)
deflinear(lax.sub_p)
deflinear(lax.convert_element_type_p)
deflinear(lax.broadcast_in_dim_p)
deflinear(lax.concatenate_p)
deflinear(lax.pad_p)
deflinear(lax.reshape_p)
deflinear(lax.squeeze_p)
deflinear(lax.rev_p)
deflinear(lax.transpose_p)
deflinear(lax.slice_p)
deflinear(lax.reduce_sum_p)
deflinear(lax.reduce_window_sum_p)
deflinear(lax.fft_p)
deflinear(dispatch.device_put_p)

def _dynamic_slice_jet_rule(primals_in, series_in, **params):
  operand, *start_indices = primals_in
  primal_out = lax.dynamic_slice_p.bind(operand, *start_indices, **params)
  series_out = [lax.dynamic_slice_p.bind(terms_in[0], *start_indices, **params)
                for terms_in in zip(*series_in)]
  return primal_out, series_out

jet_rules[lax.dynamic_slice_p] = _dynamic_slice_jet_rule

def _dynamic_update_slice_jet_rule(primals_in, series_in, **params):
  operand, update, *start_indices = primals_in
  primal_out = lax.dynamic_update_slice_p.bind(operand, update, *start_indices)
  series_out = [lax.dynamic_update_slice_p.bind(*terms_in[:2], *start_indices, **params)
                for terms_in in zip(*series_in)]
  return primal_out, series_out

jet_rules[lax.dynamic_update_slice_p] = _dynamic_update_slice_jet_rule

def _cumulative_jet_rule(primals_in, series_in, *, axis: int, reverse: bool,
                         combine_fn: Callable):
  # Irrespective of backend, we always use the parallel prefix scan
  # implementation when differentiating because reduce_window is not
  # arbitrarily differentiable.
  return jet(partial(lax.associative_scan, combine_fn, axis=axis,
                     reverse=reverse),
             primals_in, series_in)

deflinear(lax.cumsum_p)
jet_rules[lax.cumprod_p] = partial(_cumulative_jet_rule,
                                                combine_fn=lax.mul)
jet_rules[lax.cummax_p] = partial(_cumulative_jet_rule,
                                               combine_fn=lax.max)
jet_rules[lax.cummin_p] = partial(_cumulative_jet_rule,
                                               combine_fn=lax.min)


def def_deriv(prim, deriv):
  """
  Define the jet rule for a primitive in terms of its first derivative.
  """
  jet_rules[prim] = partial(deriv_prop, prim, deriv)


def deriv_prop(prim, deriv, primals_in, series_in):
  x, = primals_in
  series, = series_in
  primal_out = prim.bind(x)
  c0, cs = jet(deriv, primals_in, series_in)
  c = [c0] + cs
  u = [x] + series
  v = [primal_out] + [None] * len(series)
  for k in range(1, len(v)):
    v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
  primal_out, *series_out = v
  return primal_out, series_out


def_deriv(lax.erf_p,
          lambda x: lax.mul(
              lax_internal._const(x, 2. / np.sqrt(np.pi)),
              lax.exp(lax.neg(lax.square(x)))))


def def_comp(prim, comp):
  """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
  jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x ** 0.5)
def_comp(lax.rsqrt_p, lambda x: x ** -0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)))
def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x))
def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y))
def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b))


def _erf_inv_rule(primals_in, series_in):
  x, = primals_in
  series, = series_in

  u = [x] + series
  primal_out = lax.erf_inv(x)
  v = [primal_out] + [None] * len(series)

  # derivative on co-domain for caching purposes
  deriv_const = np.sqrt(np.pi) / 2.
  deriv_y = lambda y: lax.mul(deriv_const, lax.exp(lax.square(y)))

  # manually propagate through deriv_y since we don't have lazy evaluation of sensitivities

  c = [deriv_y(primal_out)] + [None] * (len(series) - 1)
  tmp_sq = [lax.square(v[0])] + [None] * (len(series) - 1)
  tmp_exp = [lax.exp(tmp_sq[0])] + [None] * (len(series) - 1)
  for k in range(1, len(series)):
    # we know c[:k], we compute c[k]

    # propagate c to get v
    v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))

    # propagate v to get next c

    # square
    tmp_sq[k] = fact(k) * sum(_scale2(k, j) * v[k-j] * v[j] for j in range(k + 1))

    # exp
    tmp_exp[k] = fact(k-1) * sum(_scale(k, j) * tmp_exp[k-j] * tmp_sq[j] for j in range(1, k + 1))

    # const
    c[k] = deriv_const * tmp_exp[k]

  # we can't, and don't need, to compute c[k+1], just need to get the last v[k]
  k = len(series)
  v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))

  primal_out, *series_out = v
  return primal_out, series_out
jet_rules[lax.erf_inv_p] = _erf_inv_rule

### More complicated rules

def fact(n):
  return lax.exp(lax.lgamma(n+1.))

def _scale(k, j):
  return 1. / (fact(k - j) * fact(j - 1))

def _scale2(k, j):
  return 1. / (fact(k - j) * fact(j))

def _exp_taylor(primals_in, series_in):
  x, = primals_in
  series, = series_in
  u = [x] + series
  v = [lax.exp(x)] + [None] * len(series)
  for k in range(1,len(v)):
    v[k] = fact(k-1) * sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k+1))
  primal_out, *series_out = v
  return primal_out, series_out
jet_rules[lax.exp_p] = _exp_taylor

def _pow_taylor(primals_in, series_in):
  u_, r_ = primals_in

  x, series = jet(lambda x, y: lax.mul(y, lax.log(x)), primals_in, series_in)

  u = [x] + series
  v = [u_ ** r_] + [None] * len(series)
  for k in range(1, len(v)):
    v[k] = fact(k-1) * sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k+1))
  primal_out, *series_out = v

  return primal_out, series_out
jet_rules[lax.pow_p] = _pow_taylor

def _integer_pow_taylor(primals_in, series_in, *, y):
  if y == 0:
    return jet(jnp.ones_like, primals_in, series_in)
  elif y == 1:
    return jet(lambda x: x, primals_in, series_in)
  elif y == 2:
    return jet(lambda x: x * x, primals_in, series_in)
  x, = primals_in
  series, = series_in
  u = [x] + series
  v = [lax.integer_pow(x, y)] + [None] * len(series)
  for k in range(1, len(v)):
    vu = sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k + 1))
    uv = sum(_scale(k, j) * u[k-j] * v[j] for j in range(1, k))
    v[k] = jnp.where(x == 0, 0, fact(k-1) * (y * vu - uv) / x)
  primal_out, *series_out = v

  return primal_out, series_out
jet_rules[lax.integer_pow_p] = _integer_pow_taylor


def _logistic_taylor(primals_in, series_in):
  x, = primals_in
  series, = series_in
  u = [x] + series
  v = [lax.logistic(x)] + [None] * len(series)
  e = [v[0] * (1 - v[0])] + [None] * len(series)  # terms for sigmoid' = sigmoid * (1 - sigmoid)
  for k in range(1, len(v)):
    v[k] = fact(k-1) * sum(_scale(k, j) * e[k-j] * u[j] for j in range(1, k+1))
    e[k] = (1 - v[0]) * v[k] - fact(k) * sum(_scale2(k, j) * v[j] * v[k-j] for j in range(1, k+1))

  primal_out, *series_out = v
  return primal_out, series_out

jet_rules[lax.logistic_p] = _logistic_taylor


def _tanh_taylor(primals_in, series_in):
  x, = primals_in
  series, = series_in
  u = [2*x] + [2 * series_ for series_ in series]
  primals_in, *series_in = u
  primal_out, series_out = _logistic_taylor((primals_in, ), (series_in, ))
  series_out = [2 * series_ for series_ in series_out]
  return 2 * primal_out - 1, series_out
jet_rules[lax.tanh_p] = _tanh_taylor

def _log_taylor(primals_in, series_in):
  x, = primals_in
  series, = series_in
  u = [x] + series
  v = [lax.log(x)] + [None] * len(series)
  for k in range(1, len(v)):
    conv = sum(_scale(k, j) * v[j] * u[k-j] for j in range(1, k))
    v[k] = (u[k] - fact(k - 1) * conv) / u[0]
  primal_out, *series_out = v
  return primal_out, series_out
jet_rules[lax.log_p] = _log_taylor

def _atan2_taylor(primals_in, series_in):
  x, y = primals_in
  primal_out = lax.atan2(x, y)

  x, series = jet(lax.div, primals_in, series_in)
  one = lax_internal._const(x, 1)
  c0, cs = jet(lambda x: lax.div(one, 1 + lax.square(x)), (x, ), (series, ))
  c = [c0] + cs
  u = [x] + series
  v = [primal_out] + [None] * len(series)
  for k in range(1, len(v)):
    v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
  primal_out, *series_out = v
  return primal_out, series_out
jet_rules[lax.atan2_p] = _atan2_taylor

def _div_taylor_rule(primals_in, series_in):
  x, y = primals_in
  x_terms, y_terms = series_in
  u = [x] + x_terms
  w = [y] + y_terms
  v = [None] * len(u)
  def scale(k, j): return 1. / (fact(k - j) * fact(j))
  for k in range(0, len(v)):
    conv = sum(scale(k, j) * v[j] * w[k-j] for j in range(0, k))
    v[k] = (u[k] - fact(k) * conv) / w[0]
  primal_out, *series_out = v
  return primal_out, series_out
jet_rules[lax.div_p] = _div_taylor_rule

def _sinusoidal_rule(sign, prims, primals_in, series_in):
  x, = primals_in
  series, = series_in
  u = [x] + series
  s, c = prims
  s = [s(x)] + [None] * len(series)
  c = [c(x)] + [None] * len(series)
  for k in range(1, len(s)):
    s[k] = fact(k-1) * sum(_scale(k, j) * u[j] * c[k-j] for j in range(1, k + 1))
    c[k] = fact(k-1) * sum(_scale(k, j) * u[j] * s[k-j] for j in range(1, k + 1)) * sign
  return (s[0], s[1:]), (c[0], c[1:])

def _get_ind(f, ind):
  return lambda *args: f(*args)[ind]

jet_rules[lax.sin_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 0)
jet_rules[lax.cos_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 1)
jet_rules[lax.sinh_p] = _get_ind(partial(_sinusoidal_rule, 1, (lax.sinh, lax.cosh)), 0)
jet_rules[lax.cosh_p] = _get_ind(partial(_sinusoidal_rule, 1, (lax.sinh, lax.cosh)), 1)

def _bilinear_taylor_rule(prim, primals_in, series_in, **params):
  x, y = primals_in
  x_terms, y_terms = series_in
  u = [x] + x_terms
  w = [y] + y_terms
  v = [None] * len(u)
  op = partial(prim.bind, **params)
  def scale(k, j): return 1. / (fact(k - j) * fact(j))
  for k in range(0, len(v)):
    v[k] = fact(k) * sum(scale(k, j) * op(u[j], w[k-j]) for j in range(0, k+1))
  primal_out, *series_out = v
  return primal_out, series_out
jet_rules[lax.dot_general_p] = partial(_bilinear_taylor_rule, lax.dot_general_p)
jet_rules[lax.mul_p] = partial(_bilinear_taylor_rule, lax.mul_p)
jet_rules[lax.conv_general_dilated_p] = partial(_bilinear_taylor_rule, lax.conv_general_dilated_p)

def _gather_taylor_rule(primals_in, series_in, **params):
  operand, start_indices = primals_in
  gs, _ = series_in
  primal_out = lax.gather_p.bind(operand, start_indices, **params)
  series_out = [lax.gather_p.bind(g, start_indices, **params) for g in gs]
  return primal_out, series_out
jet_rules[lax.gather_p] = _gather_taylor_rule

def _gen_reduce_choose_taylor_rule(chooser_fun):
  def chooser_taylor_rule(primals_in, series_in, **params):
    operand, = primals_in
    gs, = series_in
    primal_out = chooser_fun(operand, **params)
    axes = params.pop("axes", None)
    primal_dtype = gs[0].dtype
    shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
    location_indicators = lax.convert_element_type(
        lax_internal._eq_meet(operand, lax.reshape(primal_out, shape)),
        primal_dtype)
    counts = lax_internal._reduce_sum(location_indicators, axes)
    def _reduce_chooser_taylor_rule(g):
      return lax.div(
          lax_internal._reduce_sum(lax.mul(g, location_indicators), axes),
          counts)
    series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
    return primal_out, series_out
  return chooser_taylor_rule
jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(
    lax_internal._reduce_max)
jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(
    lax_internal._reduce_min)

def _abs_taylor_rule(x, series_in, **params):
  x, = x
  zero = lax.full_like(x, 0, shape=())
  primal_out = lax.abs_p.bind(x, **params)
  negs = lax.select(lax.lt(x, zero), lax.full_like(x, -1), lax.full_like(x, 1.0))
  fix_sign = lambda y: negs * y
  series_out = [fix_sign(*terms_in, **params) for terms_in in zip(*series_in)]
  return primal_out, series_out
jet_rules[lax.abs_p] = _abs_taylor_rule

def _select_n_taylor_rule(primal_in, series_in, **params):
  b, *cases = primal_in
  primal_out = lax.select_n(b, *cases)
  sel = lambda _, *xs: lax.select_n(b, *xs)
  series_out = [sel(*terms_in) for terms_in in zip(*series_in)]
  return primal_out, series_out
jet_rules[lax.select_n_p] = _select_n_taylor_rule

def _lax_max_taylor_rule(primal_in, series_in):
    x, y = jnp.broadcast_arrays(*primal_in)

    xgy = x > y   # greater than mask
    xey = x == y  # equal to mask
    primal_out = lax.select(xgy, x, y)

    def select_max_and_avg_eq(x_i, y_i):
        """Select x where x>y or average when x==y"""
        max_i = lax.select(xgy, x_i, y_i)
        max_i = lax.select(xey, (x_i + y_i)/2, max_i)
        return max_i

    series_out = [select_max_and_avg_eq(*jnp.broadcast_arrays(*terms_in)) for terms_in in zip(*series_in)]
    return primal_out, series_out
jet_rules[lax.max_p] = _lax_max_taylor_rule

def _lax_min_taylor_rule(primal_in, series_in):
    x, y = primal_in
    xgy = x < y   # less than mask
    xey = x == y  # equal to mask
    primal_out = lax.select(xgy, x, y)

    def select_min_and_avg_eq(x_i, y_i):
        """Select x where x>y or average when x==y"""
        min_i = lax.select(xgy, x_i, y_i)
        min_i = lax.select(xey, (x_i + y_i)/2, min_i)
        return min_i

    series_out = [select_min_and_avg_eq(*terms_in) for terms_in in zip(*series_in)]
    return primal_out, series_out
jet_rules[lax.min_p] = _lax_min_taylor_rule

def _scatter_add_rule(primals_in, series_in, *, update_jaxpr, update_consts,
                      dimension_numbers, indices_are_sorted, unique_indices,
                      mode):
  bind = partial(lax.scatter_add_p.bind, update_jaxpr=update_jaxpr,
                 update_consts=update_consts, dimension_numbers=dimension_numbers,
                 indices_are_sorted=indices_are_sorted,
                 unique_indices=unique_indices, mode=mode)
  operand, scatter_indices, updates = primals_in
  primal_out = bind(operand, scatter_indices, updates)
  series_out = [bind(d1, scatter_indices, d2) for d1, _, d2 in zip(*series_in)]
  return primal_out, series_out
jet_rules[lax.scatter_add_p] = _scatter_add_rule


@weakref_lru_cache
def _jet_jaxpr(
    jaxpr: core.ClosedJaxpr, order: int, primals_and_series_avals, in_tree_def
) -> tuple[core.ClosedJaxpr, Any]:
  f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
  f_jet, out_tree_def = traceable(jet_fun(jet_subtrace(f), order), in_tree_def)
  jaxpr_jet, _, consts = pe.trace_to_jaxpr_dynamic(
      f_jet, primals_and_series_avals)
  return core.ClosedJaxpr(jaxpr_jet, consts), out_tree_def


def _pjit_jet_rule(primals_in, series_in, **params):
  primals_and_series, in_tree_def = tree_flatten((primals_in, series_in))
  order = len(series_in[0])
  primals_and_series_avals = tuple(shaped_abstractify(x) for x in primals_and_series)
  jaxpr_jet, out_tree_def = _jet_jaxpr(params['jaxpr'], order,
                                       primals_and_series_avals, in_tree_def)
  num_series_in = len(primals_in) * order
  num_series_out = len(params['out_shardings']) * order
  new_params = {
      **params,
      'jaxpr': jaxpr_jet,
      'in_shardings': (
          params['in_shardings'] + (sharding_impls.UNSPECIFIED,) * num_series_in
      ),
      'out_shardings': (
          params['out_shardings']
          + (sharding_impls.UNSPECIFIED,) * num_series_out
      ),
      'donated_invars': params['donated_invars'] + (False,) * num_series_in,
  }
  result = pjit.pjit_p.bind(*primals_and_series, **new_params)
  return tree_unflatten(out_tree_def(), result)

jet_rules[pjit.pjit_p] = _pjit_jet_rule
