Rigid Body Simulations#

Code to simulate rigid bodies in two- and three-dimensions.

This code contains a number of pieces that handle different parts of rigid body simulations.

To start with, we include some quaternion utilities for representing oriented bodies in three dimensions.

Rigid body simulations are split into two components.

1) The state of rigid bodies are represented by a dataclass containing a center-of-mass position and an orientation. Along with this type representation, the core simulation functions are overloaded to automatically allow deterministic NVE and NVT simulations to work with state composed of RigidBody objects (see simulation.py for details). If you need any other simulation environments, please raise a github issue.

One subtlety of the type system that we use here is that a host of related quantities are represented by RigidBody objects. For example, the momentum is represented by a RigidBody containing the linear momentum and angular momentum, while the mass is a RigidBody containing the total mass and the moment of inertia. This allows us to naturally use JAX’s tree_map utilities to jointly map over the different related quantities. Additionally, forces inherit the RigidBody type with a center-of-mass force and torque.

2) Interactions between rigid bodies are specified. This is largely responsible for dictating the shape of the rigid body. While arbitrary interactions are possible, we include utility functions for producing rigid bodies that are made by the union of point-like particles. This captures many common models of rigid molecules and colloids. These functions work by providing a RigidPointUnion object that specifies the location of point particles in the body frame along with a pointwise energy function. This approach works with or without neighbor lists and yields a function that computes the total energy on a system of rigid bodies.

Quaternion Utilities#

class jax_md.rigid_body.Quaternion(vec)[source]#

An object representing a quaternion.

