Multi-Image Structure Maps#
Structure-mapped functions for multi-image neighbor lists.
This module provides analogues to jax_md.smap functions that work with
NeighborListMultiImage to correctly handle small periodic boxes where
\(r_\text{cut} > L/2\).
The key difference from standard smap functions is that these use explicit
lattice shifts stored in the neighbor list rather than relying on the minimum
image convention.
This module provides structure-mapped functions (analogous to jax_md.smap)
that work with NeighborListMultiImage to correctly
handle small periodic boxes where \(r_\text{cut} > L/2\).
The key difference from standard smap functions is that these use explicit
lattice shifts stored in the neighbor list rather than relying on the minimum
image convention.
Pair Functions#
- jax_md.custom_smap.pair_neighbor_list_multi_image(pair_fn, displacement_fn=None, species=None, reduce_axis=None, ignore_unused_parameters=False, fractional_coordinates=True, **static_kwargs)[source]#
Creates a function for pair potentials using multi-image neighbors.
This function is analogous to
jax_md.smap.pair_neighbor_listbut works withNeighborListMultiImageto correctly handle small periodic boxes where \(r_\text{cut} > L/2\).Displacement Function:
Always uses
space.free()for displacement computation since periodicity is handled via explicit lattice shifts. Thedisplacement_fnparameter is ignored (for API compatibility withsmap.pair_neighbor_list). Supportsperturbationkwarg for stress calculation viaquantity.stress.For each edge \((i, j)\) with shift \(\mathbf{s}\), computes:
\[E_{ij}^{\mathbf{s}} = f\left(\|\mathbf{r}_j + \mathbf{s} \cdot \mathbf{T} - \mathbf{r}_i\|\right)\]where \(f\) is the pair function and \(\mathbf{T}\) is the box matrix.
The pair function should have the signature:
pair_fn(dr, **kwargs) -> energy
where
dris an array of scalar pairwise distances of shape[capacity]. This matches the signature expected bysmap.pair_neighbor_list, so the same pair function (e.g.,energy.lennard_jones) works with both.Gradient handling:
Uses
space.transformfor coordinate transformations, which has a custom JVP that keeps gradients in the same coordinate system as inputs. When using fractional coordinates,jax.grad(energy_fn)returns forces in fractional coordinates (compatible with fractional-coordinate dynamics).Format handling:
Sparse: Both \(i \\to j\) and \(j \\to i\) are stored, so energies are divided by 2. Supports per-particle energies.OrderedSparse: Only one direction per pair, no division needed. Does not support per-particle energies (raisesValueError).Dense: Stores neighbors per atom as[N, max_neighbors]. Both directions are stored, so energies are divided by 2. Supports per-particle energies via sum over axis 1.
- Parameters:
pair_fn (
Callable[...,Array]) – A function that computes pairwise energies from distances. Examples:energy.lennard_jones,energy.morse,energy.soft_sphere. Signature:(dr: Array[capacity], **kwargs) -> Array[capacity].displacement_fn – Ignored. Always uses
space.free()internally.species (
Array|None) – Optional species array. Shape[N]. If provided, kwargs likesigmaandepsilonshould have shape[max_species, max_species]and will be indexed per-pair.fractional_coordinates (
bool) – If True, positions are in fractional coordinates.reduce_axis (
Optional[Tuple[int,...]]) – Axis over which to reduce the energy. IfNone(default), sums all pair energies to a scalar. If specified, returns per-atom energies of shape[N]. Note: Per-atom energies are not supported withOrderedSparseformat (raisesValueError).**static_kwargs – Static parameters passed to the pair function (e.g.,
sigma,epsilon). Can be overridden at call time.
- Returns:
energy_fn(R, neighbor, **kwargs) -> ArrayInput
R: Positions. Shape[N, dim].Input
neighbor: ANeighborListMultiImage.Output: Total energy (scalar) or per-atom energies (shape
[N]).
- Return type:
Example
from jax_md import energy from jax_md.custom_partition import neighbor_list_multi_image from jax_md.custom_smap import pair_neighbor_list_multi_image # Create Lennard-Jones energy function for multi-image neighbors lj_energy = pair_neighbor_list_multi_image( energy.lennard_jones, sigma=1.0, epsilon=1.0, ) # Use with multi-image neighbor list neighbor_fn = neighbor_list_multi_image(None, box, r_cutoff, n_atoms=N) nbrs = neighbor_fn.allocate(positions) E = lj_energy(positions, nbrs) # Compute forces via autodiff force_fn = jax.grad(lambda R, nbrs: -lj_energy(R, nbrs)) F = force_fn(positions, nbrs) # Shape: [N, dim]