Multi-Image Neighbor Lists#
This module provides neighbor list construction for small periodic boxes where the cutoff radius exceeds half the box size (\(r_\text{cut} > L/2\)). In such cases, the standard minimum image convention (MIC) fails, and particles may interact with multiple periodic images of their neighbors.
Neighbor List Construction#
- jax_md.custom_partition.neighbor_list_multi_image(displacement_or_metric, box, r_cutoff, dr_threshold=0.0, capacity_multiplier=1.25, pbc=None, fractional_coordinates=True, ordered=False, format=NeighborListFormat.Sparse, **kwargs)[source]#
Returns functions to build neighbor lists for small periodic boxes.
This function mirrors the API of
jax_md.partition.neighbor_listbut correctly handles small boxes where \(r_\text{cut} > L/2\) by explicitly enumerating periodic images. Works for any dimension.Algorithm:
For each lattice direction \(i\), computes the number of shifts needed:
\[n_i = \lceil r_\text{cut} / h_i \rceil\]where \(h_i\) is the perpendicular height of the box along direction \(i\). Then enumerates all integer shift vectors \(\mathbf{s} \in [-n_1, n_1] \times \ldots \times [-n_d, n_d]\) and finds pairs \((i, j)\) with:
\[\|\mathbf{r}_j + \mathbf{s} \cdot \mathbf{T} - \mathbf{r}_i\| < r_\text{cut}\]Usage:
from jax_md.custom_partition import neighbor_list_multi_image neighbor_fn = neighbor_list_multi_image(None, box, r_cutoff) nbrs = neighbor_fn.allocate(R) for _ in range(steps): nbrs = nbrs.update(state.position) if nbrs.did_buffer_overflow: nbrs = neighbor_fn.allocate(state.position) state = apply_fn(state, nbrs)
- Parameters:
displacement_or_metric – Ignored. Accepted for API compatibility with
partition.neighbor_list. Multi-image computes displacements using explicit lattice shifts.box (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Affine transformation (seejax_md.space.periodic_general). Shape[dim, dim]. Columns are lattice vectors.r_cutoff (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Interaction cutoff distance (scalar).dr_threshold (
float) – Maximum distance atoms can move before rebuilding. Set to 0 to always rebuild. Uses \(d_\text{max} < d_\text{thresh}/2\) as the skip condition.capacity_multiplier (
float) – Safety factor for neighbor list capacity.pbc (
Array|ndarray|None) – Boolean array indicating periodic directions. Shape[dim]. Default: all True.fractional_coordinates (
bool) – If True, positions are in fractional coordinates.ordered (
bool) – If True, use OrderedSparse format (one direction per pair). Uses 2x less memory. Ignored for Dense format.format (
NeighborListFormat) –Neighbor list format:
Sparse: Edge list(receivers, senders). Shape[capacity].OrderedSparse: Like Sparse but only \(i < j\) pairs.Dense: Per-atom neighbors. Shape[N, max_neighbors].
**kwargs – Additional arguments (ignored, for API compatibility).
- Returns:
allocate(position, box=None): Create new neighbor list from positions[N, dim]. Passbox=new_boxto rebuild for a different box geometry (e.g. during cell optimization).update(position, neighbors): Update existing neighbor list.
- Return type:
- jax_md.custom_partition.neighbor_list_multi_image_mask(neighbors)[source]#
Compute a boolean mask for valid edges in a neighbor list.
This is equivalent to
jax_md.partition.neighbor_list_mask. An edge is valid if its receiver index is less thanN(invalid edges are padded with indexN).- Parameters:
neighbors (
NeighborListMultiImage) – ANeighborListMultiImage(Sparse or OrderedSparse format).- Return type:
Array- Returns:
Boolean mask. Shape
[capacity]. True indicates a valid edge.
Data Structures#
- class jax_md.custom_partition.NeighborListMultiImage(idx, shifts, reference_position, reference_box, box, search_shifts, format, max_occupancy, update_fn, did_buffer_overflow=False)[source]#
A struct containing the state of a multi-image neighbor list.
This data structure is compatible with
jax_md.partition.to_jraphandjax_md.partition.neighbor_list_mask. It stores edges between atoms, including all periodic images within the cutoff (not just the nearest).Supports two storage formats:
- Sparse/OrderedSparse:
idx: Tuple (receivers, senders), each shape (capacity,)
shifts: Shape (capacity, dim)
- Dense:
idx: Shape (N, max_neighbors), neighbor indices for each atom
shifts: Shape (N, max_neighbors, dim), shift for each neighbor
- idx#
For Sparse: tuple (receivers, senders). For Dense: array (N, max_neighbors). Invalid entries are padded with index N (number of atoms).
- shifts#
Integer shift vectors. Sparse: (capacity, dim). Dense: (N, max_neighbors, dim). The real-space shift is
shifts @ box.T.
- reference_position#
Positions when the list was built, shape (N, dim).
- reference_box#
Box when the list was built, shape (dim, dim).
- box#
An affine transformation; see
jax_md.space.periodic_general.
- search_shifts#
Integer image stencil used to build this allocation.
- format#
NeighborListFormat.Sparse, OrderedSparse, or Dense.
- max_occupancy#
For Sparse: total capacity. For Dense: max_neighbors.
- update_fn#
Function to update the neighbor list.
- did_buffer_overflow#
True if more edges/neighbors were found than capacity allows.
- box: Array#
- format: NeighborListFormat#
- property receivers: Array#
Receiver atom indices (Sparse format only).
- reference_box: Array#
- reference_position: Array#
- search_shifts: Array#
- property senders: Array#
Sender atom indices (Sparse format only).
- set(**kwargs)#
- shifts: Array#
- update_fn: Callable[[...], NeighborListMultiImage]#
- class jax_md.custom_partition.NeighborListMultiImageFns(allocate, update)[source]#
A struct containing functions to allocate and update neighbor lists.
This mirrors the
jax_md.partition.NeighborListFnsinterface.- allocate#
A function to allocate a new neighbor list. This function cannot be compiled, since it uses the values of positions to infer the shapes. Signature:
(position: Array[N, dim], **kwargs) -> NeighborListMultiImage
- update#
A function to update a neighbor list given a new set of positions and a previously allocated neighbor list. Signature:
(position: Array[N, dim], neighbors: NeighborListMultiImage, **kwargs) -> NeighborListMultiImage
Capacity Estimation#
- jax_md.custom_partition.estimate_max_neighbors(r_cutoff, atomic_density=0.1, safety_factor=2.0, dim=3)[source]#
Estimate maximum neighbors per atom from atomic density.
Quick estimation when you don’t have box/n_atoms information. Uses:
\[N_\text{neighbors} = \text{safety\_factor} \cdot \rho \cdot V_\text{sphere}\]where \(\rho\) is the atomic density and \(V_\text{sphere}\) is the volume of a sphere with radius
r_cutoff.- Parameters:
r_cutoff (
float) – Interaction cutoff distance (scalar).atomic_density (
float) – Atomic density in atoms per unit volume. Default: 0.1. Common values: ~0.03 for gases, ~0.1 for liquids, ~0.15 for solids.safety_factor (
float) – Multiplier for safety margin. Default: 2.0.dim (
int) – Spatial dimension. Default: 3.
- Return type:
- Returns:
Estimated maximum neighbors per atom.
Example
from jax_md.custom_partition import ( neighbor_list_multi_image, estimate_max_neighbors, ) max_nbrs = estimate_max_neighbors(r_cutoff=2.5, atomic_density=0.1) neighbor_fn = neighbor_list_multi_image(None, box, r_cutoff, max_neighbors=max_nbrs)
See also
estimate_max_neighbors_from_box: More accurate estimation when box and n_atoms are known, accounts for multiple periodic images.
- jax_md.custom_partition.estimate_max_neighbors_from_box(box, r_cutoff, n_atoms, safety_factor=2.0, pbc=None)[source]#
Estimate maximum neighbors per atom from box and atom count.
Accurate estimation for multi-image neighbor lists, supporting both fully periodic and mixed boundary conditions.
\[N_\text{neighbors} = \text{safety\_factor} \cdot \rho \cdot V_\text{eff}\]where \(\rho = N / V_\text{box}\) is the number density and \(V_\text{eff}\) is the effective cutoff volume (sphere for fully periodic, capped for non-periodic directions).
For non-periodic directions, the cutoff is capped at the box length since atoms beyond the box boundary don’t exist. This reduces the effective cutoff volume and thus the expected neighbor count.
- Parameters:
box (
Array|ndarray) – Affine transformation (seejax_md.space.periodic_general). Shape[dim, dim]. Columns are lattice vectors.r_cutoff (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Interaction cutoff distance (scalar).n_atoms (
int) – Number of atoms in the system.safety_factor (
float) – Multiplier for safety margin. Default: 2.0.pbc (
Array|None) – Boolean array indicating periodic directions. Shape[dim]. Default: all True. For non-periodic directions, the cutoff is capped at the box length in that direction.
- Return type:
- Returns:
Estimated maximum neighbors per atom.
- Raises:
ValueError – If spatial dimension is not 1, 2, or 3.
Example
from jax_md.custom_partition import ( neighbor_list_multi_image, estimate_max_neighbors_from_box, ) # Fully periodic max_nbrs = estimate_max_neighbors_from_box(box, r_cutoff, n_atoms=N) # Mixed: periodic in x,y but not z max_nbrs = estimate_max_neighbors_from_box( box, r_cutoff, n_atoms=N, pbc=[True, True, False] ) neighbor_fn = neighbor_list_multi_image( None, box, r_cutoff, max_neighbors=max_nbrs, pbc=[True, True, False] )
See also
estimate_max_neighbors: Quick estimation when box is not available.
Graph Neural Network Support#
- jax_md.custom_partition.graph_featurizer(displacement_fn=None)[source]#
Create graph featurizer for multi-image neighbor lists.
Converts a NeighborListMultiImage to a jraph.GraphsTuple with displacement vectors as edge features. Uses explicit lattice shifts instead of MIC.
This has the same signature as
jax_md._nn.util.neighbor_list_featurizerso it can be used as a drop-in replacement for multi-image neighbor lists.- Parameters:
displacement_fn – Displacement function from
space.free(). Must use free boundary conditions since periodicity is handled via explicit shifts. If None, defaults tospace.free()[0].- Returns:
featurize(atoms, positions, neighbor, **kwargs) -> GraphsTuple
- Raises:
ValueError – If displacement_fn is not from
space.free().
Example
>>> from jax_md import space >>> from jax_md.custom_partition import neighbor_list_multi_image, graph_featurizer >>> displacement_fn, _ = space.free() >>> neighbor_fn = neighbor_list_multi_image(None, box, r_cutoff) >>> featurizer = graph_featurizer(displacement_fn) >>> nbrs = neighbor_fn.allocate(positions) >>> graph = featurizer(atoms, positions, nbrs)