Source code for jax_md.custom_smap

"""Structure-mapped functions for multi-image neighbor lists.

This module provides analogues to ``jax_md.smap`` functions that work with
``NeighborListMultiImage`` to correctly handle small periodic boxes where
:math:`r_\\text{cut} > L/2`.

The key difference from standard ``smap`` functions is that these use explicit
lattice shifts stored in the neighbor list rather than relying on the minimum
image convention.
"""

from typing import Callable, Tuple, Optional
import functools
from functools import partial

import jax
import jax.numpy as jnp
from jax import ops
from jax_md import partition, space, util
from jax_md.custom_partition import (
  NeighborListMultiImage,
  NeighborListFormat,
  neighbor_list_multi_image_mask,
)

# Type aliases
Array = jnp.ndarray
f32 = jnp.float32
i32 = jnp.int32


[docs] def pair_neighbor_list_multi_image( pair_fn: Callable[..., Array], displacement_fn=None, # Ignored; always uses space.free() internally species: Array | None = None, # [N] or None reduce_axis: Tuple[int, ...] | None = None, ignore_unused_parameters: bool = False, # For API compatibility fractional_coordinates: bool = True, **static_kwargs, ) -> Callable[[Array, NeighborListMultiImage], Array]: r"""Creates a function for pair potentials using multi-image neighbors. This function is analogous to ``jax_md.smap.pair_neighbor_list`` but works with ``NeighborListMultiImage`` to correctly handle small periodic boxes where :math:`r_\text{cut} > L/2`. **Displacement Function:** Always uses ``space.free()`` for displacement computation since periodicity is handled via explicit lattice shifts. The ``displacement_fn`` parameter is ignored (for API compatibility with ``smap.pair_neighbor_list``). Supports ``perturbation`` kwarg for stress calculation via ``quantity.stress``. For each edge :math:`(i, j)` with shift :math:`\mathbf{s}`, computes: .. math:: E_{ij}^{\mathbf{s}} = f\left(\|\mathbf{r}_j + \mathbf{s} \cdot \mathbf{T} - \mathbf{r}_i\|\right) where :math:`f` is the pair function and :math:`\mathbf{T}` is the box matrix. The pair function should have the signature:: pair_fn(dr, **kwargs) -> energy where ``dr`` is an array of scalar pairwise distances of shape ``[capacity]``. This matches the signature expected by ``smap.pair_neighbor_list``, so the same pair function (e.g., ``energy.lennard_jones``) works with both. **Gradient handling:** Uses ``space.transform`` for coordinate transformations, which has a custom JVP that keeps gradients in the same coordinate system as inputs. When using fractional coordinates, ``jax.grad(energy_fn)`` returns forces in fractional coordinates (compatible with fractional-coordinate dynamics). **Format handling:** - ``Sparse``: Both :math:`i \\to j` and :math:`j \\to i` are stored, so energies are divided by 2. Supports per-particle energies. - ``OrderedSparse``: Only one direction per pair, no division needed. Does **not** support per-particle energies (raises ``ValueError``). - ``Dense``: Stores neighbors per atom as ``[N, max_neighbors]``. Both directions are stored, so energies are divided by 2. Supports per-particle energies via sum over axis 1. Args: pair_fn: A function that computes pairwise energies from distances. Examples: ``energy.lennard_jones``, ``energy.morse``, ``energy.soft_sphere``. Signature: ``(dr: Array[capacity], **kwargs) -> Array[capacity]``. displacement_fn: Ignored. Always uses ``space.free()`` internally. species: Optional species array. Shape ``[N]``. If provided, kwargs like ``sigma`` and ``epsilon`` should have shape ``[max_species, max_species]`` and will be indexed per-pair. fractional_coordinates: If True, positions are in fractional coordinates. reduce_axis: Axis over which to reduce the energy. If ``None`` (default), sums all pair energies to a scalar. If specified, returns per-atom energies of shape ``[N]``. **Note:** Per-atom energies are not supported with ``OrderedSparse`` format (raises ``ValueError``). **static_kwargs: Static parameters passed to the pair function (e.g., ``sigma``, ``epsilon``). Can be overridden at call time. Returns: An energy function with signature: ``energy_fn(R, neighbor, **kwargs) -> Array`` - Input ``R``: Positions. Shape ``[N, dim]``. - Input ``neighbor``: A ``NeighborListMultiImage``. - Output: Total energy (scalar) or per-atom energies (shape ``[N]``). Example: .. code-block:: python from jax_md import energy from jax_md.custom_partition import neighbor_list_multi_image from jax_md.custom_smap import pair_neighbor_list_multi_image # Create Lennard-Jones energy function for multi-image neighbors lj_energy = pair_neighbor_list_multi_image( energy.lennard_jones, sigma=1.0, epsilon=1.0, ) # Use with multi-image neighbor list neighbor_fn = neighbor_list_multi_image(None, box, r_cutoff, n_atoms=N) nbrs = neighbor_fn.allocate(positions) E = lj_energy(positions, nbrs) # Compute forces via autodiff force_fn = jax.grad(lambda R, nbrs: -lj_energy(R, nbrs)) F = force_fn(positions, nbrs) # Shape: [N, dim] """ # Use space.free() displacement which handles perturbation via kwargs # Delete the passed displacement_fn since it is ignored. del displacement_fn displacement_fn, _ = space.free() def energy_fn( R: Array, # [N, dim] neighbor: NeighborListMultiImage, **kwargs, ) -> Array: # scalar or [N] """Compute total pair energy.""" merged_kwargs = {**static_kwargs, **kwargs} _species = merged_kwargs.pop('species', species) box = neighbor.box # [dim, dim] N = R.shape[0] # Compute Cartesian positions using space.transform for correct gradients. # Note: space.transform has a custom JVP that keeps gradients in the same # coordinate system as inputs (fractional -> fractional forces). if fractional_coordinates: R_real = space.transform(box, R) # [N, dim] else: R_real = R # [N, dim] # Handle Dense vs Sparse formats differently if partition.is_sparse(neighbor.format): # Sparse/OrderedSparse: edge-list format # idx shape: [2, capacity], shifts shape: [capacity, dim] mask = neighbor_list_multi_image_mask(neighbor) # [capacity] # Compute displacement vectors: r_j + shift - r_i i_safe = jnp.clip(neighbor.receivers, 0, N - 1) # [capacity] j_safe = jnp.clip(neighbor.senders, 0, N - 1) # [capacity] shifts_real = space.transform(box, neighbor.shifts) # [capacity, dim] # Use displacement function which handles perturbation naturally Ra = R_real[i_safe] # [capacity, dim] Rb_shifted = R_real[j_safe] + shifts_real # [capacity, dim] d = jax.vmap(partial(displacement_fn, **merged_kwargs)) dR = d(Ra, Rb_shifted) # [capacity, dim] # Compute scalar distances dr = space.distance(dR) # [capacity] # Handle species-dependent parameters if _species is not None: species_i = _species[i_safe] # [capacity] species_j = _species[j_safe] # [capacity] processed_kwargs = {} for key, val in merged_kwargs.items(): if jnp.ndim(val) == 2: processed_kwargs[key] = val[species_i, species_j] # [capacity] else: processed_kwargs[key] = val merged_kwargs = processed_kwargs # Compute pair energies pair_energies = pair_fn(dr, **merged_kwargs) # [capacity] zero = jnp.zeros((), dtype=pair_energies.dtype) pair_energies = jnp.where(mask, pair_energies, zero) # [capacity] # Normalization: OrderedSparse stores one direction, Sparse stores both normalization = ( 1.0 if neighbor.format is NeighborListFormat.OrderedSparse else 2.0 ) if reduce_axis is None: return util.high_precision_sum(pair_energies) / normalization else: # Per-particle energy via segment_sum if neighbor.format is NeighborListFormat.OrderedSparse: raise ValueError( 'Cannot compute per-particle energies with OrderedSparse format. ' 'OrderedSparse stores only one direction per pair, so segment_sum ' 'would assign the full pair energy to the receiver atom only. ' 'Use Sparse or Dense format for per-particle energies.' ) particle_energies = ops.segment_sum( pair_energies * mask, neighbor.receivers, N ) # [N] return particle_energies / normalization else: # Dense format: per-atom neighbor arrays # idx shape: [N, max_neighbors], shifts shape: [N, max_neighbors, dim] idx = neighbor.idx # [N, max_neighbors] shifts = neighbor.shifts # [N, max_neighbors, dim] mask = idx < N # [N, max_neighbors] # Get neighbor positions with safe indexing j_safe = jnp.clip(idx, 0, N - 1) # [N, max_neighbors] R_neigh = R_real[j_safe] # [N, max_neighbors, dim] # Transform shifts to real space # shifts has shape [N, max_neighbors, dim], need to reshape for transform shifts_flat = shifts.reshape( -1, shifts.shape[-1] ) # [N*max_neighbors, dim] shifts_real_flat = space.transform( box, shifts_flat ) # [N*max_neighbors, dim] shifts_real = shifts_real_flat.reshape( shifts.shape ) # [N, max_neighbors, dim] # Compute displacements using displacement function (handles perturbation) # R_real[:, None, :] broadcasts to [N, 1, dim] Ra = jnp.broadcast_to( R_real[:, None, :], R_neigh.shape ) # [N, max_neighbors, dim] Rb_shifted = R_neigh + shifts_real # [N, max_neighbors, dim] # Flatten for vmap, compute displacements, reshape back Ra_flat = Ra.reshape(-1, Ra.shape[-1]) # [N*max_neighbors, dim] Rb_flat = Rb_shifted.reshape( -1, Rb_shifted.shape[-1] ) # [N*max_neighbors, dim] d = jax.vmap(partial(displacement_fn, **merged_kwargs)) dR_flat = d(Ra_flat, Rb_flat) # [N*max_neighbors, dim] dR = dR_flat.reshape( idx.shape[0], idx.shape[1], -1 ) # [N, max_neighbors, dim] # Compute scalar distances dr = space.distance(dR) # [N, max_neighbors] # Handle species-dependent parameters if _species is not None: species_i = jnp.broadcast_to( _species[:, None], idx.shape ) # [N, max_neighbors] species_j = _species[j_safe] # [N, max_neighbors] processed_kwargs = {} for key, val in merged_kwargs.items(): if jnp.ndim(val) == 2: processed_kwargs[key] = val[ species_i, species_j ] # [N, max_neighbors] else: processed_kwargs[key] = val merged_kwargs = processed_kwargs # Compute pair energies pair_energies = pair_fn(dr, **merged_kwargs) # [N, max_neighbors] zero = jnp.zeros((), dtype=pair_energies.dtype) pair_energies = jnp.where(mask, pair_energies, zero) # [N, max_neighbors] # Dense stores both directions, so divide by 2 normalization = 2.0 if reduce_axis is None: return util.high_precision_sum(pair_energies) / normalization else: # Per-particle energy: sum over neighbors (axis 1) particle_energies = util.high_precision_sum( pair_energies, axis=1 ) # [N] return particle_energies / normalization return energy_fn