[ ]:
# /// script
# dependencies = [
#   "ase",
#   "tqdm",
# ]
# ///

Download Notebook Download Python Script

ReaxFF NVE Simulation#

Imports#

[ ]:
import jax

jax.config.update('jax_enable_x64', True)
import numpy as onp
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax_md import space, simulate, quantity
import matplotlib.pyplot as plt
from ase.io import read
from jax_md.mm_forcefields.reaxff.reaxff_interactions import reaxff_inter_list
from jax_md.mm_forcefields.reaxff.reaxff_helper import read_force_field
from jax_md.mm_forcefields.reaxff.reaxff_forcefield import ForceField
from jax_md.quantity import stress
from tqdm import tqdm

Control Variables#

[ ]:
TYPE = jnp.float64
ff_path = 'data/ffield'
cutoff2 = 0.001
geo_path = 'data/geo.pdb'
solver_iter_count = 200
solver_tol = 1e-8
MD_count = 1000
reax_time = 0.2  # timestep in fs

Read and Preprocess the Force Field#

[ ]:
force_field = read_force_field(ff_path, cutoff2=cutoff2, dtype=TYPE)
force_field = ForceField.fill_off_diag(force_field)
force_field = ForceField.fill_symm(force_field)

Read the Geometry File#

[ ]:
# read the geometry file
struct = read(geo_path)
struct.set_positions(struct.get_positions())
pos = struct.get_positions()
pos = pos - onp.min(pos, axis=0)
struct.set_positions(pos)

R = jnp.array(struct.positions, dtype=TYPE)

types = struct.get_chemical_symbols()
types_int = [force_field.name_to_index[t] for t in types]
species = jnp.array(types_int)
atomic_nums = jnp.array(struct.get_atomic_numbers())
box_size = jnp.array(struct.cell.lengths(), dtype=TYPE)
N = len(R)

Setup ReaxFF#

[ ]:
displacement, shift = space.periodic(box_size)
reaxff_inter_fn, energy_fn = reaxff_inter_list(
  displacement,
  box_size,
  species,
  atomic_nums,
  force_field,
  tol=solver_tol,
  solver_model='EEM',
  backprop_solve=False,
  tors_2013=False,
  max_solver_iter=solver_iter_count,
  short_inters_capacity_multiplier=1.5,
)

nbrs = reaxff_inter_fn.allocate(R)
mass = force_field.amas[species]
mass = jnp.array(mass, dtype=TYPE).reshape(-1, 1)
# Required conversion for timestep to match reaxff timestep
base_time = 48.8882129
multip = reax_time / base_time
# Initalize the NVE simulation
init_fn, apply_fn = simulate.nve(energy_fn, shift, multip)  # 4.184e-3
state = init_fn(jax.random.PRNGKey(0), R, kT=1e-8, mass=mass, nbr_lists=nbrs)


# MD Step Function
@jax.jit
def body_fn(i, state):
  state, nbrs = state
  nbrs = reaxff_inter_fn.update(state.position, nbrs)
  state = apply_fn(state, nbr_lists=nbrs)
  return state, nbrs

Run the MD#

[ ]:
# run the MD
PE = []
KE = []
step = 0
data_collect_freq = 10
pbar = tqdm(total=MD_count)
while step < MD_count:
  new_state, nbrs = body_fn(step, (state, nbrs))
  if nbrs.did_buffer_overflow == True:
    print('overflow step: ', step)
    print('Neighbor list overflowed, reallocating.')
    nbrs = reaxff_inter_fn.allocate(state.position)
  else:
    if step % data_collect_freq == 0:
      p_energy = jax.jit(energy_fn)(state.position, nbrs)
      k_energy = jax.jit(quantity.kinetic_energy)(
        velocity=state.velocity, mass=state.mass
      )
      PE += [p_energy]
      KE += [k_energy]
    state = new_state
    pbar.update(1)
    step += 1


PE = onp.array(PE)
KE = onp.array(KE)
total = PE + KE

Plot the Energy#

[ ]:
target_energy = onp.loadtxt('data/reaxff_target.txt')
target_PE = target_energy[:, 0][::data_collect_freq]
target_KE = target_energy[:, 1][::data_collect_freq]
plt.plot(target_PE, label='target PE (PuReMD)')
plt.plot(PE, label='JAX-MD PE')
plt.ylabel('Potential Energy (kcal/mol)')
plt.xlabel(f'MD step / {data_collect_freq}')
plt.legend()

Calculate the Stress#

[ ]:
S_ij = stress(energy_fn, state.position, box_size, nbr_lists=nbrs)
print(f'Stress Tensor: {S_ij}')