Download Notebook Download Python Script

Multi-Image Neighbor Lists for Small Periodic Boxes#

This tutorial demonstrates how to use multi-image neighbor lists in JAX-MD for systems where the cutoff radius is larger than half the box length (\(r_\text{cut} > L/2\)).

The Problem with Standard Neighbor Lists#

Standard neighbor lists in JAX-MD use the Minimum Image Convention (MIC), which assumes each atom interacts with at most one periodic image of every other atom. This works well when:

\[r_\text{cut} < \frac{L}{2}\]

However, for small periodic boxes (common in ab initio MD or machine learning potentials with longer cutoffs), an atom may interact with multiple periodic images of the same neighbor. The multi-image neighbor list explicitly enumerates all images within the cutoff.

Imports & Setup#

[1]:
import os

IN_COLAB = 'COLAB_RELEASE_TAG' in os.environ
if IN_COLAB:
  import subprocess
  import sys

  subprocess.run(
    [
      sys.executable,
      '-m',
      'pip',
      'install',
      '-q',
      'git+https://github.com/jax-md/jax-md.git',
    ]
  )

import numpy as onp

from jax import config

config.update('jax_enable_x64', True)

import jax.numpy as jnp
from jax import random, jit, lax

import time
import matplotlib.pyplot as plt
import seaborn as sns

from jax_md import space, energy, partition, quantity, simulate
from jax_md.custom_partition import neighbor_list_multi_image
from jax_md.custom_smap import pair_neighbor_list_multi_image

SMOKE_TEST = os.environ.get('READTHEDOCS', False)

sns.set_style(style='white')


def format_plot(x, y):
  plt.xlabel(x, fontsize=20)
  plt.ylabel(y, fontsize=20)


def finalize_plot(shape=(1, 1)):
  plt.gcf().set_size_inches(
    shape[0] * 1.5 * plt.gcf().get_size_inches()[1],
    shape[1] * 1.5 * plt.gcf().get_size_inches()[1],
  )
  plt.tight_layout()

Helper: Create Crystal Structures#

[2]:
def make_fcc(n_cells, a=1.0):
  """Create FCC crystal positions in fractional coordinates.

  Args:
    n_cells: Number of unit cells in each direction.
    a: Lattice constant.

  Returns:
    R: Fractional positions of shape [N, 3] in [0, 1).
    box: Box matrix of shape [3, 3] with columns as lattice vectors.
  """
  # FCC basis: 4 atoms per unit cell
  basis = onp.array(
    [
      [0.0, 0.0, 0.0],
      [0.5, 0.5, 0.0],
      [0.5, 0.0, 0.5],
      [0.0, 0.5, 0.5],
    ]
  )

  positions = []
  for i in range(n_cells):
    for j in range(n_cells):
      for k in range(n_cells):
        for b in basis:
          pos = (onp.array([i, j, k]) + b) / n_cells
          positions.append(pos)

  R = onp.array(positions)
  L = n_cells * a
  box = onp.eye(3) * L
  return jnp.array(R), jnp.array(box)


def make_diamond_cubic(n_cells, a=5.43):
  """Create diamond cubic crystal using the 2-atom primitive cell.

  Uses the FCC primitive cell with a 2-atom basis:
  - Lattice vectors: a1=(0,1,1)a/2, a2=(1,0,1)a/2, a3=(1,1,0)a/2
  - Basis: (0,0,0) and (1/4,1/4,1/4) in fractional coordinates

  This is more efficient than the 8-atom conventional cell.
  Used for silicon (a=5.43 Å) and germanium (a=5.66 Å).

  Args:
    n_cells: Number of primitive cells in each direction.
    a: Conventional cubic lattice constant (default 5.43 Å for silicon).

  Returns:
    R: Fractional positions of shape [N, 3] in [0, 1).
    box: Box matrix of shape [3, 3] with columns as FCC primitive vectors.
  """
  # 2-atom basis in fractional coordinates of primitive cell
  basis = onp.array(
    [
      [0.0, 0.0, 0.0],
      [0.25, 0.25, 0.25],
    ]
  )

  positions = []
  for i in range(n_cells):
    for j in range(n_cells):
      for k in range(n_cells):
        for b in basis:
          pos = (onp.array([i, j, k]) + b) / n_cells
          positions.append(pos)

  R = onp.array(positions)

  # FCC primitive lattice vectors (columns of box matrix)
  # a1 = (0, a/2, a/2), a2 = (a/2, 0, a/2), a3 = (a/2, a/2, 0)
  box = (a / 2.0) * onp.array(
    [
      [0.0, 1.0, 1.0],
      [1.0, 0.0, 1.0],
      [1.0, 1.0, 0.0],
    ]
  )
  # Scale by n_cells
  box = box * n_cells

  return jnp.array(R), jnp.array(box)

