# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code to simulate rigid bodies in two- and three-dimensions.
This code contains a number of pieces that handle different parts of
rigid body simulations.
To start with, we include some quaternion utilities for representing oriented
bodies in three dimensions.
Rigid body simulations are split into two components.
1) The state of rigid bodies are represented by a dataclass containing a
center-of-mass position and an orientation. Along with this type
representation, the core simulation functions are overloaded to automatically
allow deterministic NVE and NVT simulations to work with state composed of
RigidBody objects (see `simulation.py` for details). If you need any other
simulation environments, please raise a github issue.
One subtlety of the type system that we use here is that a host of related
quantities are represented by RigidBody objects. For example, the momentum
is represented by a RigidBody containing the linear momentum and angular
momentum, while the mass is a RigidBody containing the total mass and the
moment of inertia. This allows us to naturally use JAX's tree_map utilities
to jointly map over the different related quantities. Additionally, forces
inherit the RigidBody type with a center-of-mass force and torque.
2) Interactions between rigid bodies are specified. This is largely responsible
for dictating the shape of the rigid body. While arbitrary interactions are
possible, we include utility functions for producing rigid bodies that are
made by the union of point-like particles. This captures many common models
of rigid molecules and colloids. These functions work by providing a
RigidPointUnion object that specifies the location of point particles in the
body frame along with a pointwise energy function. This approach works with or
without neighbor lists and yields a function that computes the total energy on
a system of rigid bodies.
"""
from typing import Optional, Tuple, Any, Union, Callable
from absl import logging
import numpy as onp
import jax
from jax import vmap
from jax import ops
from jax import random
import jax.numpy as jnp
from jax_md import dataclasses, util, space, partition, quantity, simulate
from functools import partial, reduce
from jax.tree_util import tree_map, tree_reduce
import operator
DType = Any
Array = util.Array
PyTree = Any
f64 = util.f64
f32 = util.f32
KeyArray = random.KeyArray
NeighborListFns = partition.NeighborListFns
ShiftFn = space.ShiftFn
"""Quaternion Utilities.
The quaternion utilities are divided into private helper functions and public
functions. The public versions of the function take `Quaternion` objects which
help enforce type safety. The private version of the functions take raw arrays,
but use JAX's vectorize utilities to automatically vectorize over any number
of additional dimensions.
To compute derivatives of quaternions as the tangent space of S^3 we follow the
perspective outlined in
New Langevin and Gradient Thermostats for Rigid Body Dynamics
R. L. Davidchack, T. E. Ouldridge, and M. V. Tretyakov
J. Chem. Phys. 142, 144114 (2015)
TODO: We should make sure that all publicly exposed quaternion operations have
correct derivatives, at some point.
"""
def _is_float64(x: Array) -> bool:
return x.dtype in [jnp.float64, onp.float64]
@partial(jnp.vectorize, signature='(q),(q)->(q)')
def _quaternion_multiply(lhs: Array, rhs: Array) -> Array:
wl, xl, yl, zl = lhs
wr, xr, yr, zr = rhs
dtype = f64 if _is_float64(lhs) or _is_float64(rhs) else f32
return jnp.array([
-xl * xr - yl * yr - zl * zr + wl * wr,
xl * wr + yl * zr - zl * yr + wl * xr,
-xl * zr + yl * wr + zl * xr + wl * yr,
xl * yr - yl * xr + zl * wr + wl * zr
], dtype=dtype)
@partial(jnp.vectorize, signature='(q)->(q)')
def _quaternion_conjugate(q: Array) -> Array:
w, x, y, z = q
return jnp.array([w, -x, -y, -z], dtype=q.dtype)
def _quaternion_rotate_raw(q: Array, v: Array) -> Array:
"""Rotates a vector by a given quaternion."""
if q.shape != (4,):
raise ValueError('quaternion_rotate expects quaternion to have '
f'4-dimensions. Found {q.shape}.')
if v.shape != (3,):
raise ValueError('quaternion_rotate expects vector to have '
f'three-dimensions. Found {v.shape}.')
v = jnp.concatenate([jnp.zeros((1,), v.dtype), v])
q = _quaternion_multiply(q, _quaternion_multiply(v, _quaternion_conjugate(q)))
return q[1:]
@jax.custom_vjp
def _quaternion_rotate(q: Array, v: Array) -> Array:
return _quaternion_rotate_raw(q, v)
def _quaternion_rotate_fwd(q: Array, v: Array) -> Array:
return _quaternion_rotate(q, v), (q, v)
def _quaternion_rotate_bwd(res, g: Array) -> Tuple[Array, Array]:
q, v = res
_, vjp_fn = jax.vjp(_quaternion_rotate_raw, q, v)
dq, dv = vjp_fn(g)
return dq - (q @ dq) * q, dv
_quaternion_rotate.defvjp(_quaternion_rotate_fwd, _quaternion_rotate_bwd)
def _random_quaternion(key: KeyArray, dtype: DType) -> Array:
"""Generate a random quaternion of a given dtype."""
rnd = random.uniform(key, (3,), minval=0.0, maxval=1.0, dtype=dtype)
r1 = jnp.sqrt(1.0 - rnd[0])
r2 = jnp.sqrt(rnd[0])
pi2 = jnp.pi * 2.0
t1 = pi2 * rnd[1]
t2 = pi2 * rnd[2]
return jnp.array([jnp.cos(t2) * r2,
jnp.sin(t1) * r1,
jnp.cos(t1) * r1,
jnp.sin(t2) * r2],
dtype)
[docs]@dataclasses.dataclass
class Quaternion:
"""An object representing a quaternion.
Data is stored in a vector array, but this class exposes several
convenience features including quaternion multiplication and conjugation. It
also changes the size property to return the number of degrees of freedom
of the quaternion (which is three since we expect the quaternion to be
normalized.
Attributes:
vec: An array containing the underlying jax.numpy representation.
"""
vec: Array
@property
def size(self) -> int:
return 3 * reduce(operator.mul, self.vec.shape[:-1], 1)
@property
def ndim(self) -> Tuple[int, ...]:
return self.vec.ndim
def conj(self):
return Quaternion(_quaternion_conjugate(self.vec))
def __mul__(self, qp: 'Quaternion') -> 'Quaternion':
return Quaternion(_quaternion_multiply(self.vec, qp.vec))
def __rmul__(self, qp: 'Quaternion') -> 'Quaternion':
return Quaternion(_quaternion_multiply(qp.vec, self.vec))
def __getitem__(self, idx):
if self.vec.ndim == 1:
# NOTE: This will not catch the case where `idx` indexes into b
raise ValueError('Quaternions do not support indexing into their '
'spatial dimension. If you want this behavior then '
'access the underying `vec` attribute directly.')
return Quaternion(self.vec[idx])
[docs]def quaternion_rotate(q: Quaternion, v: Array) -> Array:
"""Rotates a vector by a given quaternion."""
return jnp.vectorize(_quaternion_rotate, signature='(q),(d)->(d)')(q.vec, v)
[docs]def random_quaternion(key: KeyArray, dtype: DType) -> Quaternion:
"""Generate a random quaternion of a given dtype."""
rand_quat = partial(_random_quaternion, dtype=dtype)
rand_quat = jnp.vectorize(rand_quat, signature='(k)->(q)')
return Quaternion(rand_quat(key))
def tree_map_no_quat(fn: Callable[..., Any], tree: Any, *rest: Any):
"""Tree map over a PyTree treating Quaternions as leaves."""
return tree_map(fn, tree, *rest,
is_leaf=lambda node: isinstance(node, Quaternion))
"""Rigid body simulation functions.
This section contains classes and functions to simulate rigid bodies and to
compute various observables for simulations of rigid bodies. The structure
of the code is as follows:
1) We have the rigid body dataclass which contains the data necessary to
describe the state of rigid bodies during a simulation.
2) We have a number of helper functions to transform between the world
space reference frame and the body reference frame along with corresponding
transformations for angular momentum.
3) We have code to compute various physical observables of rigid body
collections. These are analogous to several functions in the `quantity.py` file
namely: `kinetic_energy` and `temperature`.
4) Finally, we have functions that overide the single_dispatch functions in
`simulate.py` that allow NVE and NVT simulations to work with rigid bodies.
"""
[docs]@dataclasses.dataclass
class RigidBody:
"""Defines a body described by a position and orientation.
One subtlety about the use of RigidBody objects in JAX MD is that they
are used to describe several different related concepts. In general the
`RigidBody` object contains two pieces of data: the `center` containing
information about the center of mass of the body and `orientation` containing
information about the orientation of the body. In practice, this means that
`RigidBody` objects are used to describe a number quantities that all have a
center-of-mass and orientational components.
For example, the instantaneous state of a rigid body might be described by a
`RigidBody` containing center-of-mass position and orientation. The momentum
of the body will be described by a `RigidBody` containing the linear momentum
and the angular momentum. The force on the body will be described by a
`RigidBody` containing linear force and torque. Finally, the mass of the body
will be described by a `RigidBody` containing the total mass and the angular
momentum.
When used in conjunction with automatic differentiation or simulation
environments, forces and velocities will also be of type `RigidBody`. In
these cases the orientation should be interpreted as torque and angular
momentum respectively.
Attributes:
center: An array of two- or three-dimensional positions giving the center
position of the body.
orientation: In two-dimensions this will be an array of angles. In three-
dimensions this will be a set of quaternions.
"""
center: Array
orientation: Union[Array, Quaternion]
def __getitem__(self, idx):
return RigidBody(self.center[idx], self.orientation[idx])
util.register_custom_simulation_type(RigidBody)
@partial(jnp.vectorize, signature='(d)->(k,k)')
def _space_to_body_rotation(q: Array) -> Array:
q2 = q ** 2
w, x, y, z = q
w2, x2, y2, z2 = q2
return jnp.array([
[w2 + x2 - y2 - z2, 2 * (x * y + w * z), 2 * (x * z - w * y)],
[2 * (x * y - w * z), w2 - x2 + y2 - z2, 2 * (y * z + w * x)],
[2 * (x * z + w * y), 2 * (y * z - w * x), w2 - x2 - y2 + z2]
], q.dtype)
def space_to_body_rotation(q: Quaternion) -> Array:
"""Returns an affine transformation from world space to the body frame."""
return _space_to_body_rotation(q.vec)
@partial(jnp.vectorize, signature='(d)->(d,d)')
def _S(q: Array) -> Array:
return jnp.array([
[q[0], -q[1], -q[2], -q[3]],
[q[1], q[0], -q[3], q[2]],
[q[2], q[3], q[0], -q[1]],
[q[3], -q[2], q[1], q[0]]
], q.dtype)
def S(q: Quaternion) -> Array:
"""From Miller III et al., S(q) is defined so that \dot q = 1/2S(q)\omega.
Thus S(q) is the affine transformation that relates time derivatives of
quaternions to angular velocities.
"""
return _S(q.vec)
[docs]def conjugate_momentum_to_angular_momentum(orientation: Quaternion,
momentum: Quaternion
) -> Array:
"""Convert from the conjugate momentum of a quaternion to angular momentum.
Simulations involving quaternions typically proceed by integrating Hamilton's
equations with an extended Hamiltonian,
.. math::
H(p, q) = 1/8 p^T S(q) D S(q)^T p + \phi(q)
where q is the orientation and p is the conjugate momentum variable.
Note (!!) unlike in problems involving only positional degrees of freedom, it
is not the case here that dq/dt = p / m. The conjugate momentum is defined
only by the Legendre transformation.
This means that you cannot compute the angular velocity by simply
transforming the conjugate momentum as you would the time-derivative of q.
Compare, for example equation (2.13) and (2.15) in [1].
[1] Symplectic quaternion scheme for biophysical molecular dynamics
Miller, Eleftheriou, Pattnaik, Ndirango, Newns, and Martyna
J. Chem. Phys. 116 20 (2002)
"""
# NOTE: Here we are stripping the zeroth component of the angular moment.
# however, it would be good to add a test that this is explicitly zero.
@partial(jnp.vectorize, signature='(d),(d)->(k)')
def wrapped_fn(q: Array, m: Array) -> Array:
return (0.5 * _S(q).T @ m)[1:]
return wrapped_fn(orientation.vec, momentum.vec)
[docs]def angular_momentum_to_conjugate_momentum(orientation: Quaternion,
omega: Array
) -> Quaternion:
"""Transforms angular momentum vector to a conjugate momentum quaternion."""
@partial(jnp.vectorize, signature='(d),(k)->(d)')
def wrapped_fn(q: Array, o: Array) -> Array:
o = jnp.concatenate((jnp.zeros((1,), dtype=q.dtype), o))
return 2 * _S(q) @ o
return Quaternion(wrapped_fn(orientation.vec, omega))
"""Quantity Functions.
These functions are analogues of functions in `quantity.py` except that they
work with RigidBody objects rather than linear positions / velocities.
"""
def canonicalize_momentum(position: RigidBody, momentum: RigidBody
) -> RigidBody:
"""Convert quaternion conjugate momentum to angular momentum."""
orientation = position.orientation
p = momentum.orientation
if isinstance(orientation, Quaternion):
p = conjugate_momentum_to_angular_momentum(orientation, p)
return RigidBody(momentum.center, p)
[docs]def kinetic_energy(position: RigidBody, momentum: RigidBody, mass: RigidBody
) -> float:
"""Computes the kinetic energy of a system with some momenta."""
momentum = canonicalize_momentum(position, momentum)
ke = tree_map(lambda m, p: 0.5 * util.high_precision_sum(p**2 / m),
mass, momentum)
return tree_reduce(operator.add, ke, 0.0)
[docs]def temperature(position: RigidBody, momentum: RigidBody, mass: RigidBody
) -> float:
"""Computes the temperature of a system with some momenta."""
dof = quantity.count_dof(momentum)
momentum = canonicalize_momentum(position, momentum)
ke = tree_map(lambda m, p: util.high_precision_sum(p**2 / m) / dof,
mass, momentum)
return tree_reduce(operator.add, ke, 0.0)
def get_moment_of_inertia_diagonal(I: Array, eps=1e-5):
"""Raises a ValueError if the moment of inertia tensor is not diagonal."""
I_diag = vmap(jnp.diag)(I)
# NOTE: Here epsilon has to take into account numerical error from
# diagonalization. Maybe there's a more systematic way to figure this
# out. It might also be worth always trying to diagonalize at float64.
# NOTE: This will not work if the moment of inertia tensor is not known
# ahead of a JIT. Maybe worth removing this check and relying on the fact
# that helper functions always diagonalize the moment of inertia.
try:
if jnp.any(jnp.abs(I - vmap(jnp.diag)(I_diag)) > eps):
max_dev = jnp.max(jnp.abs(I - vmap(jnp.diag)(I_diag)))
raise ValueError('Expected diagonal moment of inertia.'
f'Maximum deviation: {max_dev}. Tolerance: {eps}.')
except jax.errors.ConcretizationTypeError:
logging.info('Skipping moment of inertia diagonalization check inside of'
'JIT. Make sure your moment of inertia is diagonal.')
return I_diag
"""Simulation Single Dispatch Extension Functions.
This code overides the core simulation functions in `simulate.py` to allow
simulations to work with RigidBody objects. See `simulate.py` for a detailed
description of the use of single dispatch in simulation functions.
These functions are based on Miller III et al [1], which uses the
Suzuki-Trotter decomposition to identify a factorization of the Liouville
operator for Rigid Body motion. This factorization is compatible with either
the NVE or NVT ensemble (but is not compatible with NPT).
"""
@quantity.count_dof.register
def _(position: RigidBody) -> int:
sizes = tree_map_no_quat(lambda x: x.size, position)
return tree_reduce(lambda accum, x: accum + x, sizes, 0)
@simulate.initialize_momenta.register(RigidBody)
def _(state, key: Array, kT: float):
R, mass = state.position, state.mass
center_key, angular_key = random.split(key)
P_center = jnp.sqrt(mass.center * kT) * random.normal(center_key,
R.center.shape,
dtype=R.center.dtype)
P_center = P_center - jnp.mean(P_center, axis=0, keepdims=True)
# A the moment we assume that rigid body objects are either 2d or 3d. At some
# point it might be worth expanding this definition to include other kinds of
# oriented bodies.
if isinstance(R.orientation, Quaternion):
scale = jnp.sqrt(mass.orientation * kT)
center = R.center
P_angular = scale * random.normal(angular_key,
center.shape,
dtype=center.dtype)
P_orientation = angular_momentum_to_conjugate_momentum(R.orientation,
P_angular)
else:
scale = jnp.sqrt(mass.orientation * kT)
shape, dtype = R.orientation.shape, R.orientation.dtype
P_orientation = scale * random.normal(angular_key, shape, dtype=dtype)
return state.set(momentum=RigidBody(P_center, P_orientation))
class EmptyLeaf:
pass
_EMPTY_LEAF = EmptyLeaf()
def split_center_and_orientation(dc):
def grab_center(x):
if isinstance(x, RigidBody):
return x.center
return _EMPTY_LEAF
def grab_orientation(x):
if isinstance(x, RigidBody):
return x.orientation
return _EMPTY_LEAF
def grab_rest(x):
if isinstance(x, RigidBody):
return _EMPTY_LEAF
return x
is_rigid = lambda x: isinstance(x, RigidBody)
c = tree_map(grab_center, dc, is_leaf=is_rigid)
o = tree_map(grab_orientation, dc, is_leaf=is_rigid)
r = tree_map(grab_rest, dc, is_leaf=is_rigid)
return r, c, o
def merge_center_and_orientation(rest, center, orientation):
def merge_fn(r, c, o):
if r is _EMPTY_LEAF:
return RigidBody(c, o)
return r
return tree_map_no_quat(merge_fn, rest, center, orientation)
MOMENTUM_PERMUTATION = [
vmap(lambda q: jnp.array([-q[1], q[0], q[3], -q[2]])),
vmap(lambda q: jnp.array([-q[2], -q[3], q[0], q[1]])),
vmap(lambda q: jnp.array([-q[3], q[2], -q[1], q[0]])),
]
def _rigid_body_3d_position_step(state, shift_fn: ShiftFn, dt, m_rot: int,
**kwargs):
"""A symplectic update function for 3d rigid bodies."""
def free_rotor(k, dt, quat, p_quat, M):
delta = dt / m_rot
P = MOMENTUM_PERMUTATION
if M.ndim == 1:
Mk = M[k]
elif M.ndim == 2:
Mk = M[:, [k]]
else:
raise NotImplementedError()
zeta = delta * jnp.einsum('ij,ij->i', p_quat, P[k](quat) / (4 * Mk))
zeta = zeta[:, None]
quat = jnp.cos(zeta) * quat + jnp.sin(zeta) * P[k](quat)
p_quat = jnp.cos(zeta) * p_quat + jnp.sin(zeta) * P[k](p_quat)
return quat, p_quat
def orientation_update(state, dt, **kwargs):
dt_2 = dt / 2
R = state.position
P = state.momentum
if not (isinstance(R, Quaternion) and isinstance(P, Quaternion)):
raise ValueError('For 3d rigid bodies, orientations must be quaternions.'
f'Found {type(R)} for positions and {type(P)} for '
'momenta.')
R = R.vec
P = P.vec
M = state.mass
for _ in range(m_rot):
R, P = free_rotor(2, dt_2, R, P, M)
R, P = free_rotor(1, dt_2, R, P, M)
R, P = free_rotor(0, dt, R, P, M)
R, P = free_rotor(1, dt_2, R, P, M)
R, P = free_rotor(2, dt_2, R, P, M)
return state.set(position=Quaternion(R), momentum=Quaternion(P))
rest, center, orientation = split_center_and_orientation(state)
center = simulate.position_step(center, shift_fn, dt, **kwargs)
orientation = orientation_update(orientation, dt, **kwargs)
return merge_center_and_orientation(rest, center, orientation)
def _rigid_body_2d_position_step(state, shift_fn: ShiftFn, dt,
**kwargs):
"""A symplectic update function for 2d rigid bodies."""
rest, center, orientation = split_center_and_orientation(state)
center = simulate.position_step(center, shift_fn, dt, **kwargs)
orientation = simulate.position_step(orientation,
lambda r, dr, **_: r + dr, dt,
**kwargs)
return merge_center_and_orientation(rest, center, orientation)
@simulate.position_step.register(RigidBody)
def _(state, shift_fn, dt, m_rot=1, **kwargs):
if isinstance(state.position.orientation, Quaternion):
return _rigid_body_3d_position_step(state, shift_fn, dt, m_rot=m_rot,
**kwargs)
else:
return _rigid_body_2d_position_step(state, shift_fn, dt, **kwargs)
@simulate.stochastic_step.register(RigidBody)
def _(state, dt: float, kT: float, gamma: float):
key, center_key, orientation_key = random.split(state.rng, 3)
rest, center, orientation = split_center_and_orientation(state)
center = simulate.stochastic_step(
center.set(rng=center_key),
dt,
kT,
gamma.center)
Pi = orientation.momentum.vec
I = orientation.mass
G = gamma.orientation
M = 4 / jnp.sum(1 / I, axis=-1)
Q = orientation.position.vec
P = MOMENTUM_PERMUTATION
# First evaluate PI term
Pi_mean = 0
for l in range(3):
I_l = I[:, [l], None]
M_l = M[:, None, None]
PP = P[l](Q)[:, None, :] * P[l](Q)[:, :, None]
Pi_mean += jnp.exp(-G * M_l * dt / (4 * I_l)) * PP
Pi_mean = jnp.einsum('nij,nj->ni', Pi_mean, Pi)
# Then evaluate Q term
Pi_var = 0
for l in range(3):
scale = jnp.sqrt(4 * kT * I[:, l] *
(1 - jnp.exp(-M * G * dt / (2 * I[:, l]))))
Pi_var += (scale[:, None] * P[l](Q))**2
momentum_dist = simulate.Normal(Pi_mean, Pi_var)
new_momentum = Quaternion(momentum_dist.sample(orientation_key))
orientation = orientation.set(momentum=new_momentum)
return merge_center_and_orientation(rest.set(rng=key), center, orientation)
@simulate.canonicalize_mass.register(RigidBody)
def _(state):
mass = state.mass
if len(mass.center) == 1:
return state.set(mass=RigidBody(mass.center[0], mass.orientation))
elif len(mass.center) > 1:
return state.set(mass=RigidBody(mass.center[:, None], mass.orientation))
raise NotImplementedError(
'Center of mass must be either a scalar or a vector. Found an array of '
f'shape {mass.center.shape}.')
@simulate.kinetic_energy.register(RigidBody)
def _(state) -> Array:
return kinetic_energy(state.position, state.momentum, state.mass)
@simulate.temperature.register(RigidBody)
def _(state) -> Array:
return temperature(state.position, state.momentum, state.mass)
"""Rigid bodies as unions of point-like particles.
All of the preceding code is valid for any rigid body. Now, we provide a set
of tools for easily constructing energy functions for one class of rigid
bodies. In particular, we provide utilities for defining rigid bodies as
rigid unions of point-like particles. These point-like particles can have
arbitrary interactions between them (which we refer to as the point-species).
Additionally, different rigid point unions can be put into the same simulation.
The rigid point union is synonymous with the shape of the body. Of course this
represents a small subset of the total possible set of rigid body potentials
and it would be interesting to explore other possibilities.
"""
[docs]@dataclasses.dataclass
class RigidPointUnion:
""".. _rigid_body_union:
Defines a rigid collection of point-like masses glued together.
This class describes a rigid body as a collection of point-like particles
rigidly arranged in space. These points can have variable masses. Rigid
bodies interact by specifying well-defined pair potentials between the
different points. This is a common model for rigid molecules and colloids.
To avoid a singularity in the case of a rigid body with a single point, the
particles are represented by disks in two-dimensions and spheres in
three-dimensions so that each point-mass has a moment of inertia,
:math:`I_{disk} = r^2/2` in two-dimensions and :math:`I_{sphere} = 2r^2/5`
in three-dimensions.
Each point can optionally be described by an integer specifying its species
(that we will refer to as a "point species"). Different point species
typically interact with different potential parameters.
Additionally, this class can store multiple different shapes packed together
that get referenced by a "shape species". In this case `total_points` refers
to the total number of points among all the shapes while `shape_count` refers
to the number of different kinds of shapes.
Attributes:
points: An array of shape `(total_points, spatial_dim)` specifying the
position of the points making up each rigid shape.
masses: An array of shape `(total_points,)` specifying the mass of each
point in the union.
point_count: An array of shape `(shape_count,)` specifying the number of
points in each shape.
point_offset: An array of shape `(shape_count,)` specifying the starting
index in the `points` array for each shape.
point_species: An optional array of shape `(total_points,)` specifying
the species of each point making up the rigid shape.
point_radius: A float specifying the radius for the disk / sphere used
in computing the moment of inertia for each point-like particle.
"""
points: Array
masses: Array
point_count: Array
point_offset: Array
point_species: Optional[Array] = None
point_radius: float = dataclasses.field(default_factory=lambda: f32(0.5))
def dimension(self) -> int:
"""Returns the spatial dimension of the shape."""
return self.points.shape[-1]
def _sum_over_shapes(self, x):
shape_count = len(self.point_count)
shape_idx = jnp.repeat(jnp.arange(shape_count), self.point_count,
total_repeat_length=len(self.points))
return ops.segment_sum(x, shape_idx, shape_count,)
def moment_of_inertia(self) -> Array:
"""Compute the moment of inertia for each shape in the collection."""
ndim = self.dimension()
dtype = self.points.dtype
if ndim == 2:
I_disk = 1 / 2 * self.point_radius ** 2
@vmap
def per_particle(point, mass):
return mass * ((point[0] ** 2 + point[1] ** 2) + I_disk)
return self._sum_over_shapes(per_particle(self.points, self.masses))
elif ndim == 3:
I_sphere = 2 / 5 * self.point_radius ** 2
@vmap
def per_particle(point, mass):
Id = jnp.eye(3, dtype=dtype)
diagonal = jnp.sum(point**2) * Id
off_diagonal = point[:, None] * point[None, :]
return mass * ((diagonal - off_diagonal) + Id * I_sphere)
return self._sum_over_shapes(per_particle(self.points, self.masses))
else:
raise ValueError('Rigid bodies are only defined in two- and three-'
'dimensions.')
def mass(self, shape_species: Optional[Array]=None) -> RigidBody:
"""Get a RigidBody with the mass and moment of inertia for each shape.
Arguments:
shape_species: An optional array of integers specifying a mixture of
different shapes. If specified then the mass object will contain
a mass and moment of inertia for each shape in the collection.
"""
ndim = self.dimension()
if ndim == 2:
if shape_species is not None:
return RigidBody(self._sum_over_shapes(self.masses)[shape_species],
self.moment_of_inertia()[shape_species])
return RigidBody(self._sum_over_shapes(self.masses),
self.moment_of_inertia())
elif ndim == 3:
# In three-dimensions, we grab the diagonal of the moment of inertia
# assuming (and checking) that it is properly diagonalized.
I_diag = get_moment_of_inertia_diagonal(self.moment_of_inertia())
if shape_species is not None:
return RigidBody(self._sum_over_shapes(self.masses)[shape_species],
I_diag[shape_species])
return RigidBody(self._sum_over_shapes(self.masses), I_diag)
raise ValueError('Rigid bodies only defined for two- and three-dimensions.'
f' Found {ndim}.')
def __getitem__(self, idx: int) -> 'RigidPointUnion':
"""Extract a single shape from the collection of shapes."""
start = self.point_offset[idx]
end = start + self.point_count[idx]
return RigidPointUnion(self.points[start : end],
self.masses[start : end],
jnp.array([self.point_count[idx]]),
jnp.array([0]),
None if self.point_species is None else
self.point_species[start : end])
def _transform_to_diagonal_frame(shape: RigidPointUnion) -> RigidPointUnion:
"""Transform points to zero center of mass and diagonal moment of inertia."""
ndim = shape.dimension()
assert len(shape.point_count) == 1
if ndim == 2:
total_mass = shape._sum_over_shapes(shape.masses[:, None] * shape.points)
com = total_mass / shape.point_count[:, None]
return shape.set(points=shape.points - com)
elif ndim == 3:
total_mass = jnp.sum(shape.masses)
I, = shape.moment_of_inertia()
I_diag, U = jnp.linalg.eigh(I)
points = jnp.einsum('ni,ij->nj', shape.points, U)
return RigidPointUnion(points,
shape.masses,
shape.point_count,
shape.point_offset)
raise ValueError('Rigid bodies only defined for two- or three-dimensions'
f' found shape of dimension={ndim}.')
[docs]def point_union_shape(points: Array, masses: Array) -> RigidPointUnion:
"""Construct a rigid body out of points and masses.
See :ref:`rigid_body_union` for details.
Arguments:
points: An array point point positions.
masses: An array of particle masses.
Returns:
A RigidPointUnion shape object specifying the shape rotated so that the
moment of inertia tensor is diagonal.
"""
if jnp.isscalar(masses) or masses.shape == ():
masses = masses * jnp.ones((len(points),), points.dtype)
shape = RigidPointUnion(points=points,
masses=masses,
point_count=jnp.array([len(points)]),
point_offset=jnp.array([0]))
return _transform_to_diagonal_frame(shape)
[docs]def concatenate_shapes(*shapes) -> RigidPointUnion:
"""Concatenate a list of RigidPointUnions into a single RigidPointUnion."""
shape_tuples = zip(*[dataclasses.astuple(s) for s in shapes])
points, masses, point_count, point_offset, point_species, _ = shape_tuples
any_point_species = any(x is not None for x in point_species)
if any_point_species and not all(x is not None for x in point_species):
raise ValueError('Either all shapes should have point species or none '
'should have point species.')
if (any_point_species and
not all(isinstance(x, (Array, onp.ndarray)) for x in point_species)):
raise ValueError('All point species should be specified as `onp.ndarray` '
'since the species must be known statically at compile '
'time.')
point_count = jnp.concatenate(point_count)
return RigidPointUnion(
points=jnp.concatenate(points),
masses=jnp.concatenate(masses),
point_species=(None if point_species[0] is None
else jnp.concatenate(point_species)),
point_count=point_count,
point_offset=jnp.concatenate([jnp.array([0]),
jnp.cumsum(point_count)[:-1]])
)
# Change of Basis Transformations (Rigid Body Frame to World Frame)
@partial(jnp.vectorize, signature='()->(d,d)')
def rotation2d(theta: Array) -> Array:
"""Builds a two-dimensional rotation matrix from an angle."""
return jnp.array([[jnp.cos(theta), -jnp.sin(theta)],
[jnp.sin(theta), jnp.cos(theta)]])
[docs]def union_to_points(body: RigidBody,
shape: RigidPointUnion,
shape_species: Optional[onp.ndarray]=None,
**kwargs,
) -> Tuple[Array, Optional[Array]]:
"""Transforms points in a RigidPointUnion to world space."""
if shape_species is None:
position = vmap(transform, (0, None))(body, shape)
point_species = shape.point_species
if point_species is not None:
point_species = shape.point_species[None, :]
point_species = jnp.broadcast_to(point_species, position.shape[:-1])
point_species = jnp.reshape(point_species, (-1,))
position = jnp.reshape(position, (-1, position.shape[-1]))
return position, point_species
elif isinstance(shape_species, onp.ndarray):
shape_species_types = onp.unique(shape_species)
shape_species_count = len(shape_species_types)
assert (len(shape.point_count) == shape_species_count and
onp.max(shape_species_types) == shape_species_count - 1 and
onp.min(shape_species_types) == 0)
shape = tree_map(lambda x: onp.array(x), shape)
point_position = []
point_species = []
for s in range(shape_species_count):
cur_shape = shape[s]
pos = vmap(transform, (0, None))(body[shape_species == s], cur_shape)
ps = cur_shape.point_species
if ps is not None:
ps = cur_shape.point_species[None, :]
ps = jnp.broadcast_to(ps, pos.shape[:-1])
point_species += [jnp.reshape(ps, (-1,))]
pos = jnp.reshape(pos, (-1, pos.shape[-1]))
point_position += [pos]
point_position = jnp.concatenate(point_position)
point_species = jnp.concatenate(point_species) if point_species else None
return point_position, point_species
else:
raise NotImplementedError('Shape species must either be None or of type '
'onp.ndarray since it must be specified ahead '
f'of compilation. Found {type(shape_species)}.')
# Energy Functions
[docs]def point_energy(energy_fn: Callable[..., Array],
shape: RigidPointUnion,
shape_species: Optional[onp.ndarray]=None
) -> Callable[..., Array]:
"""Produces a RigidBody energy given a pointwise energy and a point union.
This function takes takes a pointwise energy function that computes the
energy of a set of particle positions along with a RigidPointUnion
(optionally with shape species information) and produces a new energy
function that computes the energy of a collection of rigid bodies.
Args:
energy_fn: An energy function that takes point positions and produces a
scalar energy function.
shape: A RigidPointUnion shape that contains one or more shapes defined as
a union of point masses.
shape_species: An optional array specifying the composition of the system
in terms of shapes.
Returns:
An energy function that takes a `RigidBody` and produces a scalar energy
energy.
"""
def wrapped_energy_fn(body, **kwargs):
pos, point_species = union_to_points(body, shape, shape_species)
if point_species is None:
return energy_fn(pos, **kwargs)
return energy_fn(pos, species=point_species, **kwargs)
return wrapped_energy_fn
[docs]def point_energy_neighbor_list(energy_fn: Callable[..., Array],
neighbor_fn: NeighborListFns,
shape: RigidPointUnion,
shape_species: Optional[onp.ndarray]=None
) -> Tuple[NeighborListFns,
Callable[..., Array]]:
"""Produces a RigidBody energy given a pointwise energy and a point union.
This function takes takes a pointwise energy function that computes the
energy of a set of particle positions using neighbor lists, a `neighbor_fn`
that builds and updates neighbor lists (see `partition.py` for details),
along with a RigidPointUnion (optionally with shape species information) and
produces a new energy function that computes the energy of a collection of
rigid bodies using neighbor lists and a `neighbor_fn` that is responsible for
building and updating the neighbor lists.
Args:
energy_fn: An energy function that takes point positions along with a set
of neighbors and produces a scalar energy function.
neighbor_fn: A neighbor list function that creates and updates a neighbor
list among points.
shape: A RigidPointUnion shape that contains one or more shapes defined as
a union of point masses.
shape_species: An optional array specifying the composition of the system
in terms of shapes.
Returns:
An energy function that takes a `RigidBody` and produces a scalar energy
energy.
"""
def wrapped_energy_fn(body, neighbor, **kwargs):
pos, species = union_to_points(body, shape, shape_species)
return energy_fn(pos, neighbor=neighbor, species=species, **kwargs)
def neighbor_allocate_fn(body, **kwargs):
pos, species = union_to_points(body, shape, shape_species)
nbrs = neighbor_fn.allocate(pos, **kwargs)
nbrs = dataclasses.replace(nbrs, update_fn=neighbor_update_fn)
return nbrs
def neighbor_update_fn(body, neighbor, **kwargs):
pos, species = union_to_points(body, shape, shape_species)
return neighbor_fn.update(pos, neighbor, **kwargs)
wrapped_neighbor_fns = partition.NeighborListFns(neighbor_allocate_fn,
neighbor_update_fn)
return wrapped_neighbor_fns, wrapped_energy_fn
# Predefined RigidPointUnion shapes.
# 2D Shapes.
monomer = point_union_shape(onp.array([[0.0, 0.0]], f32), f32(1.0))
dimer = point_union_shape(onp.array([[0.0, 0.5], [0.0, -0.5]], f32), f32(1.0))
trimer = point_union_shape(
onp.array([[0, onp.sqrt(1 - 0.5 ** 2) - 0.5],
[0.5, -0.5],
[-0.5, -0.5]], f32),
f32(1.0))
square = point_union_shape(
onp.array([[-0.5, -0.5], [0.5, -0.5], [0.5, 0.5], [-0.5, 0.5]], f32),
f32(1.0))
# 3D Shapes.
tetrahedron = point_union_shape(
onp.array([[1.0, 1.0, 1.0],
[ 1.0, -1.0, -1.0],
[-1.0, 1.0, -1.0],
[-1.0, -1.0, 1.0]], f32) * f32(0.5),
f32(1.0))
octohedron = point_union_shape(
onp.array([[1.0, 0.0, 0.0],
[-1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, -1.0, 0.0],
[0.0, 0.0, 1.0],
[0.0, 0.0, -1.0]], f32) * f32(0.5),
f32(1.0))