Multi-Image Neighbor Lists#

This module provides neighbor list construction for small periodic boxes where the cutoff radius exceeds half the box size (\(r_\text{cut} > L/2\)). In such cases, the standard minimum image convention (MIC) fails, and particles may interact with multiple periodic images of their neighbors.

Neighbor List Construction#

jax_md.custom_partition.neighbor_list_multi_image(displacement_or_metric, box, r_cutoff, dr_threshold=0.0, capacity_multiplier=1.25, pbc=None, fractional_coordinates=True, ordered=False, format=NeighborListFormat.Sparse, **kwargs)[source]#

Returns functions to build neighbor lists for small periodic boxes.

This function mirrors the API of jax_md.partition.neighbor_list but correctly handles small boxes where \(r_\text{cut} > L/2\) by explicitly enumerating periodic images. Works for any dimension.

Algorithm:

For each lattice direction \(i\), computes the number of shifts needed:

\[n_i = \lceil r_\text{cut} / h_i \rceil\]

where \(h_i\) is the perpendicular height of the box along direction \(i\). Then enumerates all integer shift vectors \(\mathbf{s} \in [-n_1, n_1] \times \ldots \times [-n_d, n_d]\) and finds pairs \((i, j)\) with:

\[\|\mathbf{r}_j + \mathbf{s} \cdot \mathbf{T} - \mathbf{r}_i\| < r_\text{cut}\]

Usage:

from jax_md.custom_partition import neighbor_list_multi_image

neighbor_fn = neighbor_list_multi_image(None, box, r_cutoff)
nbrs = neighbor_fn.allocate(R)

for _ in range(steps):
  nbrs = nbrs.update(state.position)
  if nbrs.did_buffer_overflow:
    nbrs = neighbor_fn.allocate(state.position)
  state = apply_fn(state, nbrs)