Example 1: Lennard-Jones with All Three neighbor list formats#

We compute LJ energy for a small FCC argon crystal using all three neighbor list formats to verify they produce identical results:

  • Dense: Per-atom neighbor arrays [N, max_neighbors]

  • Sparse: Edge list [2, capacity] with both directions

  • OrderedSparse: Edge list with one direction per pair (most efficient)

[3]:
# Argon LJ parameters (reduced units: sigma=1, epsilon=1)
sigma = 1.0  # Length unit
epsilon = 1.0  # Energy unit
r_cutoff = 2.5 * sigma
r_onset = 2.0 * sigma

# Create small FCC argon crystal where r_cut > L/2
# In reduced units, equilibrium nearest-neighbor distance ≈ 2^(1/6) * sigma ≈ 1.12
# FCC lattice constant a = sqrt(2) * nearest_neighbor ≈ 1.58 in reduced units
n_cells = 2
a_reduced = 1.55  # Small box to test multi-image (r_cut/L > 0.5)
R, box = make_fcc(n_cells, a=a_reduced)
N = len(R)
L = float(box[0, 0])

print(f'System: {N} Ar atoms in {n_cells}x{n_cells}x{n_cells} FCC')
print(f'Box size L = {L:.2f}sigma, r_cutoff = {r_cutoff:.2f}sigma')
print(f'r_cutoff / L = {r_cutoff / L:.2f} (> 0.5: multi-image needed)')

# Setup displacement function
displacement_fn, shift_fn = space.periodic_general(
  box, fractional_coordinates=True
)

# Test all three formats
formats = [
  ('Dense', partition.Dense),
  ('Sparse', partition.Sparse),
  ('OrderedSparse', partition.OrderedSparse),
]

energies = {}
for name, fmt in formats:
  neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(
    displacement_fn,
    box,
    sigma=sigma,
    epsilon=epsilon,
    r_onset=r_onset / sigma,
    r_cutoff=r_cutoff / sigma,
    fractional_coordinates=True,
    neighbor_list_fn=neighbor_list_multi_image,
    pair_neighbor_list_fn=pair_neighbor_list_multi_image,
    format=fmt,
  )

  nbrs = neighbor_fn.allocate(R)
  E = float(energy_fn(R, nbrs))
  energies[name] = E

  # Get neighbor count
  if partition.is_sparse(fmt):
    n_edges = int(jnp.sum(nbrs.idx[0] < N))
  else:
    n_edges = int(jnp.sum(nbrs.idx < N))

  print(f'{name:15s}: E = {E:12.6f}, edges = {n_edges}')

# Verify all formats give the same energy
E_ref = energies['Dense']
for name, E in energies.items():
  assert abs(E - E_ref) < 1e-5, f'{name} energy mismatch: {E} vs {E_ref}'
