Defining Spaces#

Spaces in which particles are simulated.

Spaces are pairs of functions containing:
displacement_fn(Ra, Rb, **kwargs):

Computes displacements between pairs of particles. Ra and Rb should be ndarrays of shape [spatial_dim]. Returns an ndarray of shape [spatial_dim]. To compute the displacement over more than one particle at a time see the map_product(), map_bond(), and map_neighbor() functions.

shift_fn(R, dR, **kwargs):

Moves points at position R by an amount dR.

Spaces can accept keyword arguments allowing the space to be changed over the course of a simulation. For an example of this use see periodic_general().

Although displacement functions are compute the displacement between two points, it is often useful to compute displacements between multiple particles in a vectorized fashion. To do this we provide three functions: map_product, map_bond, and map_neighbor:

map_product:

Computes displacements between all pairs of points such that if Ra has shape [n, spatial_dim] and Rb has shape [m, spatial_dim] then the output has shape [n, m, spatial_dim].

map_bond:

Computes displacements between all points in a list such that if Ra has shape [n, spatial_dim] and Rb has shape [m, spatial_dim] then the output has shape [n, spatial_dim].

map_neighbor:

Computes displacements between points and all of their neighbors such that if Ra has shape [n, spatial_dim] and Rb has shape [n, neighbors, spatial_dim] then the output has shape [n, neighbors, spatial_dim].

Spaces#

jax_md.space.free()[source]#

Free boundary conditions.

Return type

Tuple[Callable[[Array, Array], Array], Callable[[Array, Array], Array]]

jax_md.space.periodic(side, wrapped=True)[source]#

Periodic boundary conditions on a hypercube of sidelength side.

Parameters
  • side (Array) – Either a float or an ndarray of shape [spatial_dimension] specifying the size of each side of the periodic box.

  • wrapped (bool) – A boolean specifying whether or not particle positions are remapped back into the box after each step

Return type

Tuple[Callable[[Array, Array], Array], Callable[[Array, Array], Array]]

Returns

(displacement_fn, shift_fn) tuple.

jax_md.space.periodic_general(box, fractional_coordinates=True, wrapped=True)[source]#

Periodic boundary conditions on a parallelepiped.

This function defines a simulation on a parallelepiped, \(X\), formed by applying an affine transformation, \(T\), to the unit hypercube \(U = [0, 1]^d\) along with periodic boundary conditions across all of the faces.

Formally, the space is defined such that \(X = {Tu : u \in [0, 1]^d}\).

The affine transformation, \(T\), can be specified in a number of different ways. For a parallelepiped that is: 1) a cube of side length \(L\), the affine transformation can simply be a scalar; 2) an orthorhombic unit cell can be specified by a vector [Lx, Ly, Lz] of lengths for each axis; 3) a general triclinic cell can be specified by an upper triangular matrix.

There are a number of ways to parameterize a simulation on \(X\). periodic_general supports two parametrizations of \(X\) that can be selected using the fractional_coordinates keyword argument.

  1. When fractional_coordinates=True, particle positions are stored in the unit cube, \(u\in U\). Here, the displacement function computes the displacement between \(x, y \in X\) as \(d_X(x, y) = Td_U(u, v)\) where \(d_U\) is the displacement function on the unit cube, \(U\), \(x = Tu\), and \(v = Tv\) with \(u, v \in U\). The derivative of the displacement function is defined so that derivatives live in \(X\) (as opposed to being backpropagated to \(U\)). The shift function, shift_fn(R, dR) is defined so that \(R\) is expected to lie in \(U\) while \(dR\) should lie in \(X\). This combination enables code such as shift_fn(R, force_fn(R)) to work as intended.

  2. When fractional_coordinates=False, particle positions are stored in the parallelepiped \(X\). Here, for \(x, y \in X\), the displacement function is defined as \(d_X(x, y) = Td_U(T^{-1}x, T^{-1}y)\). Since there is an extra multiplication by \(T^{-1}\), this parameterization is typically slower than fractional_coordinates=False. As in 1), the displacement function is defined to compute derivatives in \(X\). The shift function is defined so that \(R\) and \(dR\) should both lie in \(X\).

Example:

from jax import random
side_length = 10.0
disp_frac, shift_frac = periodic_general(side_length,
                                          fractional_coordinates=True)
