Source code for jax_md.interpolate

# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for constructing various interpolating functions.

This code was adapted from the way learning rate schedules are are built in JAX.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import jax.numpy as np
from scipy.interpolate import splrep, PPoly

from jax_md import util


# Typing

f32 = util.f32
f64 = util.f64

#


[docs] def constant(f): def schedule(unused_t): return f return schedule
[docs] def canonicalize(scalar_or_schedule_fun): if callable(scalar_or_schedule_fun): return scalar_or_schedule_fun elif np.ndim(scalar_or_schedule_fun) == 0: return constant(scalar_or_schedule_fun) else: raise TypeError(type(scalar_or_schedule_fun))
[docs] def spline(y, dx, degree=3): """Spline fit a given scalar function. Args: y: The values of the scalar function evaluated on points starting at zero with the interval dx. dx: The interval at which the scalar function is evaluated. degree: Polynomial degree of the spline fit. Returns: A function that computes the spline function. """ num_points = len(y) dx = f32(dx) x = np.arange(num_points, dtype=f32) * dx # Create a spline fit using the scipy function. fn = splrep(x, y, s=0, k=degree) # Turn off smoothing by setting s to zero. params = PPoly.from_spline(fn) # Store the coefficients of the spline fit to an array. coeffs = np.array(params.c) def spline_fn(x): """Evaluates the spline fit for values of x.""" ind = np.array(x / dx, dtype=np.int64) # The spline is defined for x values between 0 and largest value of y. If x # is outside this domain, truncate its ind value to within the domain. truncated_ind = np.array( np.where(ind < num_points, ind, num_points - 1), np.int64 ) truncated_ind = np.array( np.where(truncated_ind >= 0, truncated_ind, 0), np.int64 ) result = np.array(0, x.dtype) dX = x - np.array(ind, np.float32) * dx for i in range(degree + 1): # sum over the polynomial terms up to degree. result = result + np.array( coeffs[degree - i, truncated_ind + 2], x.dtype ) * dX ** np.array(i, x.dtype) # For x values that are outside the domain of the spline fit, return zeros. result = np.where(ind < num_points, result, np.array(0.0, x.dtype)) return result return spline_fn