Spatial Partitioning#
Code to transform functions on individual tuples of particles to sets.
Neighbor Lists#
- jax_md.partition.neighbor_list(displacement_or_metric, box, r_cutoff, dr_threshold=0.0, capacity_multiplier=1.25, disable_cell_list=False, mask_self=True, custom_mask_function=None, fractional_coordinates=False, format=NeighborListFormat.Dense, **static_kwargs)[source]#
Returns a function that builds a list neighbors for collections of points.
Neighbor lists must balance the need to be jit compatible with the fact that under a jit the maximum number of neighbors cannot change (owing to static shape requirements). To deal with this, our
neighbor_listreturns aNeighborListFnsobject that contains two functions: 1)neighbor_fn.allocatecreate a new neighbor list and 2)neighbor_fn.updateupdates an existing neighbor list. Neighbor lists themselves additionally have a convenienceupdatemember function.Note that allocation of a new neighbor list cannot be jit compiled since it uses the positions to infer the maximum number of neighbors (along with additional space specified by the
capacity_multiplier). Updating the neighbor list can be jit compiled; if the neighbor list capacity is not sufficient to store all the neighbors, thedid_buffer_overflowbit will be set toTrueand a new neighbor list will need to be reallocated.Here is a typical example of a simulation loop with neighbor lists:
init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3) exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift, 1e-3) nbrs = neighbor_fn.allocate(R) state = init_fn(random.PRNGKey(0), R, neighbor_idx=nbrs.idx) def body_fn(i, state): state, nbrs = state nbrs = nbrs.update(state.position) state = apply_fn(state, neighbor_idx=nbrs.idx) return state, nbrs step = 0 for _ in range(20): new_state, nbrs = lax.fori_loop(0, 100, body_fn, (state, nbrs)) if nbrs.did_buffer_overflow: nbrs = neighbor_fn.allocate(state.position) else: state = new_state step += 1
- Parameters:
displacement – A function
d(R_a, R_b)that computes the displacement between pairs of points.box (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Either a float specifying the size of the box, an array of shape[spatial_dim]specifying the box size for a cubic box in each spatial dimension, or a matrix of shape[spatial_dim, spatial_dim]that is upper triangular and specifies the lattice vectors of the box.r_cutoff (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – A scalar specifying the neighborhood radius.dr_threshold (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – A scalar specifying the maximum distance particles can move before rebuilding the neighbor list.capacity_multiplier (
float) – A floating point scalar specifying the fractional increase in maximum neighborhood occupancy we allocate compared with the maximum in the example positions.disable_cell_list (
bool) – An optional boolean. If set toTruethen the neighbor list is constructed using only distances. This can be useful for debugging but should generally be left asFalse.mask_self (
bool) – An optional boolean. Determines whether points can consider themselves to be their own neighbors.custom_mask_function (
Optional[Callable[[Array],Array]]) – An optional function. Takes the neighbor array and masks selected elements. Note: The input array to the function is(n_particles, m)where the index of particle 1 is in index in the first dimension of the array, the index of particle 2 is given by the value in the arrayfractional_coordinates (
bool) – An optional boolean. Specifies whether positions will be supplied in fractional coordinates in the unit cube, \([0, 1]^d\). If this is set to True then thebox_sizewill be set to1.0and the cell size used in the cell list will be set tocutoff / box_size.format (
NeighborListFormat) – The format of the neighbor list; see theNeighborListFormat()enum for details about the different choices for formats. Defaults toDense.**static_kwargs – kwargs that get threaded through the calculation of example positions.
- Return type:
- Returns:
A NeighborListFns object that contains a method to allocate a new neighbor list and a method to update an existing neighbor list.
- jax_md.partition.neighbor_list_mask(neighbor, mask_self=False)[source]#
Compute a mask for neighbor list.
- Return type:
Array
- jax_md.partition.to_jraph(neighbor, mask=None, nodes=None, edges=None, globals=None)[source]#
Convert a sparse neighbor list to a
jraph.GraphsTuple.As in jraph, padding here is accomplished by adding a ficticious graph with a single node.
- Parameters:
neighbor (
NeighborList) – A neighbor list that we will convert to the jraph format. Must be sparse.mask (
Array|None) – An optional mask on the edges.
- Return type:
GraphsTuple- Returns:
A
jraph.GraphsTuplethat contains the topology of the neighbor list.
- jax_md.partition.to_dense(neighbor)[source]#
Converts a sparse neighbor list to dense ids. Cannot be JIT.
- Return type:
Array
- class jax_md.partition.NeighborList(idx, reference_position, error, cell_list_capacity, max_occupancy, format, cell_size, cell_list_fn, update_fn)[source]#
A struct containing the state of a Neighbor List.
- idx#
For an N particle system this is an
[N, max_occupancy]array of integers such thatidx[i, j]is the j-th neighbor of particle i.
- reference_position#
The positions of particles when the neighbor list was constructed. This is used to decide whether the neighbor list ought to be updated.
- error#
An error code that is used to identify errors that occured during neighbor list construction. See
PartitionErrorandPartitionErrorCodefor details.
- cell_list_capacity#
An optional integer specifying the capacity of the cell list used as an intermediate step in the creation of the neighbor list.
- max_occupancy#
A static integer specifying the maximum size of the neighbor list. Changing this will invoke a recompilation.
- format#
A NeighborListFormat enum specifying the format of the neighbor list.
- cell_size#
A float specifying the current minimum size of the cells used in cell list construction.
- cell_list_fn#
The function used to construct the cell list.
- update_fn#
A static python function used to update the neighbor list.
- class jax_md.partition.NeighborListFormat(*values)[source]#
An enum listing the different neighbor list formats.
- Dense#
A dense neighbor list where the ids are a square matrix of shape
(N, max_neighbors_per_atom). Here the capacity of the neighbor list must scale with the highest connectivity neighbor.
- Sparse#
A sparse neighbor list where the ids are a rectangular matrix of shape
(2, max_neighbors)specifying the start / end particle of each neighbor pair.
- OrderedSparse#
A sparse neighbor list whose format is the same as
Sparsewhere only bonds with i < j are included.
- class jax_md.partition.NeighborListFns(allocate, update)[source]#
A struct containing functions to allocate and update neighbor lists.
- allocate#
A function to allocate a new neighbor list. This function cannot be compiled, since it uses the values of positions to infer the shapes.
- update#
A function to update a neighbor list given a new set of positions and a previously allocated neighbor list.
Cell Lists#
- jax_md.partition.cell_list(box_size, minimum_cell_size, buffer_size_multiplier=1.25)[source]#
Returns a function that partitions point data spatially.
Given a set of points \(\{x_i \in R^d\}\) with associated data \(\{k_i \in R^m\}\) it is often useful to partition the points / data spatially. A simple partitioning that can be implemented efficiently within XLA is a dense partition into a uniform grid called a cell list.
Since XLA requires that shapes be statically specified inside of a JIT block, the cell list code can operate in two modes: allocation and update.
Allocation creates a new cell list that uses a set of input positions to estimate the capacity of the cell list. This capacity can be adjusted by setting the
buffer_size_multiplieror setting theextra_capacity. Allocation cannot be JIT.Updating takes a previously allocated cell list and places a new set of particles in the cells. Updating cannot resize the cell list and is therefore compatible with JIT. However, if the configuration has changed substantially it is possible that the existing cell list won’t be large enough to accommodate all of the particles. In this case the
did_buffer_overflowbit will be set to True.- Parameters:
box_size (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – A float or an ndarray of shape[spatial_dimension]specifying the size of the system. Note, this code is written for the case where the boundaries are periodic. If this is not the case, then the current code will be slightly less efficient.minimum_cell_size (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – A float specifying the minimum side length of each cell. Cells are enlarged so that they exactly fill the box.buffer_size_multiplier (
float) – A floating point multiplier that multiplies the estimated cell capacity to allow for fluctuations in the maximum cell occupancy.
- Return type:
CellListFns- Returns:
A
CellListFnsobject that contains two methods, one to allocate the cell list and one to update the cell list. The update function can be called with either a cell list from which the capacity can be inferred or with an explicit integer denoting the capacity. Note that an existing cell list can also be updated by callingcell_list.update(position).
- class jax_md.partition.CellList(position_buffer, id_buffer, particle_cell_id, named_buffer, did_buffer_overflow, cell_capacity, cell_size, update_fn)[source]#
Stores the spatial partition of a system into a cell list.
See
cell_list()for details on the construction / specification. Cell list buffers all have a common shape, S, where:S = [cell_count_x, cell_count_y, cell_capacity]S = [cell_count_x, cell_count_y, cell_count_z, cell_capacity]
in two- and three-dimensions respectively. It is assumed that each cell has the same capacity.
- position_buffer#
An ndarray of floating point positions with shape
S + [spatial_dimension].
- id_buffer#
An ndarray of int32 particle ids of shape
S. Note that empty slots are specified byid = NwhereNis the number of particles in the system.
- particle_cell_id#
An ndarray of int32 with shape
[N, spatial_dimension]storing each particle’s dimension-wise cell index.
- named_buffer#
A dictionary of ndarrays of shape
S + [...]. This contains side data placed into the cell list.
- did_buffer_overflow#
A boolean specifying whether or not the cell list exceeded the maximum allocated capacity.
- cell_capacity#
An integer specifying the maximum capacity of each cell in the cell list.
- update_fn#
A function that updates the cell list at a fixed capacity.