disp_real, shift_real = periodic_general(side_length,
                                          fractional_coordinates=False)

# Instantiate random positions in both parameterizations.
R_frac = random.uniform(random.PRNGKey(0), (4, 3))
R_real = side_length * R_frac

# Make some shift vectors.
dR = random.normal(random.PRNGKey(0), (4, 3))

disp_real(R_real[0], R_real[1]) == disp_frac(R_frac[0], R_frac[1])
transform(side_length, shift_frac(R_frac, 1.0)) == shift_real(R_real, 1.0)

It is often desirable to deform a simulation cell either: using a finite deformation during a simulation, or using an infinitesimal deformation while computing elastic constants. To do this using fractional coordinates, we can supply a new affine transformation as displacement_fn(Ra, Rb, box=new_box). When using real coordinates, we can specify positions in a space \(X\) defined by an affine transformation \(T\) and compute displacements in a deformed space \(X'\) defined by an affine transformation \(T'\). This is done by writing displacement_fn(Ra, Rb, new_box=new_box).

There are a few caveats when using periodic_general. periodic_general uses the minimum image convention, and so it will fail for potentials whose cutoff is longer than the half of the side-length of the box. It will also fail to find the correct image when the box is too deformed. We hope to add a more robust box for small simulations soon (TODO) along with better error checking. In the meantime caution is recommended.

Parameters
  • box (Array) – A (spatial_dim, spatial_dim) affine transformation.

  • fractional_coordinates (bool) – A boolean specifying whether positions are stored in the parallelepiped or the unit cube.

  • wrapped (bool) – A boolean specifying whether or not particle positions are remapped back into the box after each step

Return type

Tuple[Callable[[Array, Array], Array], Callable[[Array, Array], Array]]

Returns

(displacement_fn, shift_fn) tuple.

Higher Order Functions#

jax_md.space.metric(displacement)[source]#

Takes a displacement function and creates a metric.

Return type

Callable[[Array, Array], float]

jax_md.space.map_product(metric_or_displacement)[source]#

Vectorizes a metric or displacement function over all pairs.

Return type

Union[Callable[[Array, Array], Array], Callable[[Array, Array], float]]

jax_md.space.map_bond(metric_or_displacement)[source]#

Vectorizes a metric or displacement function over bonds.

Return type

Union[Callable[[Array, Array], Array], Callable[[Array, Array], float]]

jax_md.space.map_neighbor(metric_or_displacement)[source]#

Vectorizes a metric or displacement function over neighborhoods.

Return type

Union[Callable[[Array, Array], Array], Callable[[Array, Array], float]]

jax_md.space.canonicalize_displacement_or_metric(displacement_or_metric)[source]#

Checks whether or not a displacement or metric was provided.

Helper Functions#

jax_md.space.transform(box, R)[source]#

Apply an affine transformation to positions.

See periodic_general for a description of the semantics of box.

Parameters
  • box (Array) – An affine transformation described in periodic_general.

  • R (Array) – Array of positions. Should have shape (..., spatial_dimension).

Return type

Array

Returns

A transformed array positions of shape (..., spatial_dimension).

jax_md.space.square_distance(dR)[source]#

Computes square distances.

Parameters

dR (Array) – Matrix of displacements; ndarray(shape=[..., spatial_dim]).

Return type

Array

Returns

Matrix of squared distances; ndarray(shape=[...]).

jax_md.space.distance(dR)[source]#

Computes distances.

Parameters

dR (Array) – Matrix of displacements; ndarray(shape=[..., spatial_dim]).

Return type

Array

Returns

Matrix of distances; ndarray(shape=[...]).

jax_md.space.pairwise_displacement(Ra, Rb)[source]#

Compute a matrix of pairwise displacements given two sets of positions.

Parameters
  • Ra (Array) – Vector of positions; ndarray(shape=[spatial_dim]).

  • Rb (Array) – Vector of positions; ndarray(shape=[spatial_dim]).

Return type

Array

Returns

Matrix of displacements; ndarray(shape=[spatial_dim]).

jax_md.space.periodic_shift(side, R, dR)[source]#

Shifts positions, wrapping them back within a periodic hypercube.

Return type

Array