Source code for jax_md.custom_partition

from functools import partial

import jax
import jax.numpy as jnp
from jax_md import dataclasses, partition, space
from jax_md.partition import NeighborListFormat
from typing import Callable, Tuple, Union, Optional

# Type aliases
Array = jnp.ndarray
f32 = jnp.float32
i32 = jnp.int32


[docs] @dataclasses.dataclass class NeighborListMultiImage: """A struct containing the state of a multi-image neighbor list. This data structure is compatible with `jax_md.partition.to_jraph` and `jax_md.partition.neighbor_list_mask`. It stores edges between atoms, including all periodic images within the cutoff (not just the nearest). Supports two storage formats: **Sparse/OrderedSparse**: - idx: Tuple (receivers, senders), each shape (capacity,) - shifts: Shape (capacity, dim) **Dense**: - idx: Shape (N, max_neighbors), neighbor indices for each atom - shifts: Shape (N, max_neighbors, dim), shift for each neighbor Attributes: idx: For Sparse: tuple (receivers, senders). For Dense: array (N, max_neighbors). Invalid entries are padded with index N (number of atoms). shifts: Integer shift vectors. Sparse: (capacity, dim). Dense: (N, max_neighbors, dim). The real-space shift is ``shifts @ box.T``. reference_position: Positions when the list was built, shape (N, dim). box: An affine transformation; see ``jax_md.space.periodic_general``. format: NeighborListFormat.Sparse, OrderedSparse, or Dense. max_occupancy: For Sparse: total capacity. For Dense: max_neighbors. update_fn: Function to update the neighbor list. did_buffer_overflow: True if more edges/neighbors were found than capacity allows. """ # Sparse/OrderedSparse: Tuple[Array[capacity], Array[capacity]] = (receivers, senders) # Dense: Array[N, max_neighbors] idx: Union[Tuple[Array, Array], Array] # Sparse/OrderedSparse: Array[capacity, dim] # Dense: Array[N, max_neighbors, dim] shifts: Array # real-space shift = shifts @ box.T reference_position: Array # [N, dim] box: Array # [dim, dim] format: NeighborListFormat = dataclasses.static_field() max_occupancy: int = dataclasses.static_field() update_fn: Callable[..., 'NeighborListMultiImage'] = ( dataclasses.static_field() ) did_buffer_overflow: bool = False
[docs] def update(self, position: Array, **kwargs) -> 'NeighborListMultiImage': """Update neighbor list with new positions.""" return self.update_fn(position, self, **kwargs)
@property def senders(self) -> Array: """Sender atom indices (Sparse format only).""" if self.format is NeighborListFormat.Dense: raise ValueError( 'senders property not available for Dense format. Use idx directly.' ) return self.idx[1] @property def receivers(self) -> Array: """Receiver atom indices (Sparse format only).""" if self.format is NeighborListFormat.Dense: raise ValueError( 'receivers property not available for Dense format. Use idx directly.' ) return self.idx[0] @property def n_edges(self) -> int: """Number of valid edges (excluding padding).""" N = len(self.reference_position) if self.format is NeighborListFormat.Dense: # Count valid entries in Dense format return int(jnp.sum(self.idx < N)) return int(jnp.sum(self.idx[0] < N)) @property def max_neighbors(self) -> int: """Maximum neighbors per atom (Dense format only).""" if self.format is not NeighborListFormat.Dense: raise ValueError( 'max_neighbors property only available for Dense format.' ) return self.idx.shape[1] @property def n_node(self) -> int: """Number of atoms (for to_jraph compatibility).""" return len(self.reference_position)
# Type alias for neighbor list functions # AllocateFn: (position: Array[N, dim], **kwargs) -> NeighborListMultiImage # UpdateFn: (position: Array[N, dim], neighbors: NeighborListMultiImage, **kwargs) -> NeighborListMultiImage AllocateFn = Callable[..., NeighborListMultiImage] UpdateFn = Callable[[Array, NeighborListMultiImage], NeighborListMultiImage]
[docs] @dataclasses.dataclass class NeighborListMultiImageFns: """A struct containing functions to allocate and update neighbor lists. This mirrors the `jax_md.partition.NeighborListFns` interface. Attributes: allocate: A function to allocate a new neighbor list. This function cannot be compiled, since it uses the values of positions to infer the shapes. Signature: `(position: Array[N, dim], **kwargs) -> NeighborListMultiImage` update: A function to update a neighbor list given a new set of positions and a previously allocated neighbor list. Signature: `(position: Array[N, dim], neighbors: NeighborListMultiImage, **kwargs) -> NeighborListMultiImage` """ allocate: AllocateFn = dataclasses.static_field() update: UpdateFn = dataclasses.static_field() def __iter__(self): """Allow unpacking: allocate_fn, update_fn = neighbor_fn.""" return iter((self.allocate, self.update))
def _compute_shift_ranges( box: Array, # [dim, dim] r_cutoff: float, pbc: Array, # [dim] ) -> Array: # [num_shifts, dim] r"""Compute integer shift vectors for multi-image neighbor search. For each lattice direction, determines how many periodic images are needed to capture all neighbors within :math:`r_\text{cut}`. Uses the reciprocal lattice to compute perpendicular box heights: .. math:: h_i = \frac{1}{\|\mathbf{b}_i\|} where :math:`\mathbf{b}_i` is the :math:`i`-th column of the inverse box transpose (i.e., the :math:`i`-th reciprocal lattice vector). The number of shifts along direction :math:`i` is :math:`n_i = \lceil r_\text{cut} / h_i \rceil`. The total number of shift vectors is :math:`\prod_i (2 n_i + 1)` for periodic directions. The real-space shift for a given integer shift vector :math:`\mathbf{s}` is :math:`\mathbf{s} \cdot \mathbf{T}` where :math:`\mathbf{T}` is the box matrix. This is the same algorithm used by ASE and matscipy for neighbor list construction with periodic boundary conditions. Args: box: Affine transformation (see ``periodic_general``). Shape ``[dim, dim]``. r_cutoff: Interaction cutoff distance (scalar). pbc: Boolean array indicating which directions are periodic. Shape ``[dim]``. Non-periodic directions get zero shifts. Returns: Integer shift vectors spanning the required range. Shape ``[num_shifts, dim]``. Each row is a shift vector :math:`(n_1, n_2, \ldots, n_d)`. """ # Reciprocal lattice vectors (columns of inv_box.T) inv_box_T = jnp.linalg.inv(box).T # [dim, dim] # Perpendicular heights of the box heights = 1.0 / jnp.linalg.norm(inv_box_T, axis=0) # [dim] # Number of shifts needed per direction. n_max = jnp.ceil(r_cutoff / heights).astype(i32) # [dim] n_max = jnp.where(pbc, n_max, 0) # [dim], zero for non-periodic # Build Cartesian product of shift ranges dim = box.shape[0] n_max_int = [int(n_max[i]) for i in range(dim)] ranges = [jnp.arange(-n, n + 1) for n in n_max_int] # List of [2*n+1] grids = jnp.meshgrid(*ranges, indexing='ij') return jnp.stack([g.ravel() for g in grids], axis=-1) # [num_shifts, dim] def _compute_distances_sq( position: Array, # [N, dim] box: Array, # [dim, dim] shifts_real: Array, # [num_shifts, dim] fractional_coordinates: bool, ) -> Array: # [num_shifts, N, N] r"""Compute squared distances for all (shift, i, j) combinations. For each shift :math:`\mathbf{s}` and atom pair :math:`(i, j)`: .. math:: d_{ij}^{\mathbf{s}2} = \|\mathbf{r}_j + \mathbf{s} \cdot \mathbf{T} - \mathbf{r}_i\|^2 Args: position: Atom positions. Shape ``[N, dim]``. box: Box matrix. Shape ``[dim, dim]``. shifts_real: Real-space shifts. Shape ``[num_shifts, dim]``. fractional_coordinates: If True, positions are fractional. Returns: Squared distances. Shape ``[num_shifts, N, N]``. """ if fractional_coordinates: position_real = position @ box.T # [N, dim] else: position_real = position D_all = ( position_real[None, None, :, :] # [1, 1, N, dim] + shifts_real[:, None, None, :] # [num_shifts, 1, 1, dim] - position_real[None, :, None, :] # [1, N, 1, dim] ) # [num_shifts, N, N, dim] return jnp.sum(D_all**2, axis=-1) # [num_shifts, N, N] def _compute_pairwise_mask( position: Array, # [N, dim] box: Array, # [dim, dim] shifts_real: Array, # [num_shifts, dim] zero_shift_idx: int, r_cutoff: float, fractional_coordinates: bool, ) -> Array: # [num_shifts, N, N] r"""Compute boolean mask for pairs within cutoff across all shifts. For each shift vector :math:`\mathbf{s}` and atom pair :math:`(i, j)`, computes whether the distance is within the cutoff: .. math:: \|\mathbf{r}_j + \mathbf{s} \cdot \mathbf{T} - \mathbf{r}_i\| < r_\text{cut} Self-interactions (:math:`i = j` with zero shift) are excluded. Uses ``vmap`` over shifts for better memory locality and fused computation. Args: position: Atom positions. Shape ``[N, dim]``. box: Affine transformation (see ``periodic_general``). Shape ``[dim, dim]``. shifts_real: Real-space shift vectors. Shape ``[num_shifts, dim]``. zero_shift_idx: Index of the zero shift in ``shifts_real``. r_cutoff: Cutoff distance (scalar). fractional_coordinates: If True, positions are fractional. Returns: Boolean mask. Shape ``[num_shifts, N, N]``. Entry ``[s, i, j]`` is True if atom ``j`` with shift ``s`` is a neighbor of atom ``i``. """ N = position.shape[0] num_shifts = shifts_real.shape[0] cutoff_sq = r_cutoff**2 # Convert to real coordinates if needed if fractional_coordinates: position_real = position @ box.T # [N, dim] else: position_real = position # Fused distance and mask computation per shift using vmap def mask_for_shift(shift): # shift: [dim] # D[i,j] = r_j + shift - r_i D = ( position_real[None, :, :] + shift - position_real[:, None, :] ) # [N, N, dim] dist_sq = jnp.sum(D**2, axis=-1) # [N, N] return dist_sq < cutoff_sq within_cutoff = jax.vmap(mask_for_shift)(shifts_real) # [num_shifts, N, N] # Exclude self-interactions (i == j with zero shift) self_mask = jnp.eye(N, dtype=bool) # [N, N] zero_shift_mask = jnp.arange(num_shifts) == zero_shift_idx # [num_shifts] self_interaction = zero_shift_mask[:, None, None] & self_mask[None, :, :] within_cutoff = within_cutoff & ~self_interaction # [num_shifts, N, N] return within_cutoff def _scatter_to_sparse( valid_mask: Array, # [num_shifts, N, N] shifts: Array, # [num_shifts, dim] capacity: int, N: int, ) -> Tuple[Array, Array, Array, Array]: r"""Scatter valid pairs into fixed-size sparse arrays. Uses cumulative sum for compaction of valid entries. Args: valid_mask: Boolean mask indicating valid pairs. Shape ``[num_shifts, N, N]``. shifts: Integer shift vectors. Shape ``[num_shifts, dim]``. capacity: Maximum number of edges to store. N: Number of atoms. Returns: Tuple ``(senders, receivers, edge_shifts, n_valid)``: - ``senders``: Shape ``[capacity]``. Padded with ``N``. - ``receivers``: Shape ``[capacity]``. Padded with ``N``. - ``edge_shifts``: Shape ``[capacity, dim]``. - ``n_valid``: Total number of valid pairs (scalar). """ num_shifts = shifts.shape[0] dim = shifts.shape[1] valid_flat = valid_mask.ravel() # [num_shifts * N * N] # Create index grids s_grid, i_grid, j_grid = jnp.meshgrid( jnp.arange(num_shifts), jnp.arange(N), jnp.arange(N), indexing='ij' ) # each [num_shifts, N, N] s_flat, i_flat, j_flat = s_grid.ravel(), i_grid.ravel(), j_grid.ravel() # Compact using cumsum n_valid = jnp.sum(valid_flat) cumsum = jnp.cumsum(valid_flat) - 1 # [num_shifts * N * N] # Pre-allocate with padding index = N senders = N * jnp.ones(capacity, dtype=i32) # [capacity] receivers = N * jnp.ones(capacity, dtype=i32) # [capacity] edge_shifts = jnp.zeros( (capacity, dim), dtype=shifts.dtype ) # [capacity, dim] # Scatter valid entries write_idx = jnp.where(valid_flat, cumsum, capacity) write_mask = valid_flat & (cumsum < capacity) senders = senders.at[write_idx].set( jnp.where(write_mask, j_flat, N), mode='drop' ) receivers = receivers.at[write_idx].set( jnp.where(write_mask, i_flat, N), mode='drop' ) shift_vals = shifts[s_flat] # [num_shifts * N * N, dim] shift_vals_masked = jnp.where(write_mask[:, None], shift_vals, 0) edge_shifts = edge_shifts.at[write_idx].set(shift_vals_masked, mode='drop') return senders, receivers, edge_shifts, n_valid def _build_neighbor_list_sparse( position: Array, # [N, dim] box: Array, # [dim, dim] shifts: Array, # [num_shifts, dim] shifts_real: Array, # [num_shifts, dim] zero_shift_idx: int, r_cutoff: float, capacity: int, fractional_coordinates: bool, update_fn: Callable, ) -> NeighborListMultiImage: r"""Build neighbor list in Sparse format. Stores **both directions** for each pair: if :math:`(i, j)` is a neighbor, both :math:`i \to j` and :math:`j \to i` are stored. Required for GNNs and asymmetric potentials. Args: position: Atom positions. Shape ``[N, dim]``. box: Box matrix. Shape ``[dim, dim]``. shifts: Integer shift vectors. Shape ``[num_shifts, dim]``. shifts_real: Real-space shifts (``shifts @ box.T``). Shape ``[num_shifts, dim]``. zero_shift_idx: Index of the zero shift vector. r_cutoff: Cutoff distance. capacity: Maximum edges to store. fractional_coordinates: If True, positions are fractional. update_fn: Function to update the neighbor list. Returns: NeighborListMultiImage with ``format=Sparse``: - ``idx``: ``(receivers, senders)`` each shape ``[capacity]``. - ``shifts``: Shape ``[capacity, dim]``. """ within_cutoff = _compute_pairwise_mask( position, box, shifts_real, zero_shift_idx, r_cutoff, fractional_coordinates ) # [num_shifts, N, N] N = position.shape[0] senders, receivers, edge_shifts, n_valid = _scatter_to_sparse( within_cutoff, shifts, capacity, N ) return NeighborListMultiImage( idx=(receivers, senders), shifts=edge_shifts, reference_position=position, box=box, format=NeighborListFormat.Sparse, did_buffer_overflow=(n_valid > capacity), max_occupancy=capacity, update_fn=update_fn, ) def _build_neighbor_list_orderedsparse( position: Array, # [N, dim] box: Array, # [dim, dim] shifts: Array, # [num_shifts, dim] shifts_real: Array, # [num_shifts, dim] zero_shift_idx: int, r_cutoff: float, capacity: int, fractional_coordinates: bool, update_fn: Callable, ) -> NeighborListMultiImage: r"""Build neighbor list in OrderedSparse format. Stores **one direction** per pair to avoid double-counting. Uses 2x less memory than Sparse format. **Ordering rules:** - **Zero shift** (:math:`\mathbf{s} = \mathbf{0}`): Store only :math:`i < j`. - **Non-zero shift**: A shift is "canonical" if its first non-zero component is positive. Only canonical shifts are stored. Args: position: Atom positions. Shape ``[N, dim]``. box: Box matrix. Shape ``[dim, dim]``. shifts: Integer shift vectors. Shape ``[num_shifts, dim]``. shifts_real: Real-space shifts (``shifts @ box.T``). Shape ``[num_shifts, dim]``. zero_shift_idx: Index of the zero shift vector. r_cutoff: Cutoff distance. capacity: Maximum edges to store. fractional_coordinates: If True, positions are fractional. update_fn: Function to update the neighbor list. Returns: NeighborListMultiImage with ``format=OrderedSparse``: - ``idx``: ``(receivers, senders)`` each shape ``[capacity]``. - ``shifts``: Shape ``[capacity, dim]``. """ N = position.shape[0] num_shifts = shifts.shape[0] within_cutoff = _compute_pairwise_mask( position, box, shifts_real, zero_shift_idx, r_cutoff, fractional_coordinates ) # [num_shifts, N, N] # Apply ordering to eliminate double-counting i_idx = jnp.arange(N)[None, :, None] # [1, N, 1] j_idx = jnp.arange(N)[None, None, :] # [1, 1, N] zero_shift_mask = jnp.arange(num_shifts) == zero_shift_idx # [num_shifts] def is_shift_canonical(s): """Check if shift is canonical (first non-zero component positive).""" nonzero_mask = s != 0 first_nonzero_idx = jnp.argmax(nonzero_mask) first_val = s[first_nonzero_idx] is_zero = jnp.all(s == 0) return jnp.where(is_zero, True, first_val > 0) shift_is_canonical = jax.vmap(is_shift_canonical)(shifts) # [num_shifts] # Keep mask: zero shift -> i < j, non-zero -> canonical shifts only keep_mask = jnp.where( zero_shift_mask[:, None, None], i_idx < j_idx, shift_is_canonical[:, None, None], ) # [num_shifts, N, N] within_cutoff = within_cutoff & keep_mask senders, receivers, edge_shifts, n_valid = _scatter_to_sparse( within_cutoff, shifts, capacity, N ) return NeighborListMultiImage( idx=(receivers, senders), shifts=edge_shifts, reference_position=position, box=box, format=NeighborListFormat.OrderedSparse, did_buffer_overflow=(n_valid > capacity), max_occupancy=capacity, update_fn=update_fn, ) def _build_neighbor_list_dense( position: Array, # [N, dim] box: Array, # [dim, dim] shifts: Array, # [num_shifts, dim] shifts_real: Array, # [num_shifts, dim] zero_shift_idx: int, r_cutoff: float, max_neighbors: int, fractional_coordinates: bool, update_fn: Callable, ) -> NeighborListMultiImage: r"""Build neighbor list in Dense format. Dense format stores neighbors per atom as a ``[N, max_neighbors]`` array, enabling efficient three-body potential computation via vectorized operations. For each atom :math:`i`, finds the closest ``max_neighbors`` atoms :math:`j` (with shifts :math:`\mathbf{s}`) satisfying: .. math:: \|\mathbf{r}_j + \mathbf{s} \cdot \mathbf{T} - \mathbf{r}_i\| < r_\text{cut} Uses ``argsort`` to select top-k neighbors per atom. Args: position: Atom positions. Shape ``[N, dim]``. box: Affine transformation (see ``periodic_general``). Shape ``[dim, dim]``. shifts: Integer shift vectors. Shape ``[num_shifts, dim]``. shifts_real: Real-space shifts (``shifts @ box.T``). Shape ``[num_shifts, dim]``. zero_shift_idx: Index of the zero shift vector. r_cutoff: Cutoff distance (scalar). max_neighbors: Maximum neighbors per atom. fractional_coordinates: If True, positions are fractional. update_fn: Function to update the neighbor list. Returns: NeighborListMultiImage with ``format=Dense``: - ``idx``: Neighbor indices. Shape ``[N, max_neighbors]``. Padded with ``N``. - ``shifts``: Shift vectors. Shape ``[N, max_neighbors, dim]``. """ N = position.shape[0] num_shifts = shifts.shape[0] # Compute squared distances for all (shift, i, j) dist_sq = _compute_distances_sq( position, box, shifts_real, fractional_coordinates ) # [num_shifts, N, N] # Valid neighbors: within cutoff, excluding self (i==j with zero shift) within_cutoff = dist_sq < r_cutoff**2 # [num_shifts, N, N] self_mask = jnp.eye(N, dtype=bool) # [N, N] zero_shift_mask = jnp.arange(num_shifts) == zero_shift_idx # [num_shifts] self_interaction = zero_shift_mask[:, None, None] & self_mask[None, :, :] valid = within_cutoff & ~self_interaction # [num_shifts, N, N] # Reshape to per-atom view: [N, num_shifts * N] # Transpose [num_shifts, N, N] -> [N, num_shifts, N] -> [N, num_shifts * N] valid_per_atom = valid.transpose(1, 0, 2).reshape(N, num_shifts * N) dist_sq_per_atom = dist_sq.transpose(1, 0, 2).reshape(N, num_shifts * N) # Set invalid distances to inf for sorting dist_for_sort = jnp.where(valid_per_atom, dist_sq_per_atom, jnp.inf) # Select top-k closest neighbors per atom via argsort top_k_flat_idx = jnp.argsort(dist_for_sort, axis=-1)[ :, :max_neighbors ] # [N, max_neighbors] # Decode flat index -> (shift_idx, j) neighbor_shift_idx = top_k_flat_idx // N # [N, max_neighbors] neighbor_j = top_k_flat_idx % N # [N, max_neighbors] # Gather shift vectors neighbor_shifts = shifts[neighbor_shift_idx] # [N, max_neighbors, dim] # Check which entries are actually valid (not inf padding) gathered_valid = jnp.take_along_axis( valid_per_atom, top_k_flat_idx, axis=-1 ) # [N, max_neighbors] # Replace invalid entries with padding sentinel N neighbor_idx = jnp.where(gathered_valid, neighbor_j, N) # [N, max_neighbors] neighbor_shifts = jnp.where( gathered_valid[:, :, None], neighbor_shifts, 0 ) # [N, max_neighbors, dim] # Check for overflow total_valid_per_atom = jnp.sum(valid_per_atom, axis=-1) # [N] did_overflow = jnp.any(total_valid_per_atom > max_neighbors) return NeighborListMultiImage( idx=neighbor_idx, shifts=neighbor_shifts, reference_position=position, box=box, format=NeighborListFormat.Dense, did_buffer_overflow=did_overflow, max_occupancy=max_neighbors, update_fn=update_fn, )
[docs] def estimate_max_neighbors( r_cutoff: float, atomic_density: float = 0.1, safety_factor: float = 2.0, dim: int = 3, ) -> int: r"""Estimate maximum neighbors per atom from atomic density. Quick estimation when you don't have box/n_atoms information. Uses: .. math:: N_\text{neighbors} = \text{safety\_factor} \cdot \rho \cdot V_\text{sphere} where :math:`\rho` is the atomic density and :math:`V_\text{sphere}` is the volume of a sphere with radius ``r_cutoff``. Args: r_cutoff: Interaction cutoff distance (scalar). atomic_density: Atomic density in atoms per unit volume. Default: 0.1. Common values: ~0.03 for gases, ~0.1 for liquids, ~0.15 for solids. safety_factor: Multiplier for safety margin. Default: 2.0. dim: Spatial dimension. Default: 3. Returns: Estimated maximum neighbors per atom. Example: .. code-block:: python from jax_md.custom_partition import ( neighbor_list_multi_image, estimate_max_neighbors, ) max_nbrs = estimate_max_neighbors(r_cutoff=2.5, atomic_density=0.1) neighbor_fn = neighbor_list_multi_image(None, box, r_cutoff, max_neighbors=max_nbrs) See Also: ``estimate_max_neighbors_from_box``: More accurate estimation when box and n_atoms are known, accounts for multiple periodic images. """ if r_cutoff <= 0: return 0 if dim > 3: raise ValueError(f'dim must be 1, 2, or 3, got {dim}') # Compute sphere volume based on dimension if dim == 3: sphere_volume = 4.0 / 3.0 * jnp.pi * r_cutoff**3 elif dim == 2: sphere_volume = jnp.pi * r_cutoff**2 else: # dim == 1 sphere_volume = 2.0 * r_cutoff expected_neighbors = safety_factor * atomic_density * sphere_volume return max(int(jnp.ceil(expected_neighbors)), 1)
[docs] def estimate_max_neighbors_from_box( box: Array, # [dim, dim] r_cutoff: float, n_atoms: int, safety_factor: float = 2.0, pbc: Array | None = None, # [dim] ) -> int: r"""Estimate maximum neighbors per atom from box and atom count. Accurate estimation for multi-image neighbor lists, supporting both fully periodic and mixed boundary conditions. .. math:: N_\text{neighbors} = \text{safety\_factor} \cdot \rho \cdot V_\text{eff} where :math:`\rho = N / V_\text{box}` is the number density and :math:`V_\text{eff}` is the effective cutoff volume (sphere for fully periodic, capped for non-periodic directions). For non-periodic directions, the cutoff is capped at the box length since atoms beyond the box boundary don't exist. This reduces the effective cutoff volume and thus the expected neighbor count. Args: box: Affine transformation (see ``jax_md.space.periodic_general``). Shape ``[dim, dim]``. Columns are lattice vectors. r_cutoff: Interaction cutoff distance (scalar). n_atoms: Number of atoms in the system. safety_factor: Multiplier for safety margin. Default: 2.0. pbc: Boolean array indicating periodic directions. Shape ``[dim]``. Default: all True. For non-periodic directions, the cutoff is capped at the box length in that direction. Returns: Estimated maximum neighbors per atom. Raises: ValueError: If spatial dimension is not 1, 2, or 3. Example: .. code-block:: python from jax_md.custom_partition import ( neighbor_list_multi_image, estimate_max_neighbors_from_box, ) # Fully periodic max_nbrs = estimate_max_neighbors_from_box(box, r_cutoff, n_atoms=N) # Mixed: periodic in x,y but not z max_nbrs = estimate_max_neighbors_from_box( box, r_cutoff, n_atoms=N, pbc=[True, True, False] ) neighbor_fn = neighbor_list_multi_image( None, box, r_cutoff, max_neighbors=max_nbrs, pbc=[True, True, False] ) See Also: ``estimate_max_neighbors``: Quick estimation when box is not available. """ if r_cutoff <= 0: return 0 box = jnp.asarray(box) dim = box.shape[0] if dim > 3: raise ValueError(f'dim must be 1, 2, or 3, got {dim}') if pbc is None: pbc = jnp.ones(dim, dtype=bool) pbc = jnp.asarray(pbc) # Compute density from box box_volume = float(jnp.abs(jnp.linalg.det(box))) density = n_atoms / box_volume # For non-periodic directions, cap the effective cutoff at the box length. # This is because atoms can only exist within the box in non-periodic directions. # Compute box lengths along each axis (diagonal for orthorhombic, more complex otherwise) box_lengths = jnp.linalg.norm( box, axis=0 ) # [dim] - length of each lattice vector # Effective cutoff per direction: r_cutoff if periodic, min(r_cutoff, L) if not r_eff = jnp.where(pbc, r_cutoff, jnp.minimum(r_cutoff, box_lengths)) # For computing effective volume, use the minimum effective cutoff # This is conservative (may overestimate slightly for anisotropic boxes) r_eff_min = float(jnp.min(r_eff)) # Compute sphere/circle/line volume based on dimension using effective cutoff if dim == 3: sphere_volume = 4.0 / 3.0 * jnp.pi * r_eff_min**3 elif dim == 2: sphere_volume = jnp.pi * r_eff_min**2 else: # dim == 1 sphere_volume = 2.0 * r_eff_min # For multi-image neighbor lists, when r_cutoff > L/2, the cutoff sphere # extends into multiple periodic images. However, the atoms in those images # are the same physical atoms, just translated. The expected neighbor count # is still density * sphere_volume (the number of atoms within the sphere), # regardless of how many images are involved. expected_neighbors = safety_factor * density * sphere_volume return max(int(jnp.ceil(expected_neighbors)), 1)
[docs] def neighbor_list_multi_image( displacement_or_metric, # Ignored, for API compatibility box: Array, # [dim, dim] r_cutoff: float, dr_threshold: float = 0.0, capacity_multiplier: float = 1.25, pbc: Array | None = None, # [dim] fractional_coordinates: bool = True, ordered: bool = False, format: NeighborListFormat = NeighborListFormat.Sparse, **kwargs, ) -> NeighborListMultiImageFns: r"""Returns functions to build neighbor lists for small periodic boxes. This function mirrors the API of ``jax_md.partition.neighbor_list`` but correctly handles small boxes where :math:`r_\text{cut} > L/2` by explicitly enumerating periodic images. Works for any dimension. **Algorithm:** For each lattice direction :math:`i`, computes the number of shifts needed: .. math:: n_i = \lceil r_\text{cut} / h_i \rceil where :math:`h_i` is the perpendicular height of the box along direction :math:`i`. Then enumerates all integer shift vectors :math:`\mathbf{s} \in [-n_1, n_1] \times \ldots \times [-n_d, n_d]` and finds pairs :math:`(i, j)` with: .. math:: \|\mathbf{r}_j + \mathbf{s} \cdot \mathbf{T} - \mathbf{r}_i\| < r_\text{cut} **Usage:** .. code-block:: python from jax_md.custom_partition import neighbor_list_multi_image neighbor_fn = neighbor_list_multi_image(None, box, r_cutoff) nbrs = neighbor_fn.allocate(R) for _ in range(steps): nbrs = nbrs.update(state.position) if nbrs.did_buffer_overflow: nbrs = neighbor_fn.allocate(state.position) state = apply_fn(state, nbrs) Args: displacement_or_metric: Ignored. Accepted for API compatibility with ``partition.neighbor_list``. Multi-image computes displacements using explicit lattice shifts. box: Affine transformation (see ``jax_md.space.periodic_general``). Shape ``[dim, dim]``. Columns are lattice vectors. r_cutoff: Interaction cutoff distance (scalar). dr_threshold: Maximum distance atoms can move before rebuilding. Set to 0 to always rebuild. Uses :math:`d_\text{max} < d_\text{thresh}/2` as the skip condition. capacity_multiplier: Safety factor for neighbor list capacity. pbc: Boolean array indicating periodic directions. Shape ``[dim]``. Default: all True. fractional_coordinates: If True, positions are in fractional coordinates. ordered: If True, use OrderedSparse format (one direction per pair). Uses 2x less memory. Ignored for Dense format. format: Neighbor list format: - ``Sparse``: Edge list ``(receivers, senders)``. Shape ``[capacity]``. - ``OrderedSparse``: Like Sparse but only :math:`i < j` pairs. - ``Dense``: Per-atom neighbors. Shape ``[N, max_neighbors]``. **kwargs: Additional arguments (ignored, for API compatibility). Returns: ``NeighborListMultiImageFns`` with: - ``allocate(position, box=None)``: Create new neighbor list from positions ``[N, dim]``. Pass ``box=new_box`` to rebuild for a different box geometry (e.g. during cell optimization). - ``update(position, neighbors)``: Update existing neighbor list. """ del displacement_or_metric # Unused - multi-image uses explicit shifts default_box = jnp.asarray(box) # [dim, dim] dim = default_box.shape[0] use_dense = format is NeighborListFormat.Dense use_ordered = ordered or (format is NeighborListFormat.OrderedSparse) if pbc is None: pbc = jnp.ones(dim, dtype=bool) pbc = jnp.asarray(pbc) # [dim] # Add dr_threshold as a "skin" buffer to the cutoff for neighbor search. search_cutoff = r_cutoff + dr_threshold # Pre-compute shift vectors using reciprocal lattice heights shifts = _compute_shift_ranges( default_box, search_cutoff, pbc ) # [num_shifts, dim] zero_shift_idx = int(jnp.argmin(jnp.sum(shifts**2, axis=1))) num_shifts = shifts.shape[0] # Displacement threshold for skipping rebuild threshold_sq = (dr_threshold / 2.0) ** 2 # Placeholder for circular reference in NeighborListMultiImage.update def update_fn_placeholder(position, neighbors, **kwargs): raise NotImplementedError() update_fn_ref = [update_fn_placeholder] # Cache for JIT-compiled build functions per capacity # This avoids recompilation when N stays constant across calls build_fn_cache = {} def _initial_probe_capacity(N: int) -> int: """Geometry-based probe capacity estimate.""" est = estimate_max_neighbors_from_box( default_box, search_cutoff, N, safety_factor=5.0, pbc=pbc ) npa = max(int(est * capacity_multiplier), num_shifts) if use_dense: return npa cap = N * npa if use_ordered: cap = cap // 2 + N return cap # Select format-specific build function once at construction time if use_dense: build_nl_fn = _build_neighbor_list_dense elif use_ordered: build_nl_fn = _build_neighbor_list_orderedsparse else: build_nl_fn = _build_neighbor_list_sparse def make_build_fn(capacity, nl_shifts, zero_idx): """Create a build function for given capacity and shifts.""" @jax.jit def build_fn(pos, box): shifts_real = nl_shifts @ box.T return build_nl_fn( pos, box, nl_shifts, shifts_real, zero_idx, search_cutoff, capacity, fractional_coordinates, update_fn_ref[0], ) return build_fn def get_build_fn(capacity: int, box_override=None): """Get or create a build function for given capacity. When ``box_override`` is provided (only from ``allocate_fn``, never inside JIT) the integer shift vectors are recomputed for the new geometry. These are not cached because different boxes can produce the same shift count but different shift vectors. Without ``box_override``, the precomputed shifts for the default box are used and the result is cached by capacity. """ if box_override is not None: override_box = jnp.asarray(box_override) nl_shifts = _compute_shift_ranges(override_box, search_cutoff, pbc) zero_idx = int(jnp.argmin(jnp.sum(nl_shifts**2, axis=1))) return make_build_fn(capacity, nl_shifts, zero_idx) if capacity in build_fn_cache: return build_fn_cache[capacity] build_fn = make_build_fn(capacity, shifts, zero_shift_idx) build_fn_cache[capacity] = build_fn return build_fn @jax.jit def check_needs_rebuild( position: Array, # [N, dim] reference_position: Array, # [N, dim] ) -> Array: # scalar bool """Check if maximum displacement exceeds threshold.""" if fractional_coordinates: pos_new = position @ default_box.T # [N, dim] pos_old = reference_position @ default_box.T else: pos_new = position pos_old = reference_position max_disp_sq = jnp.max(jnp.sum((pos_new - pos_old) ** 2, axis=-1)) return max_disp_sq >= threshold_sq # Choose update strategy at function creation time (not trace time) use_threshold = dr_threshold > 0 def allocate_fn( position: Array, extra_capacity: int = 0, box: Array = None, **kwargs ) -> NeighborListMultiImage: """Allocate a new neighbor list from positions [N, dim]. Args: position: Atom positions. Shape ``[N, dim]``. extra_capacity: Additional capacity to add (multiplied by N for Sparse). Use this to recover from buffer overflow. box: Override box for this allocation. If provided, shift vectors are recomputed for the new box geometry. This is needed when the cell changes during cell optimization. Returns: New neighbor list. """ position = jnp.asarray(position) N = position.shape[0] _extra = extra_capacity if use_dense else N * extra_capacity current_box = jnp.asarray(box) if box is not None else default_box # Probe with geometry-based estimate; retry with 2x on overflow. cap = _initial_probe_capacity(N) + _extra while True: probe_fn = get_build_fn(cap, box_override=box) probe = probe_fn(position, current_box) if not probe.did_buffer_overflow: break cap = cap * 2 if use_dense: actual_occ = int(jnp.max(jnp.sum(probe.idx < N, axis=1))) else: actual_occ = int(jnp.sum(probe.idx[0] < N)) max_occupancy = max(int(actual_occ * capacity_multiplier) + _extra, 1) if max_occupancy == cap: return probe build_fn = get_build_fn(max_occupancy, box_override=box) return build_fn(position, current_box) def neighbor_list_fn( position: Array, # [N, dim] neighbors: NeighborListMultiImage | None = None, **kwargs, ) -> NeighborListMultiImage: """Build or update neighbor list. Args: position: Atom positions. Shape ``[N, dim]``. neighbors: Existing neighbor list, or None to allocate new. **kwargs: Accepts ``box=`` to rebuild with a different box geometry. Returns: Updated neighbor list. """ position = jnp.asarray(position) if neighbors is None: # First call: allocate with capacity computed from position return allocate_fn(position, **kwargs) position = position.astype(neighbors.reference_position.dtype) # Update: reuse existing capacity and precomputed integer shifts. # Real-space shifts (shifts @ box.T) are recomputed inside build_fn # using the traced box argument, so this is safe inside JIT. current_box = jnp.asarray(kwargs.get('box', default_box)) capacity = neighbors.max_occupancy build_fn = get_build_fn(capacity) # If box= was provided the geometry has changed; always rebuild. # The displacement threshold only applies when the box is unchanged. if use_threshold and 'box' not in kwargs: return jax.lax.cond( check_needs_rebuild(position, neighbors.reference_position), lambda pos: build_fn(pos, current_box), lambda pos: neighbors, position, ) return build_fn(position, current_box) # Close the circular reference update_fn_ref[0] = neighbor_list_fn return NeighborListMultiImageFns(allocate_fn, neighbor_list_fn)
[docs] def neighbor_list_multi_image_mask( neighbors: NeighborListMultiImage, ) -> Array: # [capacity] r"""Compute a boolean mask for valid edges in a neighbor list. This is equivalent to ``jax_md.partition.neighbor_list_mask``. An edge is valid if its receiver index is less than ``N`` (invalid edges are padded with index ``N``). Args: neighbors: A ``NeighborListMultiImage`` (Sparse or OrderedSparse format). Returns: Boolean mask. Shape ``[capacity]``. True indicates a valid edge. """ N = len(neighbors.reference_position) return neighbors.idx[0] < N # [capacity]
[docs] def graph_featurizer(displacement_fn=None): """Create graph featurizer for multi-image neighbor lists. Converts a NeighborListMultiImage to a jraph.GraphsTuple with displacement vectors as edge features. Uses explicit lattice shifts instead of MIC. This has the same signature as ``jax_md._nn.util.neighbor_list_featurizer`` so it can be used as a drop-in replacement for multi-image neighbor lists. Args: displacement_fn: Displacement function from ``space.free()``. Must use free boundary conditions since periodicity is handled via explicit shifts. If None, defaults to ``space.free()[0]``. Returns: featurize(atoms, positions, neighbor, **kwargs) -> GraphsTuple Raises: ValueError: If displacement_fn is not from ``space.free()``. Example: >>> from jax_md import space >>> from jax_md.custom_partition import neighbor_list_multi_image, graph_featurizer >>> displacement_fn, _ = space.free() >>> neighbor_fn = neighbor_list_multi_image(None, box, r_cutoff) >>> featurizer = graph_featurizer(displacement_fn) >>> nbrs = neighbor_fn.allocate(positions) >>> graph = featurizer(atoms, positions, nbrs) """ if displacement_fn is None: displacement_fn, _ = space.free() def featurize( atoms, positions: Array, neighbor: NeighborListMultiImage, **kwargs ): """Convert neighbor list to graph with displacement edges. Args: atoms: Node features dict (e.g., {'species': ..., 'positions': ...}). positions: Atom positions (fractional), shape (N, dim). neighbor: Multi-image neighbor list. **kwargs: Additional arguments: - ``box``: Override box for coordinate transformation (for NPT). - ``perturbation``: Strain perturbation for stress via ``quantity.stress``. Returns: GraphsTuple with displacement vectors as edges. """ graph = partition.to_jraph(neighbor, nodes=atoms) mask = neighbor_list_multi_image_mask(neighbor) # Use box from kwargs if provided, else from neighbor list box = kwargs.pop('box', neighbor.box) # Get positions and compute displacements with explicit shifts pos_real = space.transform(box, positions) Ra = pos_real[neighbor.receivers] Rb = pos_real[neighbor.senders] shifts_real = space.transform(box, neighbor.shifts) # dR = (R_j + shift) - R_i, using displacement_fn which handles perturbation Rb_shifted = Rb + shifts_real d = jax.vmap(partial(displacement_fn, **kwargs)) dR = d(Ra, Rb_shifted) # Set masked edges to displacement 1 dR = jnp.where(mask[:, None], dR, 1.0) return graph._replace(edges=dR) return featurize
def _compute_displacements( position: Array, # [N, dim] neighbors: NeighborListMultiImage, fractional_coordinates: bool = True, ) -> Array: # [capacity, dim] r"""Compute displacement vectors for all edges in a neighbor list. For each edge from receiver :math:`i` to sender :math:`j` with shift :math:`\mathbf{s}`, computes: .. math:: \mathbf{d}_{ij}^{\mathbf{s}} = \mathbf{r}_j + \mathbf{s} \cdot \mathbf{T} - \mathbf{r}_i where :math:`\mathbf{T}` is the box matrix. Uses ``space.transform`` for coordinate conversion so that gradients w.r.t. fractional inputs are real-space gradients (see ``space.transform_jvp``). Args: position: Atom positions. Shape ``[N, dim]``. neighbors: A ``NeighborListMultiImage`` (Sparse or OrderedSparse format). fractional_coordinates: If True, positions are in fractional coordinates. Returns: Displacement vectors in Cartesian coordinates. Shape ``[capacity, dim]``. Invalid edges are set to zero. Use ``neighbor_list_mask(neighbors)`` to filter valid edges. """ box = neighbors.box # [dim, dim] N = position.shape[0] mask = neighbor_list_multi_image_mask(neighbors) # [capacity] if fractional_coordinates: position_real = space.transform(box, position) # [N, dim] else: position_real = position # Safe indexing for padding (clip to valid range) i_safe = jnp.clip(neighbors.receivers, 0, N - 1) # [capacity] j_safe = jnp.clip(neighbors.senders, 0, N - 1) # [capacity] shifts_real = space.transform(box, neighbors.shifts) # [capacity, dim] # Displacement: r_j + shift - r_i dR = ( position_real[j_safe] + shifts_real - position_real[i_safe] ) # [capacity, dim] return jnp.where(mask[:, None], dR, 0.0)