Parameters:
  • displacement_or_metric – Ignored. Accepted for API compatibility with partition.neighbor_list. Multi-image computes displacements using explicit lattice shifts.

  • box (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Affine transformation (see jax_md.space.periodic_general). Shape [dim, dim]. Columns are lattice vectors.

  • r_cutoff (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Interaction cutoff distance (scalar).

  • dr_threshold (float) – Maximum distance atoms can move before rebuilding. Set to 0 to always rebuild. Uses \(d_\text{max} < d_\text{thresh}/2\) as the skip condition.

  • capacity_multiplier (float) – Safety factor for neighbor list capacity.

  • pbc (Array | ndarray | None) – Boolean array indicating periodic directions. Shape [dim]. Default: all True.

  • fractional_coordinates (bool) – If True, positions are in fractional coordinates.

  • ordered (bool) – If True, use OrderedSparse format (one direction per pair). Uses 2x less memory. Ignored for Dense format.

  • format (NeighborListFormat) –

    Neighbor list format:

    • Sparse: Edge list (receivers, senders). Shape [capacity].

    • OrderedSparse: Like Sparse but only \(i < j\) pairs.

    • Dense: Per-atom neighbors. Shape [N, max_neighbors].

  • **kwargs – Additional arguments (ignored, for API compatibility).

Returns:

  • allocate(position, box=None): Create new neighbor list from positions [N, dim]. Pass box=new_box to rebuild for a different box geometry (e.g. during cell optimization).

  • update(position, neighbors): Update existing neighbor list.

Return type:

NeighborListMultiImageFns

jax_md.custom_partition.neighbor_list_multi_image_mask(neighbors)[source]#

Compute a boolean mask for valid edges in a neighbor list.

This is equivalent to jax_md.partition.neighbor_list_mask. An edge is valid if its receiver index is less than N (invalid edges are padded with index N).

Parameters:

neighbors (NeighborListMultiImage) – A NeighborListMultiImage (Sparse or OrderedSparse format).

Return type:

Array

Returns:

Boolean mask. Shape [capacity]. True indicates a valid edge.

Data Structures#

class jax_md.custom_partition.NeighborListMultiImage(idx, shifts, reference_position, reference_box, box, search_shifts, format, max_occupancy, update_fn, did_buffer_overflow=False)[source]#

A struct containing the state of a multi-image neighbor list.

This data structure is compatible with jax_md.partition.to_jraph and jax_md.partition.neighbor_list_mask. It stores edges between atoms, including all periodic images within the cutoff (not just the nearest).

Supports two storage formats:

Sparse/OrderedSparse:
  • idx: Tuple (receivers, senders), each shape (capacity,)

  • shifts: Shape (capacity, dim)

Dense:
  • idx: Shape (N, max_neighbors), neighbor indices for each atom

  • shifts: Shape (N, max_neighbors, dim), shift for each neighbor

idx#

For Sparse: tuple (receivers, senders). For Dense: array (N, max_neighbors). Invalid entries are padded with index N (number of atoms).

shifts#

Integer shift vectors. Sparse: (capacity, dim). Dense: (N, max_neighbors, dim). The real-space shift is shifts @ box.T.

reference_position#

Positions when the list was built, shape (N, dim).

reference_box#

Box when the list was built, shape (dim, dim).

box#

An affine transformation; see jax_md.space.periodic_general.

search_shifts#

Integer image stencil used to build this allocation.

format#

NeighborListFormat.Sparse, OrderedSparse, or Dense.

max_occupancy#

For Sparse: total capacity. For Dense: max_neighbors.

update_fn#

Function to update the neighbor list.

did_buffer_overflow#

True if more edges/neighbors were found than capacity allows.

box: Array#
did_buffer_overflow: Array | bool = False#
format: NeighborListFormat#
idx: Tuple[Array, Array] | Array#
property max_neighbors: int#

Maximum neighbors per atom (Dense format only).

max_occupancy: int#
property n_edges: int#

Number of valid edges (excluding padding).

property n_node: int#

Number of atoms (for to_jraph compatibility).

property receivers: Array#

Receiver atom indices (Sparse format only).

reference_box: Array#
reference_position: Array#
search_shifts: Array#
property senders: Array#

Sender atom indices (Sparse format only).

set(**kwargs)#
shifts: Array#
update(position, **kwargs)[source]#

Update neighbor list with new positions.

Return type:

NeighborListMultiImage

update_fn: Callable[[...], NeighborListMultiImage]#
class jax_md.custom_partition.NeighborListMultiImageFns(allocate, update)[source]#

A struct containing functions to allocate and update neighbor lists.

This mirrors the jax_md.partition.NeighborListFns interface.

allocate#

A function to allocate a new neighbor list. This function cannot be compiled, since it uses the values of positions to infer the shapes. Signature: (position: Array[N, dim], **kwargs) -> NeighborListMultiImage

update#

A function to update a neighbor list given a new set of positions and a previously allocated neighbor list. Signature: (position: Array[N, dim], neighbors: NeighborListMultiImage, **kwargs) -> NeighborListMultiImage

Capacity Estimation#

jax_md.custom_partition.estimate_max_neighbors(r_cutoff, atomic_density=0.1, safety_factor=2.0, dim=3)[source]#

Estimate maximum neighbors per atom from atomic density.

Quick estimation when you don’t have box/n_atoms information. Uses:

\[N_\text{neighbors} = \text{safety\_factor} \cdot \rho \cdot V_\text{sphere}\]

where \(\rho\) is the atomic density and \(V_\text{sphere}\) is the volume of a sphere with radius r_cutoff.

Parameters:
  • r_cutoff (float) – Interaction cutoff distance (scalar).

  • atomic_density (float) – Atomic density in atoms per unit volume. Default: 0.1. Common values: ~0.03 for gases, ~0.1 for liquids, ~0.15 for solids.

  • safety_factor (float) – Multiplier for safety margin. Default: 2.0.

  • dim (int) – Spatial dimension. Default: 3.

Return type:

int

Returns:

Estimated maximum neighbors per atom.

Example

from jax_md.custom_partition import (
    neighbor_list_multi_image,
    estimate_max_neighbors,
)

max_nbrs = estimate_max_neighbors(r_cutoff=2.5, atomic_density=0.1)
neighbor_fn = neighbor_list_multi_image(None, box, r_cutoff, max_neighbors=max_nbrs)

See also

estimate_max_neighbors_from_box: More accurate estimation when box and n_atoms are known, accounts for multiple periodic images.

jax_md.custom_partition.estimate_max_neighbors_from_box(box, r_cutoff, n_atoms, safety_factor=2.0, pbc=None)[source]#

Estimate maximum neighbors per atom from box and atom count.

Accurate estimation for multi-image neighbor lists, supporting both fully periodic and mixed boundary conditions.

\[N_\text{neighbors} = \text{safety\_factor} \cdot \rho \cdot V_\text{eff}\]

where \(\rho = N / V_\text{box}\) is the number density and \(V_\text{eff}\) is the effective cutoff volume (sphere for fully periodic, capped for non-periodic directions).

For non-periodic directions, the cutoff is capped at the box length since atoms beyond the box boundary don’t exist. This reduces the effective cutoff volume and thus the expected neighbor count.

Parameters:
  • box (Array | ndarray) – Affine transformation (see jax_md.space.periodic_general). Shape [dim, dim]. Columns are lattice vectors.

  • r_cutoff (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Interaction cutoff distance (scalar).

  • n_atoms (int) – Number of atoms in the system.

  • safety_factor (float) – Multiplier for safety margin. Default: 2.0.

  • pbc (Array | None) – Boolean array indicating periodic directions. Shape [dim]. Default: all True. For non-periodic directions, the cutoff is capped at the box length in that direction.

Return type:

int

Returns:

Estimated maximum neighbors per atom.

Raises:

ValueError – If spatial dimension is not 1, 2, or 3.

Example

from jax_md.custom_partition import (
    neighbor_list_multi_image,
    estimate_max_neighbors_from_box,
)

# Fully periodic
max_nbrs = estimate_max_neighbors_from_box(box, r_cutoff, n_atoms=N)

# Mixed: periodic in x,y but not z
max_nbrs = estimate_max_neighbors_from_box(
    box, r_cutoff, n_atoms=N, pbc=[True, True, False]
)
neighbor_fn = neighbor_list_multi_image(
    None, box, r_cutoff, max_neighbors=max_nbrs, pbc=[True, True, False]
)

See also

estimate_max_neighbors: Quick estimation when box is not available.

Graph Neural Network Support#

jax_md.custom_partition.graph_featurizer(displacement_fn=None)[source]#

Create graph featurizer for multi-image neighbor lists.

Converts a NeighborListMultiImage to a jraph.GraphsTuple with displacement vectors as edge features. Uses explicit lattice shifts instead of MIC.

This has the same signature as jax_md._nn.util.neighbor_list_featurizer so it can be used as a drop-in replacement for multi-image neighbor lists.

Parameters:

displacement_fn – Displacement function from space.free(). Must use free boundary conditions since periodicity is handled via explicit shifts. If None, defaults to space.free()[0].

Returns:

featurize(atoms, positions, neighbor, **kwargs) -> GraphsTuple

Raises:

ValueError – If displacement_fn is not from space.free().

Example

>>> from jax_md import space
>>> from jax_md.custom_partition import neighbor_list_multi_image, graph_featurizer
>>> displacement_fn, _ = space.free()
>>> neighbor_fn = neighbor_list_multi_image(None, box, r_cutoff)
>>> featurizer = graph_featurizer(displacement_fn)
>>> nbrs = neighbor_fn.allocate(positions)
>>> graph = featurizer(atoms, positions, nbrs)