Spatial Partitioning#

Code to transform functions on individual tuples of particles to sets.

Neighbor Lists#

jax_md.partition.neighbor_list(displacement_or_metric, box, r_cutoff, dr_threshold=0.0, capacity_multiplier=1.25, disable_cell_list=False, mask_self=True, custom_mask_function=None, fractional_coordinates=False, format=NeighborListFormat.Dense, **static_kwargs)[source]#

Returns a function that builds a list neighbors for collections of points.

Neighbor lists must balance the need to be jit compatible with the fact that under a jit the maximum number of neighbors cannot change (owing to static shape requirements). To deal with this, our neighbor_list returns a NeighborListFns object that contains two functions: 1) neighbor_fn.allocate create a new neighbor list and 2) neighbor_fn.update updates an existing neighbor list. Neighbor lists themselves additionally have a convenience update member function.

Note that allocation of a new neighbor list cannot be jit compiled since it uses the positions to infer the maximum number of neighbors (along with additional space specified by the capacity_multiplier). Updating the neighbor list can be jit compiled; if the neighbor list capacity is not sufficient to store all the neighbors, the did_buffer_overflow bit will be set to True and a new neighbor list will need to be reallocated.

Here is a typical example of a simulation loop with neighbor lists:

init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3)
exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift, 1e-3)

nbrs = neighbor_fn.allocate(R)
state = init_fn(random.PRNGKey(0), R, neighbor_idx=nbrs.idx)

def body_fn(i, state):
  state, nbrs = state
  nbrs = nbrs.update(state.position)
  state = apply_fn(state, neighbor_idx=nbrs.idx)
  return state, nbrs

step = 0
for _ in range(20):
  new_state, nbrs = lax.fori_loop(0, 100, body_fn, (state, nbrs))
  if nbrs.did_buffer_overflow:
    nbrs = neighbor_fn.allocate(state.position)
  else:
    state = new_state
    step += 1
Parameters
  • displacement – A function d(R_a, R_b) that computes the displacement between pairs of points.

  • box (Array) – Either a float specifying the size of the box, an array of shape [spatial_dim] specifying the box size for a cubic box in each spatial dimension, or a matrix of shape [spatial_dim, spatial_dim] that is _upper triangular_ and specifies the lattice vectors of the box.

  • r_cutoff (float) – A scalar specifying the neighborhood radius.

  • dr_threshold (float) – A scalar specifying the maximum distance particles can move before rebuilding the neighbor list.

  • capacity_multiplier (float) – A floating point scalar specifying the fractional increase in maximum neighborhood occupancy we allocate compared with the maximum in the example positions.

  • disable_cell_list (bool) – An optional boolean. If set to True then the neighbor list is constructed using only distances. This can be useful for debugging but should generally be left as False.

  • mask_self (bool) – An optional boolean. Determines whether points can consider themselves to be their own neighbors.

  • custom_mask_function (Optional[Callable[[Array], Array]]) – An optional function. Takes the neighbor array and masks selected elements. Note: The input array to the function is (n_particles, m) where the index of particle 1 is in index in the first dimension of the array, the index of particle 2 is given by the value in the array

  • fractional_coordinates (bool) – An optional boolean. Specifies whether positions will be supplied in fractional coordinates in the unit cube, \([0, 1]^d\). If this is set to True then the box_size will be set to 1.0 and the cell size used in the cell list will be set to cutoff / box_size.

  • format (NeighborListFormat) – The format of the neighbor list; see the NeighborListFormat() enum for details about the different choices for formats. Defaults to Dense.

  • **static_kwargs – kwargs that get threaded through the calculation of example positions.

Return type

Callable[[Array, Optional[NeighborList], Optional[int]], NeighborList]

Returns

A NeighborListFns object that contains a method to allocate a new neighbor list and a method to update an existing neighbor list.

jax_md.partition.neighbor_list_mask(neighbor, mask_self=False)[source]#

Compute a mask for neighbor list.

Return type

Array

jax_md.partition.to_jraph(neighbor, mask=None, nodes=None, edges=None, globals=None)[source]#

Convert a sparse neighbor list to a jraph.GraphsTuple.

As in jraph, padding here is accomplished by adding a ficticious graph with a single node.

Parameters
  • neighbor (NeighborList) – A neighbor list that we will convert to the jraph format. Must be sparse.

  • mask (Optional[Array]) – An optional mask on the edges.

Return type

GraphsTuple

Returns

A jraph.GraphsTuple that contains the topology of the neighbor list.

jax_md.partition.to_dense(neighbor)[source]#

Converts a sparse neighbor list to dense ids. Cannot be JIT.

Return type

Array

class jax_md.partition.NeighborList(idx, reference_position, error, cell_list_capacity, max_occupancy, format, cell_size, cell_list_fn, update_fn)[source]#

A struct containing the state of a Neighbor List.

idx#

For an N particle system this is an [N, max_occupancy] array of integers such that idx[i, j] is the j-th neighbor of particle i.

Type

jax.Array

reference_position#

The positions of particles when the neighbor list was constructed. This is used to decide whether the neighbor list ought to be updated.

