Source code for jax_md.interpolate

# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for constructing various interpolating functions.

This code was adapted from the way learning rate schedules are are built in JAX.
"""

import jax.numpy as np
from scipy.interpolate import splrep, PPoly

from jax_md import util


# Typing

f32 = util.f32
f64 = util.f64

#


[docs] def constant(f): def schedule(unused_t): return f return schedule
[docs] def canonicalize(scalar_or_schedule_fun): if callable(scalar_or_schedule_fun): return scalar_or_schedule_fun elif np.ndim(scalar_or_schedule_fun) == 0: return constant(scalar_or_schedule_fun) else: raise TypeError(type(scalar_or_schedule_fun))
[docs] def spline(y, dx, degree=3): """Spline fit a given scalar function. Args: y: The values of the scalar function evaluated on points starting at zero with the interval dx. dx: The interval at which the scalar function is evaluated. degree: Polynomial degree of the spline fit. Returns: A function that computes the spline function. """ num_points = len(y) dx = f32(dx) x = np.arange(num_points, dtype=f32) * dx # Create a spline fit using the scipy function. fn = splrep(x, y, s=0, k=degree) # Turn off smoothing by setting s to zero. params = PPoly.from_spline(fn) # Store the coefficients of the spline fit to an array. coeffs = np.array(params.c) def spline_fn(x): """Evaluates the spline fit for values of x.""" ind = np.array(x / dx, dtype=np.int64) # The spline is defined for x values between 0 and largest value of y. If x # is outside this domain, truncate its ind value to within the domain. truncated_ind = np.array( np.where(ind < num_points, ind, num_points - 1), np.int64 ) truncated_ind = np.array( np.where(truncated_ind >= 0, truncated_ind, 0), np.int64 ) result = np.array(0, x.dtype) dX = x - np.array(ind, np.float32) * dx for i in range(degree + 1): # sum over the polynomial terms up to degree. result = result + np.array( coeffs[degree - i, truncated_ind + 2], x.dtype ) * dX ** np.array(i, x.dtype) # For x values that are outside the domain of the spline fit, return zeros. result = np.where(ind < num_points, result, np.array(0.0, x.dtype)) return result return spline_fn