System: 32 Ar atoms in 2x2x2 FCC
Box size L = 3.10sigma, r_cutoff = 2.50sigma
r_cutoff / L = 0.81 (> 0.5: multi-image needed)
Dense          : E =  -250.355417, edges = 4288
/home/docs/checkouts/readthedocs.org/user_builds/jax-md/envs/main/lib/python3.12/site-packages/jax/_src/ops/scatter.py:104: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32 with jax_numpy_dtype_promotion=standard. In future JAX releases this will result in an error.
  warnings.warn(
Sparse         : E =  -250.355417, edges = 4288
OrderedSparse  : E =  -250.355417, edges = 2144

Forces and Stress Computation#

We can compute forces using quantity.force and stress using quantity.stress. The multi-image neighbor list with graph_featurizer supports the perturbation kwarg required for stress calculation.

[4]:
# Use Sparse format for force/stress computation
neighbor_fn_lj, energy_fn_lj = energy.lennard_jones_neighbor_list(
  displacement_fn,
  box,
  sigma=sigma,
  epsilon=epsilon,
  r_onset=r_onset / sigma,
  r_cutoff=r_cutoff / sigma,
  fractional_coordinates=True,
  neighbor_list_fn=neighbor_list_multi_image,
  pair_neighbor_list_fn=pair_neighbor_list_multi_image,
  format=partition.Sparse,
)

# Perturb positions slightly from equilibrium to get non-zero forces
key = random.PRNGKey(42)
R_perturbed = R + random.normal(key, R.shape) * 0.01

nbrs_lj = neighbor_fn_lj.allocate(R_perturbed)
E_lj = float(energy_fn_lj(R_perturbed, nbrs_lj))

# Compute forces
force_fn = quantity.force(energy_fn_lj)
F = force_fn(R_perturbed, neighbor=nbrs_lj)
max_force = float(jnp.max(jnp.abs(F)))
print(f'Perturbed energy: {E_lj:.6f}')
print(f'Max force magnitude: {max_force:.6f}')

# Compute stress (3x3 tensor)
stress = quantity.stress(energy_fn_lj, R_perturbed, box, neighbor=nbrs_lj)
print(f'Stress tensor (diagonal): [{stress[0,0]:.4f}, {stress[1,1]:.4f}, {stress[2,2]:.4f}]')
print(f'Pressure: {-jnp.trace(stress) / 3:.6f}')
Perturbed energy: -232.244742
Max force magnitude: 51.575608
Stress tensor (diagonal): [2.8778, 3.3587, 3.1045]
Pressure: -3.113656

Example 2: Stillinger-Weber (Three-Body Potential)#

Stillinger-Weber is a three-body potential for silicon that requires Dense format for the angular terms. We use a 2x2x2 supercell of the 2-atom primitive cell.

Note: Stillinger-Weber internally uses space.map_neighbor for displacement computation, which applies the minimum image convention (MIC). For small boxes where r_cut > L/2, the multi-image neighbor list finds the correct neighbors, but the energy computation would still use MIC displacements. Therefore, we use a larger box where MIC is valid.

[5]:
# Stillinger-Weber parameters for silicon
sw_sigma = 2.0951  # Angstrom
sw_cutoff = 1.8 * sw_sigma  # ~3.77 Angstrom

# Create 3x3x3 supercell so that MIC is valid
# For SW, the box must be large enough that r_cut < L/2
n_cells_sw = 3
a_sw = 5.43  # Si lattice constant
R_sw, box_sw = make_diamond_cubic(n_cells_sw, a=a_sw)
N_sw = len(R_sw)

# For non-cubic boxes, compute minimum perpendicular height
inv_box_T = jnp.linalg.inv(box_sw).T
heights_sw = 1.0 / jnp.linalg.norm(inv_box_T, axis=0)
L_min_sw = float(jnp.min(heights_sw))

print(f'System: {N_sw} Si atoms in {n_cells_sw}x{n_cells_sw}x{n_cells_sw} diamond cubic supercell')
print(f'Min box height = {L_min_sw:.2f} Angstrom, SW cutoff = {sw_cutoff:.2f} Angstrom')
print(f'cutoff / L_min = {sw_cutoff / L_min_sw:.2f} (< 0.5: MIC is valid)')

displacement_sw, shift_sw = space.periodic_general(
  box_sw, fractional_coordinates=True
)

# Stillinger-Weber only supports Dense format (three-body terms)
# Note: SW uses MIC internally, so multi-image NL only helps with neighbor finding
neighbor_fn_sw, energy_fn_sw = energy.stillinger_weber_neighbor_list(
  displacement_sw,
  box_sw,
  neighbor_list_fn=neighbor_list_multi_image,
  format=partition.Dense,
  fractional_coordinates=True,
)

nbrs_sw = neighbor_fn_sw.allocate(R_sw)
E_sw = float(energy_fn_sw(R_sw, nbrs_sw))
n_edges_sw = int(jnp.sum(nbrs_sw.idx < N_sw))

print(f'Stillinger-Weber energy: {E_sw:.6f} eV')
print(f'Number of edges: {n_edges_sw}')
print('Stillinger-Weber computes correctly (MIC valid for this box size).')
System: 54 Si atoms in 3x3x3 diamond cubic supercell
Min box height = 9.41 Angstrom, SW cutoff = 3.77 Angstrom
cutoff / L_min = 0.40 (< 0.5: MIC is valid)
Stillinger-Weber energy: -234.171986 eV
Number of edges: 864
Stillinger-Weber computes correctly (MIC valid for this box size).

Example 3: NVE Molecular Dynamics#

We run NVE (constant energy) molecular dynamics with the multi-image neighbor list. This demonstrates rebuild tracking and overflow handling following the pattern recommended in partition.neighbor_list documentation.

[6]:

# Simulation parameters N_md = 500 dimension = 2 box_size = 40.0 if SMOKE_TEST else 60.0 # Create box matrix for 2D box_md = jnp.eye(dimension) * box_size # Random initial positions (fractional coordinates in [0, 1)) key = random.PRNGKey(0) R_md = random.uniform(key, (N_md, dimension), minval=0.0, maxval=1.0) # 50:50 mixture of two species sigma_md = jnp.array([[1.0, 1.2], [1.2, 1.4]]) N_half = N_md // 2 species = jnp.where(jnp.arange(N_md) < N_half, 0, 1) # Cutoff r_cutoff_md = 2.5 print(f'System: {N_md} atoms in {dimension}D box of size {box_size}') print(f'Cutoff: {r_cutoff_md}, cutoff/L = {r_cutoff_md / box_size:.3f}') # Setup displacement function for fractional coordinates displacement_md, shift_md = space.periodic_general( box_md, fractional_coordinates=True ) # For random positions, use generous capacity to avoid overflow # Random positions can cluster, requiring more capacity than uniform estimates # Use soft sphere potential with multi-image neighbor list neighbor_fn_md, energy_fn_md = energy.soft_sphere_neighbor_list( displacement_md, box_md, species=species, sigma=sigma_md, fractional_coordinates=True, neighbor_list_fn=neighbor_list_multi_image, pair_neighbor_list_fn=pair_neighbor_list_multi_image, format=partition.Sparse, ) # Initialize neighbor list nbrs_md = neighbor_fn_md.allocate(R_md) if nbrs_md.did_buffer_overflow: raise RuntimeError('Neighbor list overflowed - increase max_neighbors') # Setup NVE integrator dt = 1e-2 init_fn, apply_fn = simulate.nve(energy_fn_md, shift_md, dt) # Initialize state with zero temperature state = init_fn(key, R_md, neighbor=nbrs_md, kT=0.0) # JIT-compiled step function with neighbor list update @jit def step_fn(i, state_and_nbrs): state, nbrs = state_and_nbrs state = apply_fn(state, neighbor=nbrs) nbrs = nbrs.update(state.position) return state, nbrs
System: 500 atoms in 2D box of size 40.0
Cutoff: 2.5, cutoff/L = 0.062
[7]:
# Run simulation following JAX-MD's recommended pattern for overflow handling.
# See partition.neighbor_list docstring for the canonical example.
N_steps = 200 if SMOKE_TEST else 1000
print_every = 20
inner_steps = 10

PE = []
KE = []
rebuild_count = 0
realloc_count = 0

print(f'{"Step":>4} {"KE":>5} {"PE":>6} {"Total":>6} {"dt":>6} {"rebuild":>7} {"realloc":>7}')
old_time = time.time()

for i in range(N_steps):
  # Track reference position before inner loop
  old_ref_pos = nbrs_md.reference_position

  # Run inner_steps using fori_loop for efficiency
  new_state, new_nbrs = lax.fori_loop(0, inner_steps, step_fn, (state, nbrs_md))

  # Check for buffer overflow after the loop
  # If overflow: discard new state, reallocate with extra capacity
  # If no overflow: accept new state
  if new_nbrs.did_buffer_overflow:
    # Reallocate with extra capacity (10 more neighbors per atom)
    nbrs_md = neighbor_fn_md.allocate(state.position, extra_capacity=10)
    realloc_count += 1
    print(f'  [Overflow at step {i * inner_steps}! Reallocating with extra capacity...]')
    # Don't advance state - retry from last good state
  else:
    # Accept the new state
    state = new_state
    nbrs_md = new_nbrs

    # Check if rebuild happened (reference position changed)
    new_ref_pos = nbrs_md.reference_position
    if not jnp.allclose(old_ref_pos, new_ref_pos):
      rebuild_count += 1

  pe = float(energy_fn_md(state.position, nbrs_md))
  ke = float(quantity.kinetic_energy(momentum=state.momentum))
  PE.append(pe)
  KE.append(ke)

  if i % print_every == 0 and i > 0:
    new_time = time.time()
    step_time = (new_time - old_time) / print_every / inner_steps
    print(
      f'{i * inner_steps:4d} {ke:5.1f} {pe:6.1f} {ke + pe:6.2f} '
      f'{step_time:6.4f} {rebuild_count:7d} {realloc_count:7d}'
    )
    old_time = new_time

PE = jnp.array(PE)
KE = jnp.array(KE)

print(f'Total energy drift: {abs(float(PE[-1] + KE[-1] - PE[0] - KE[0])):.2e}')
print(f'Total rebuilds: {rebuild_count}, reallocs: {realloc_count}')
Step    KE     PE  Total     dt rebuild realloc
 200  26.3    5.5  31.75 0.0224      16       0
 400  27.0    4.8  31.75 0.0139      34       0
 600  26.8    4.9  31.74 0.0143      51       0
 800  26.8    5.0  31.74 0.0130      68       0
1000  25.5    6.3  31.74 0.0141      88       0
1200  27.4    4.4  31.74 0.0138     107       0
1400  25.7    6.1  31.75 0.0141     126       0
1600  26.0    5.8  31.75 0.0153     145       0
1800  26.4    5.4  31.75 0.0127     162       0
Total energy drift: 9.10e-03
Total rebuilds: 178, reallocs: 0

Plot Energy Evolution#

We verify energy conservation by plotting PE, KE, and total energy over time.

[8]:
t = onp.arange(N_steps) * dt * inner_steps

plt.figure(figsize=(10, 6))
plt.plot(t, PE, label='PE', linewidth=2)
plt.plot(t, KE, label='KE', linewidth=2)
plt.plot(t, PE + KE, label='Total Energy', linewidth=2, linestyle='--')
plt.legend(fontsize=12)
format_plot('Time', 'Energy')
plt.title('NVE Energy Conservation with Multi-Image Neighbor List', fontsize=14)
finalize_plot()
plt.savefig('nve_multi_image.png', dpi=150, bbox_inches='tight')
plt.show()
../_images/examples_custom_nl_16_0.png

Visualize Final Configuration#

[9]:
ms = 40 if SMOKE_TEST else 15
R_final = onp.array(state.position)

# Convert from fractional to Cartesian for plotting
R_cart = R_final * box_size

plt.figure(figsize=(8, 8))
plt.plot(
  R_cart[:N_half, 0], R_cart[:N_half, 1], 'o', markersize=ms * 0.5, alpha=0.7
)
plt.plot(
  R_cart[N_half:, 0], R_cart[N_half:, 1], 'o', markersize=ms * 0.7, alpha=0.7
)
plt.xlim([0, box_size])
plt.ylim([0, box_size])
plt.axis('off')
plt.title('Final Configuration', fontsize=14)
finalize_plot((2, 2))
plt.savefig('nve_final_config.png', dpi=150, bbox_inches='tight')
plt.show()
../_images/examples_custom_nl_18_0.png