Type

jax.Array

error#

An error code that is used to identify errors that occured during neighbor list construction. See PartitionError and PartitionErrorCode for details.

Type

jax_md.partition.PartitionError

cell_list_capacity#

An optional integer specifying the capacity of the cell list used as an intermediate step in the creation of the neighbor list.

Type

Optional[int]

max_occupancy#

A static integer specifying the maximum size of the neighbor list. Changing this will invoke a recompilation.

Type

int

format#

A NeighborListFormat enum specifying the format of the neighbor list.

Type

jax_md.partition.NeighborListFormat

cell_size#

A float specifying the current minimum size of the cells used in cell list construction.

Type

Optional[float]

cell_list_fn#

The function used to construct the cell list.

Type

Callable[[jax.Array, jax_md.partition.CellList], jax_md.partition.CellList]

update_fn#

A static python function used to update the neighbor list.

Type

Callable[[jax.Array, jax_md.partition.NeighborList], jax_md.partition.NeighborList]

class jax_md.partition.NeighborListFormat(value)[source]#

An enum listing the different neighbor list formats.

Dense#

A dense neighbor list where the ids are a square matrix of shape (N, max_neighbors_per_atom). Here the capacity of the neighbor list must scale with the highest connectivity neighbor.

Sparse#

A sparse neighbor list where the ids are a rectangular matrix of shape (2, max_neighbors) specifying the start / end particle of each neighbor pair.

OrderedSparse#

A sparse neighbor list whose format is the same as Sparse where only bonds with i < j are included.

class jax_md.partition.NeighborListFns(allocate, update)[source]#

A struct containing functions to allocate and update neighbor lists.

allocate#

A function to allocate a new neighbor list. This function cannot be compiled, since it uses the values of positions to infer the shapes.

Type

Callable[[…], jax_md.partition.NeighborList]

update#

A function to update a neighbor list given a new set of positions and a previously allocated neighbor list.

Type

Callable[[jax.Array, jax_md.partition.NeighborList], jax_md.partition.NeighborList]

Cell Lists#

jax_md.partition.cell_list(box_size, minimum_cell_size, buffer_size_multiplier=1.25)[source]#

Returns a function that partitions point data spatially.

Given a set of points \(\{x_i \in R^d\}\) with associated data \(\{k_i \in R^m\}\) it is often useful to partition the points / data spatially. A simple partitioning that can be implemented efficiently within XLA is a dense partition into a uniform grid called a cell list.

Since XLA requires that shapes be statically specified inside of a JIT block, the cell list code can operate in two modes: allocation and update.

Allocation creates a new cell list that uses a set of input positions to estimate the capacity of the cell list. This capacity can be adjusted by setting the buffer_size_multiplier or setting the extra_capacity. Allocation cannot be JIT.

Updating takes a previously allocated cell list and places a new set of particles in the cells. Updating cannot resize the cell list and is therefore compatible with JIT. However, if the configuration has changed substantially it is possible that the existing cell list won’t be large enough to accommodate all of the particles. In this case the did_buffer_overflow bit will be set to True.

Parameters
  • box_size (Array) – A float or an ndarray of shape [spatial_dimension] specifying the size of the system. Note, this code is written for the case where the boundaries are periodic. If this is not the case, then the current code will be slightly less efficient.

  • minimum_cell_size (float) – A float specifying the minimum side length of each cell. Cells are enlarged so that they exactly fill the box.

  • buffer_size_multiplier (float) – A floating point multiplier that multiplies the estimated cell capacity to allow for fluctuations in the maximum cell occupancy.

Return type

CellListFns

Returns

A CellListFns object that contains two methods, one to allocate the cell list and one to update the cell list. The update function can be called with either a cell list from which the capacity can be inferred or with an explicit integer denoting the capacity. Note that an existing cell list can also be updated by calling cell_list.update(position).

class jax_md.partition.CellList(position_buffer, id_buffer, named_buffer, did_buffer_overflow, cell_capacity, cell_size, update_fn)[source]#

Stores the spatial partition of a system into a cell list.

See cell_list() for details on the construction / specification. Cell list buffers all have a common shape, S, where

  • S = [cell_count_x, cell_count_y, cell_capacity]

  • S = [cell_count_x, cell_count_y, cell_count_z, cell_capacity]

in two- and three-dimensions respectively. It is assumed that each cell has the same capacity.

position_buffer#

An ndarray of floating point positions with shape S + [spatial_dimension].

Type

jax.Array

id_buffer#

An ndarray of int32 particle ids of shape S. Note that empty slots are specified by id = N where N is the number of particles in the system.

Type

jax.Array

named_buffer#

A dictionary of ndarrays of shape S + [...]. This contains side data placed into the cell list.

Type

Dict[str, jax.Array]

did_buffer_overflow#

A boolean specifying whether or not the cell list exceeded the maximum allocated capacity.

Type

jax.Array

cell_capacity#

An integer specifying the maximum capacity of each cell in the cell list.

Type

int

update_fn#

A function that updates the cell list at a fixed capacity.

Type

Callable[[…], jax_md.partition.CellList]