Spatial Partitioning#
Code to transform functions on individual tuples of particles to sets.
Neighbor Lists#
- jax_md.partition.neighbor_list(displacement_or_metric, box, r_cutoff, dr_threshold=0.0, capacity_multiplier=1.25, disable_cell_list=False, mask_self=True, custom_mask_function=None, fractional_coordinates=False, format=NeighborListFormat.Dense, **static_kwargs)[source]#
Returns a function that builds a list neighbors for collections of points.
Neighbor lists must balance the need to be jit compatible with the fact that under a jit the maximum number of neighbors cannot change (owing to static shape requirements). To deal with this, our
neighbor_list
returns aNeighborListFns
object that contains two functions: 1)neighbor_fn.allocate
create a new neighbor list and 2)neighbor_fn.update
updates an existing neighbor list. Neighbor lists themselves additionally have a convenienceupdate
member function.Note that allocation of a new neighbor list cannot be jit compiled since it uses the positions to infer the maximum number of neighbors (along with additional space specified by the
capacity_multiplier
). Updating the neighbor list can be jit compiled; if the neighbor list capacity is not sufficient to store all the neighbors, thedid_buffer_overflow
bit will be set toTrue
and a new neighbor list will need to be reallocated.Here is a typical example of a simulation loop with neighbor lists:
init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3) exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift, 1e-3) nbrs = neighbor_fn.allocate(R) state = init_fn(random.PRNGKey(0), R, neighbor_idx=nbrs.idx) def body_fn(i, state): state, nbrs = state nbrs = nbrs.update(state.position) state = apply_fn(state, neighbor_idx=nbrs.idx) return state, nbrs step = 0 for _ in range(20): new_state, nbrs = lax.fori_loop(0, 100, body_fn, (state, nbrs)) if nbrs.did_buffer_overflow: nbrs = neighbor_fn.allocate(state.position) else: state = new_state step += 1
- Parameters
displacement – A function
d(R_a, R_b)
that computes the displacement between pairs of points.box (
Array
) – Either a float specifying the size of the box, an array of shape[spatial_dim]
specifying the box size for a cubic box in each spatial dimension, or a matrix of shape[spatial_dim, spatial_dim]
that is _upper triangular_ and specifies the lattice vectors of the box.r_cutoff (
float
) – A scalar specifying the neighborhood radius.dr_threshold (
float
) – A scalar specifying the maximum distance particles can move before rebuilding the neighbor list.capacity_multiplier (
float
) – A floating point scalar specifying the fractional increase in maximum neighborhood occupancy we allocate compared with the maximum in the example positions.disable_cell_list (
bool
) – An optional boolean. If set toTrue
then the neighbor list is constructed using only distances. This can be useful for debugging but should generally be left asFalse
.mask_self (
bool
) – An optional boolean. Determines whether points can consider themselves to be their own neighbors.custom_mask_function (
Optional
[Callable
[[Array
],Array
]]) – An optional function. Takes the neighbor array and masks selected elements. Note: The input array to the function is(n_particles, m)
where the index of particle 1 is in index in the first dimension of the array, the index of particle 2 is given by the value in the arrayfractional_coordinates (
bool
) – An optional boolean. Specifies whether positions will be supplied in fractional coordinates in the unit cube, \([0, 1]^d\). If this is set to True then thebox_size
will be set to1.0
and the cell size used in the cell list will be set tocutoff / box_size
.format (
NeighborListFormat
) – The format of the neighbor list; see theNeighborListFormat()
enum for details about the different choices for formats. Defaults toDense
.**static_kwargs – kwargs that get threaded through the calculation of example positions.
- Return type
Callable
[[Array
,Optional
[NeighborList
],Optional
[int
]],NeighborList
]- Returns
A NeighborListFns object that contains a method to allocate a new neighbor list and a method to update an existing neighbor list.
- jax_md.partition.neighbor_list_mask(neighbor, mask_self=False)[source]#
Compute a mask for neighbor list.
- Return type
Array
- jax_md.partition.to_jraph(neighbor, mask=None, nodes=None, edges=None, globals=None)[source]#
Convert a sparse neighbor list to a
jraph.GraphsTuple
.As in jraph, padding here is accomplished by adding a ficticious graph with a single node.
- Parameters
neighbor (
NeighborList
) – A neighbor list that we will convert to the jraph format. Must be sparse.mask (
Optional
[Array
]) – An optional mask on the edges.
- Return type
GraphsTuple
- Returns
A
jraph.GraphsTuple
that contains the topology of the neighbor list.
- jax_md.partition.to_dense(neighbor)[source]#
Converts a sparse neighbor list to dense ids. Cannot be JIT.
- Return type
Array
- class jax_md.partition.NeighborList(idx, reference_position, error, cell_list_capacity, max_occupancy, format, cell_size, cell_list_fn, update_fn)[source]#
A struct containing the state of a Neighbor List.
- idx#
For an N particle system this is an
[N, max_occupancy]
array of integers such thatidx[i, j]
is the j-th neighbor of particle i.- Type
jax.Array
- reference_position#
The positions of particles when the neighbor list was constructed. This is used to decide whether the neighbor list ought to be updated.
- Type
jax.Array
- error#
An error code that is used to identify errors that occured during neighbor list construction. See
PartitionError
andPartitionErrorCode
for details.- Type
jax_md.partition.PartitionError
- cell_list_capacity#
An optional integer specifying the capacity of the cell list used as an intermediate step in the creation of the neighbor list.
- Type
Optional[int]
- max_occupancy#
A static integer specifying the maximum size of the neighbor list. Changing this will invoke a recompilation.
- Type
- format#
A NeighborListFormat enum specifying the format of the neighbor list.
- cell_size#
A float specifying the current minimum size of the cells used in cell list construction.
- Type
Optional[float]
- cell_list_fn#
The function used to construct the cell list.
- Type
Callable[[jax.Array, jax_md.partition.CellList], jax_md.partition.CellList]
- update_fn#
A static python function used to update the neighbor list.
- Type
Callable[[jax.Array, jax_md.partition.NeighborList], jax_md.partition.NeighborList]
- class jax_md.partition.NeighborListFormat(value)[source]#
An enum listing the different neighbor list formats.
- Dense#
A dense neighbor list where the ids are a square matrix of shape
(N, max_neighbors_per_atom)
. Here the capacity of the neighbor list must scale with the highest connectivity neighbor.
- Sparse#
A sparse neighbor list where the ids are a rectangular matrix of shape
(2, max_neighbors)
specifying the start / end particle of each neighbor pair.
- OrderedSparse#
A sparse neighbor list whose format is the same as
Sparse
where only bonds with i < j are included.
- class jax_md.partition.NeighborListFns(allocate, update)[source]#
A struct containing functions to allocate and update neighbor lists.
- allocate#
A function to allocate a new neighbor list. This function cannot be compiled, since it uses the values of positions to infer the shapes.
- Type
Callable[[…], jax_md.partition.NeighborList]
- update#
A function to update a neighbor list given a new set of positions and a previously allocated neighbor list.
- Type
Callable[[jax.Array, jax_md.partition.NeighborList], jax_md.partition.NeighborList]
Cell Lists#
- jax_md.partition.cell_list(box_size, minimum_cell_size, buffer_size_multiplier=1.25)[source]#
Returns a function that partitions point data spatially.
Given a set of points \(\{x_i \in R^d\}\) with associated data \(\{k_i \in R^m\}\) it is often useful to partition the points / data spatially. A simple partitioning that can be implemented efficiently within XLA is a dense partition into a uniform grid called a cell list.
Since XLA requires that shapes be statically specified inside of a JIT block, the cell list code can operate in two modes: allocation and update.
Allocation creates a new cell list that uses a set of input positions to estimate the capacity of the cell list. This capacity can be adjusted by setting the
buffer_size_multiplier
or setting theextra_capacity
. Allocation cannot be JIT.Updating takes a previously allocated cell list and places a new set of particles in the cells. Updating cannot resize the cell list and is therefore compatible with JIT. However, if the configuration has changed substantially it is possible that the existing cell list won’t be large enough to accommodate all of the particles. In this case the
did_buffer_overflow
bit will be set to True.- Parameters
box_size (
Array
) – A float or an ndarray of shape[spatial_dimension]
specifying the size of the system. Note, this code is written for the case where the boundaries are periodic. If this is not the case, then the current code will be slightly less efficient.minimum_cell_size (
float
) – A float specifying the minimum side length of each cell. Cells are enlarged so that they exactly fill the box.buffer_size_multiplier (
float
) – A floating point multiplier that multiplies the estimated cell capacity to allow for fluctuations in the maximum cell occupancy.
- Return type
CellListFns
- Returns
A
CellListFns
object that contains two methods, one to allocate the cell list and one to update the cell list. The update function can be called with either a cell list from which the capacity can be inferred or with an explicit integer denoting the capacity. Note that an existing cell list can also be updated by callingcell_list.update(position)
.
- class jax_md.partition.CellList(position_buffer, id_buffer, named_buffer, did_buffer_overflow, cell_capacity, cell_size, update_fn)[source]#
Stores the spatial partition of a system into a cell list.
See
cell_list()
for details on the construction / specification. Cell list buffers all have a common shape, S, whereS = [cell_count_x, cell_count_y, cell_capacity]
S = [cell_count_x, cell_count_y, cell_count_z, cell_capacity]
in two- and three-dimensions respectively. It is assumed that each cell has the same capacity.
- position_buffer#
An ndarray of floating point positions with shape
S + [spatial_dimension]
.- Type
jax.Array
- id_buffer#
An ndarray of int32 particle ids of shape
S
. Note that empty slots are specified byid = N
whereN
is the number of particles in the system.- Type
jax.Array
- named_buffer#
A dictionary of ndarrays of shape
S + [...]
. This contains side data placed into the cell list.- Type
Dict[str, jax.Array]
- did_buffer_overflow#
A boolean specifying whether or not the cell list exceeded the maximum allocated capacity.
- Type
jax.Array
- update_fn#
A function that updates the cell list at a fixed capacity.
- Type
Callable[[…], jax_md.partition.CellList]