Data is stored in a vector array, but this class exposes several convenience features including quaternion multiplication and conjugation. It also changes the size property to return the number of degrees of freedom of the quaternion (which is three since we expect the quaternion to be normalized.

vec#

An array containing the underlying jax.numpy representation.

Type

jax.Array

jax_md.rigid_body.quaternion_rotate(q, v)[source]#

Rotates a vector by a given quaternion.

Return type

Array

jax_md.rigid_body.random_quaternion(key, dtype)[source]#

Generate a random quaternion of a given dtype.

Return type

Quaternion

Rigid Body Simulation#

class jax_md.rigid_body.RigidBody(center, orientation)[source]#

Defines a body described by a position and orientation.

One subtlety about the use of RigidBody objects in JAX MD is that they are used to describe several different related concepts. In general the RigidBody object contains two pieces of data: the center containing information about the center of mass of the body and orientation containing information about the orientation of the body. In practice, this means that RigidBody objects are used to describe a number quantities that all have a center-of-mass and orientational components.

For example, the instantaneous state of a rigid body might be described by a RigidBody containing center-of-mass position and orientation. The momentum of the body will be described by a RigidBody containing the linear momentum and the angular momentum. The force on the body will be described by a RigidBody containing linear force and torque. Finally, the mass of the body will be described by a RigidBody containing the total mass and the angular momentum.

When used in conjunction with automatic differentiation or simulation environments, forces and velocities will also be of type RigidBody. In these cases the orientation should be interpreted as torque and angular momentum respectively.

center#

An array of two- or three-dimensional positions giving the center position of the body.

Type

jax.Array

orientation#

In two-dimensions this will be an array of angles. In three- dimensions this will be a set of quaternions.

Type

Union[jax.Array, jax_md.rigid_body.Quaternion]

jax_md.rigid_body.kinetic_energy(position, momentum, mass)[source]#

Computes the kinetic energy of a system with some momenta.

Return type

float

jax_md.rigid_body.temperature(position, momentum, mass)[source]#

Computes the temperature of a system with some momenta.

Return type

float

jax_md.rigid_body.angular_momentum_to_conjugate_momentum(orientation, omega)[source]#

Transforms angular momentum vector to a conjugate momentum quaternion.

Return type

Quaternion

jax_md.rigid_body.conjugate_momentum_to_angular_momentum(orientation, momentum)[source]#

Convert from the conjugate momentum of a quaternion to angular momentum.

Simulations involving quaternions typically proceed by integrating Hamilton’s equations with an extended Hamiltonian,

\[H(p, q) = 1/8 p^T S(q) D S(q)^T p + \phi(q)\]

where q is the orientation and p is the conjugate momentum variable. Note (!!) unlike in problems involving only positional degrees of freedom, it is not the case here that dq/dt = p / m. The conjugate momentum is defined only by the Legendre transformation.

This means that you cannot compute the angular velocity by simply transforming the conjugate momentum as you would the time-derivative of q. Compare, for example equation (2.13) and (2.15) in [1].

[1] Symplectic quaternion scheme for biophysical molecular dynamics Miller, Eleftheriou, Pattnaik, Ndirango, Newns, and Martyna J. Chem. Phys. 116 20 (2002)

Return type

Array

Rigid Collections of Points#

class jax_md.rigid_body.RigidPointUnion(points, masses, point_count, point_offset, point_species=None, point_radius=<factory>)[source]#

Defines a rigid collection of point-like masses glued together.

This class describes a rigid body as a collection of point-like particles rigidly arranged in space. These points can have variable masses. Rigid bodies interact by specifying well-defined pair potentials between the different points. This is a common model for rigid molecules and colloids.

To avoid a singularity in the case of a rigid body with a single point, the particles are represented by disks in two-dimensions and spheres in three-dimensions so that each point-mass has a moment of inertia, \(I_{disk} = r^2/2\) in two-dimensions and \(I_{sphere} = 2r^2/5\) in three-dimensions.

Each point can optionally be described by an integer specifying its species (that we will refer to as a “point species”). Different point species typically interact with different potential parameters.

Additionally, this class can store multiple different shapes packed together that get referenced by a “shape species”. In this case total_points refers to the total number of points among all the shapes while shape_count refers to the number of different kinds of shapes.

points#

An array of shape (total_points, spatial_dim) specifying the position of the points making up each rigid shape.

Type

jax.Array

masses#

An array of shape (total_points,) specifying the mass of each point in the union.

Type

jax.Array

point_count#

An array of shape (shape_count,) specifying the number of points in each shape.

Type

jax.Array

point_offset#

An array of shape (shape_count,) specifying the starting index in the points array for each shape.

Type

jax.Array

point_species#

An optional array of shape (total_points,) specifying the species of each point making up the rigid shape.

Type

Optional[jax.Array]

point_radius#

A float specifying the radius for the disk / sphere used in computing the moment of inertia for each point-like particle.

Type

float

jax_md.rigid_body.point_union_shape(points, masses)[source]#

Construct a rigid body out of points and masses.

See rigid_body_union for details.

Parameters
  • points (Array) – An array point point positions.

  • masses (Array) – An array of particle masses.

Return type

RigidPointUnion

Returns

A RigidPointUnion shape object specifying the shape rotated so that the moment of inertia tensor is diagonal.

jax_md.rigid_body.concatenate_shapes(*shapes)[source]#

Concatenate a list of RigidPointUnions into a single RigidPointUnion.

Return type

RigidPointUnion

jax_md.rigid_body.point_energy(energy_fn, shape, shape_species=None)[source]#

Produces a RigidBody energy given a pointwise energy and a point union.

This function takes takes a pointwise energy function that computes the energy of a set of particle positions along with a RigidPointUnion (optionally with shape species information) and produces a new energy function that computes the energy of a collection of rigid bodies.

Parameters
  • energy_fn (Callable[…, Array]) – An energy function that takes point positions and produces a scalar energy function.

  • shape (RigidPointUnion) – A RigidPointUnion shape that contains one or more shapes defined as a union of point masses.

  • shape_species (Optional[ndarray]) – An optional array specifying the composition of the system in terms of shapes.

Return type

Callable[…, Array]

Returns

An energy function that takes a RigidBody and produces a scalar energy energy.

jax_md.rigid_body.point_energy_neighbor_list(energy_fn, neighbor_fn, shape, shape_species=None)[source]#

Produces a RigidBody energy given a pointwise energy and a point union.

This function takes takes a pointwise energy function that computes the energy of a set of particle positions using neighbor lists, a neighbor_fn that builds and updates neighbor lists (see partition.py for details), along with a RigidPointUnion (optionally with shape species information) and produces a new energy function that computes the energy of a collection of rigid bodies using neighbor lists and a neighbor_fn that is responsible for building and updating the neighbor lists.

Parameters
  • energy_fn (Callable[…, Array]) – An energy function that takes point positions along with a set of neighbors and produces a scalar energy function.

  • neighbor_fn (NeighborListFns) – A neighbor list function that creates and updates a neighbor list among points.

  • shape (RigidPointUnion) – A RigidPointUnion shape that contains one or more shapes defined as a union of point masses.

  • shape_species (Optional[ndarray]) – An optional array specifying the composition of the system in terms of shapes.

Return type

Tuple[NeighborListFns, Callable[…, Array]]

Returns

An energy function that takes a RigidBody and produces a scalar energy energy.

jax_md.rigid_body.transform(body, shape)[source]#

Transform a rigid point union from body frame to world frame.

Return type

Array

jax_md.rigid_body.union_to_points(body, shape, shape_species=None, **kwargs)[source]#

Transforms points in a RigidPointUnion to world space.

Return type

Tuple[Array, Optional[Array]]