Multi-Image Structure Maps

Multi-Image Structure Maps#

Structure-mapped functions for multi-image neighbor lists.

This module provides analogues to jax_md.smap functions that work with NeighborListMultiImage to correctly handle small periodic boxes where \(r_\text{cut} > L/2\).

The key difference from standard smap functions is that these use explicit lattice shifts stored in the neighbor list rather than relying on the minimum image convention.

This module provides structure-mapped functions (analogous to jax_md.smap) that work with NeighborListMultiImage to correctly handle small periodic boxes where \(r_\text{cut} > L/2\).

The key difference from standard smap functions is that these use explicit lattice shifts stored in the neighbor list rather than relying on the minimum image convention.

Pair Functions#

jax_md.custom_smap.pair_neighbor_list_multi_image(pair_fn, displacement_fn=None, species=None, reduce_axis=None, ignore_unused_parameters=False, fractional_coordinates=True, **static_kwargs)[source]#

Creates a function for pair potentials using multi-image neighbors.

This function is analogous to jax_md.smap.pair_neighbor_list but works with NeighborListMultiImage to correctly handle small periodic boxes where \(r_\text{cut} > L/2\).

Displacement Function:

Always uses space.free() for displacement computation since periodicity is handled via explicit lattice shifts. The displacement_fn parameter is ignored (for API compatibility with smap.pair_neighbor_list). Supports perturbation kwarg for stress calculation via quantity.stress.

For each edge \((i, j)\) with shift \(\mathbf{s}\), computes:

\[E_{ij}^{\mathbf{s}} = f\left(\|\mathbf{r}_j + \mathbf{s} \cdot \mathbf{T} - \mathbf{r}_i\|\right)\]

where \(f\) is the pair function and \(\mathbf{T}\) is the box matrix.

The pair function should have the signature:

pair_fn(dr, **kwargs) -> energy

where dr is an array of scalar pairwise distances of shape [capacity]. This matches the signature expected by smap.pair_neighbor_list, so the same pair function (e.g., energy.lennard_jones) works with both.

Gradient handling:

Uses space.transform for coordinate transformations, which has a custom JVP that keeps gradients in the same coordinate system as inputs. When using fractional coordinates, jax.grad(energy_fn) returns forces in fractional coordinates (compatible with fractional-coordinate dynamics).

Format handling:

  • Sparse: Both \(i \\to j\) and \(j \\to i\) are stored, so energies are divided by 2. Supports per-particle energies.

  • OrderedSparse: Only one direction per pair, no division needed. Does not support per-particle energies (raises ValueError).

  • Dense: Stores neighbors per atom as [N, max_neighbors]. Both directions are stored, so energies are divided by 2. Supports per-particle energies via sum over axis 1.

Parameters:
  • pair_fn (Callable[..., Array]) – A function that computes pairwise energies from distances. Examples: energy.lennard_jones, energy.morse, energy.soft_sphere. Signature: (dr: Array[capacity], **kwargs) -> Array[capacity].

  • displacement_fn – Ignored. Always uses space.free() internally.

  • species (Array | None) – Optional species array. Shape [N]. If provided, kwargs like sigma and epsilon should have shape [max_species, max_species] and will be indexed per-pair.

  • fractional_coordinates (bool) – If True, positions are in fractional coordinates.

  • reduce_axis (Optional[Tuple[int, ...]]) – Axis over which to reduce the energy. If None (default), sums all pair energies to a scalar. If specified, returns per-atom energies of shape [N]. Note: Per-atom energies are not supported with OrderedSparse format (raises ValueError).

  • **static_kwargs – Static parameters passed to the pair function (e.g., sigma, epsilon). Can be overridden at call time.

Returns:

energy_fn(R, neighbor, **kwargs) -> Array

  • Input R: Positions. Shape [N, dim].

  • Input neighbor: A NeighborListMultiImage.

  • Output: Total energy (scalar) or per-atom energies (shape [N]).

Return type:

Callable[..., Array]

Example

from jax_md import energy
from jax_md.custom_partition import neighbor_list_multi_image
from jax_md.custom_smap import pair_neighbor_list_multi_image

# Create Lennard-Jones energy function for multi-image neighbors
lj_energy = pair_neighbor_list_multi_image(
    energy.lennard_jones,
    sigma=1.0,
    epsilon=1.0,
)

# Use with multi-image neighbor list
neighbor_fn = neighbor_list_multi_image(None, box, r_cutoff, n_atoms=N)
nbrs = neighbor_fn.allocate(positions)
E = lj_energy(positions, nbrs)

# Compute forces via autodiff
force_fn = jax.grad(lambda R, nbrs: -lj_energy(R, nbrs))
F = force_fn(positions, nbrs)  # Shape: [N, dim]