jax_md.interpolate module

jax_md.interpolate module#

Utilities for constructing various interpolating functions.

This code was adapted from the way learning rate schedules are are built in JAX.

jax_md.interpolate.canonicalize(scalar_or_schedule_fun)[source]#
jax_md.interpolate.constant(f)[source]#
jax_md.interpolate.spline(y, dx, degree=3)[source]#

Spline fit a given scalar function.

Parameters:
  • y – The values of the scalar function evaluated on points starting at zero

  • dx. (with the interval)

  • dx – The interval at which the scalar function is evaluated.

  • degree – Polynomial degree of the spline fit.

Returns:

A function that computes the spline function.