Simulation Routines#

Code to simulate systems in various statistical ensembles.

This file contains a number of different methods that can be used to simulate systems in a variety of ensembles.

In general, simulation code follows the same overall structure as optimizers in JAX. Simulations are tuples of two functions:

init_fn:

Function that initializes the state of a system. Should take positions as an ndarray of shape [n, output_dimension]. Returns a state which will be a namedtuple.

apply_fn:

Function that takes a state and produces a new state after one step of optimization.

One question that we need to think about is whether the simulations should also return a function that computes the invariant for that ensemble. This can be used for testing purposes, but is not often used otherwise.

Deterministic Simulation Environments#

jax_md.simulate.nve(energy_or_force_fn, shift_fn, dt=0.001, **sim_kwargs)[source]#

Simulates a system in the NVE ensemble.

Samples from the microcanonical ensemble in which the number of particles (N), the system volume (V), and the energy (E) are held constant. We use a standard velocity Verlet integration scheme.

Parameters
  • energy_or_force – A function that produces either an energy or a force from a set of particle positions specified as an ndarray of shape [n, spatial_dimension].

  • shift_fn – A function that displaces positions, R, by an amount dR. Both R and dR should be ndarrays of shape [n, spatial_dimension].

  • dt – Floating point number specifying the timescale (step size) of the simulation.

Returns

See above.

jax_md.simulate.nvt_nose_hoover(energy_or_force_fn, shift_fn, dt, kT, chain_length=5, chain_steps=2, sy_steps=3, tau=None, **sim_kwargs)[source]#

Simulation in the NVT ensemble using a Nose Hoover Chain thermostat.

Samples from the canonical ensemble in which the number of particles (N), the system volume (V), and the temperature (T) are held constant. We use a Nose Hoover Chain (NHC) thermostat described in 1 2 3. We follow the direct translation method outlined in Tuckerman et al. 3 and the interested reader might want to look at that paper as a reference.

Parameters
  • energy_or_force – A function that produces either an energy or a force from a set of particle positions specified as an ndarray of shape [n, spatial_dimension].

  • shift_fn (Callable[[Array, Array], Array]) – A function that displaces positions, R, by an amount dR. Both R and dR should be ndarrays of shape [n, spatial_dimension].

  • dt (float) – Floating point number specifying the timescale (step size) of the simulation.

  • kT (float) – Floating point number specifying the temperature in units of Boltzmann constant. To update the temperature dynamically during a simulation one should pass kT as a keyword argument to the step function.

  • chain_length (int) – An integer specifying the number of particles in the Nose-Hoover chain.

  • chain_steps (int) – An integer specifying the number, \(n_c\), of outer substeps.

  • sy_steps (int) – An integer specifying the number of Suzuki-Yoshida steps. This must be either 1, 3, 5, or 7.

  • tau (Optional[float]) – A floating point timescale over which temperature equilibration occurs. Measured in units of dt. The performance of the Nose-Hoover chain thermostat can be quite sensitive to this choice.

Return type

Tuple[Callable[…, ~T], Callable[[~T], ~T]]

Returns

See above.

References

1(1,2,3,4)

Martyna, Glenn J., Michael L. Klein, and Mark Tuckerman. “Nose-Hoover chains: The canonical ensemble via continuous dynamics.” The Journal of chemical physics 97, no. 4 (1992): 2635-2643.

2(1,2)

Martyna, Glenn, Mark Tuckerman, Douglas J. Tobias, and Michael L. Klein. “Explicit reversible integrators for extended systems dynamics.” Molecular Physics 87. (1998) 1117-1157.

3(1,2,3,4)

Tuckerman, Mark E., Jose Alejandre, Roberto Lopez-Rendon, Andrea L. Jochim, and Glenn J. Martyna. “A Liouville-operator derived measure-preserving integrator for molecular dynamics simulations in the isothermal-isobaric ensemble.” Journal of Physics A: Mathematical and General 39, no. 19 (2006): 5629.

jax_md.simulate.npt_nose_hoover(energy_fn, shift_fn, dt, pressure, kT, barostat_kwargs=None, thermostat_kwargs=None)[source]#

Simulation in the NPT ensemble using a pair of Nose Hoover Chains.

Samples from the canonical ensemble in which the number of particles (N), the system pressure (P), and the temperature (T) are held constant. We use a pair of Nose Hoover Chains (NHC) described in 1 2 3 coupled to the barostat and the thermostat respectively. We follow the direct translation method outlined in Tuckerman et al. 3 and the interested reader might want to look at that paper as a reference.

Parameters
  • energy_fn (Callable[…, Array]) – A function that produces either an energy from a set of particle positions specified as an ndarray of shape [n, spatial_dimension].

  • shift_fn (Callable[[Array, Array], Array]) – A function that displaces positions, R, by an amount dR. Both R and dR should be ndarrays of shape [n, spatial_dimension].

  • dt (float) – Floating point number specifying the timescale (step size) of the simulation.

  • pressure (float) – Floating point number specifying the target pressure. To update the pressure dynamically during a simulation one should pass pressure as a keyword argument to the step function.

  • kT (float) – Floating point number specifying the temperature in units of Boltzmann constant. To update the temperature dynamically during a simulation one should pass kT as a keyword argument to the step function.

  • barostat_kwargs (Optional[Dict]) – A dictionary of keyword arguments passed to the barostat NHC. Any parameters not set are drawn from a relatively robust default set.

  • thermostat_kwargs (Optional[Dict]) – A dictionary of keyword arguments passed to the thermostat NHC. Any parameters not set are drawn from a relatively robust default set.

Return type

Tuple[Callable[…, ~T], Callable[[~T], ~T]]

Returns

See above.

Stochastic Simulation Environments#

jax_md.simulate.nvt_langevin(energy_or_force_fn, shift_fn, dt, kT, gamma=0.1, center_velocity=True, **sim_kwargs)[source]#

Simulation in the NVT ensemble using the BAOAB Langevin thermostat.

Samples from the canonical ensemble in which the number of particles (N), the system volume (V), and the temperature (T) are held constant. Langevin dynamics are stochastic and it is supposed that the system is interacting with fictitious microscopic degrees of freedom. An example of this would be large particles in a solvent such as water. Thus, Langevin dynamics are a stochastic ODE described by a friction coefficient and noise of a given covariance.

Our implementation follows the paper [#davidcheck] by Davidchack, Ouldridge, and Tretyakov.

Parameters
  • energy_or_force – A function that produces either an energy or a force from a set of particle positions specified as an ndarray of shape [n, spatial_dimension].

  • shift_fn (Callable[[Array, Array], Array]) – A function that displaces positions, R, by an amount dR. Both R and dR should be ndarrays of shape [n, spatial_dimension].

  • dt (float) – Floating point number specifying the timescale (step size) of the simulation.

  • kT (float) – Floating point number specifying the temperature in units of Boltzmann constant. To update the temperature dynamically during a simulation one should pass kT as a keyword argument to the step function.

  • gamma (float) – A float specifying the friction coefficient between the particles and the solvent.

  • center_velocity (bool) – A boolean specifying whether or not the center of mass position should be subtracted.

Return type

Tuple[Callable[…, ~T], Callable[[~T], ~T]]

Returns

See above.

References

4

R. L. Davidchack, T. E. Ouldridge, and M. V. Tretyakov. “New Langevin and gradient thermostats for rigid body dynamics.” The Journal of Chemical Physics 142, 144114 (2015)

jax_md.simulate.brownian(energy_or_force, shift, dt, kT, gamma=0.1)[source]#

Simulation of Brownian dynamics.

Simulates Brownian dynamics which are synonymous with the overdamped regime of Langevin dynamics. However, in this case we don’t need to take into account velocity information and the dynamics simplify. Consequently, when Brownian dynamics can be used they will be faster than Langevin. As in the case of Langevin dynamics our implementation follows Carlon et al. 4

Parameters
  • energy_or_force (Callable[…, Array]) – A function that produces either an energy or a force from a set of particle positions specified as an ndarray of shape [n, spatial_dimension].

  • shift_fn – A function that displaces positions, R, by an amount dR. Both R and dR should be ndarrays of shape [n, spatial_dimension].

  • dt (float) – Floating point number specifying the timescale (step size) of the simulation.

  • kT (float) – Floating point number specifying the temperature in units of Boltzmann constant. To update the temperature dynamically during a simulation one should pass kT as a keyword argument to the step function.

  • gamma (float) – A float specifying the friction coefficient between the particles and the solvent.

Return type

Tuple[Callable[…, ~T], Callable[[~T], ~T]]

Returns

See above.

jax_md.simulate.hybrid_swap_mc(space_fns, energy_fn, neighbor_fn, dt, kT, t_md, N_swap, sigma_fn=None)[source]#

Simulation of Hybrid Swap Monte-Carlo.

This code simulates the hybrid Swap Monte Carlo algorithm introduced in Berthier et al. 5 Here an NVT simulation is performed for t_md time and then N_swap MC moves are performed that swap the radii of randomly chosen particles. The random swaps are accepted with Metropolis-Hastings step. Each call to the step function runs molecular dynamics for t_md and then performs the swaps.

Note that this code doesn’t feature some of the convenience functions in the other simulations. In particular, there is no support for dynamics keyword arguments and the energy function must be a simple callable of two variables: the distance between adjacent particles and the diameter of the particles. If you want support for a better notion of potential or dynamic keyword arguments, please file an issue!

Parameters
  • space_fns (Tuple[Callable[[Array, Array], Array], Callable[[Array, Array], Array]]) – A tuple of a displacement function and a shift function defined in space.py.

  • energy_fn (Callable[[Array, Array], Array]) – A function that computes the energy between one pair of particles as a function of the distance between the particles and the diameter. This function should not have been passed to smap.xxx.

  • neighbor_fn (Callable[[Array, Optional[NeighborList], Optional[int]], NeighborList]) – A function to construct neighbor lists outlined in partition.py.

  • dt (float) – The timestep used for the continuous time MD portion of the simulation.

  • kT (float) – The temperature of heat bath that the system is coupled to during MD.

  • t_md (float) – The time of each MD block.

  • N_swap (int) – The number of swapping moves between MD blocks.

  • sigma_fn (Optional[Callable[[Array], Array]]) – An optional function for combining radii if they are to be non-additive.

Return type

Tuple[Callable[…, ~T], Callable[[~T], ~T]]

Returns

See above.

References

5

L. Berthier, E. Flenner, C. J. Fullerton, C. Scalliet, and M. Singh. “Efficient swap algorithms for molecular dynamics simulations of equilibrium supercooled liquids”, J. Stat. Mech. (2019) 064004

Helper Functions#

jax_md.simulate.velocity_verlet(force_fn, shift_fn, dt, state, **kwargs)[source]#

Apply a single step of velocity Verlet integration to a state.

Return type

~T

jax_md.simulate.nose_hoover_chain(dt, chain_length, chain_steps, sy_steps, tau)[source]#

Helper function to simulate a Nose-Hoover Chain coupled to a system.

This function is used in simulations that sample from thermal ensembles by coupling the system to one, or more, Nose-Hoover chains. We use the direct translation method outlined in Martyna et al. 1 and the Nose-Hoover chains are updated using two half steps: one at the beginning of a simulation step and one at the end. The masses of the Nose-Hoover chains are updated automatically to enforce a specific period of oscillation, tau. Larger values of tau will yield systems that reach the target temperature more slowly but are also more stable.

As described in Martyna et al. 1, the Nose-Hoover chain often evolves on a faster timescale than the rest of the simulation. Therefore, it sometimes necessary to integrate the chain over several substeps for each step of MD. To do this we follow the Suzuki-Yoshida scheme. Specifically, we subdivide our chain simulation into \(n_c\) substeps. These substeps are further subdivided into \(n_sy\) steps. Each \(n_sy\) step has length \(\delta_i = \Delta t w_i / n_c\) where \(w_i\) are constants such that \(\sum_i w_i = 1\). See the table of Suzuki-Yoshida weights above for specific values. The number of substeps and the number of Suzuki-Yoshida steps are set using the chain_steps and sy_steps arguments.

Consequently, the Nose-Hoover chains are described by three functions: an init_fn that initializes the state of the chain, a half_step_fn that updates the chain for one half-step, and an update_chain_mass_fn that updates the masses of the chain to enforce the correct period of oscillation.

Note that a system can have many Nose-Hoover chains coupled to it to produce, for example, a temperature gradient. We also note that the NPT ensemble naturally features two chains: one that couples to the thermal degrees of freedom and one that couples to the barostat.

jax_md.simulate.dt#

Floating point number specifying the timescale (step size) of the simulation.

jax_md.simulate.chain_length#

An integer specifying the number of particles in the Nose-Hoover chain.

jax_md.simulate.chain_steps#

An integer specifying the number \(n_c\) of outer substeps.

jax_md.simulate.sy_steps#

An integer specifying the number of Suzuki-Yoshida steps. This must be either 1, 3, 5, or 7.

jax_md.simulate.tau#

A floating point timescale over which temperature equilibration occurs. Measured in units of dt. The performance of the Nose-Hoover chain thermostat can be quite sensitive to this choice.

Return type

NoseHooverChainFns

Returns

A triple of functions that initialize the chain, do a half step of simulation, and update the chain masses respectively.

jax_md.simulate.npt_box(state)[source]#

Get the current box from an NPT simulation.

Return type

Array

Testing Functions#

jax_md.simulate.nvt_nose_hoover_invariant(energy_fn, state, kT, **kwargs)[source]#

The conserved quantity for the NVT ensemble with a Nose-Hoover thermostat.

This function is normally used for debugging the Nose-Hoover thermostat.

Parameters
  • energy_fn (Callable[…, Array]) – The energy function of the Nose-Hoover system.

  • state (NVTNoseHooverState) – The current state of the system.

  • kT (float) – The current goal temperature of the system.

Return type

float

Returns

The Hamiltonian of the extended NVT dynamics.

jax_md.simulate.npt_nose_hoover_invariant(energy_fn, state, pressure, kT, **kwargs)[source]#

The conserved quantity for the NPT ensemble with a Nose-Hoover thermostat.

This function is normally used for debugging the NPT simulation.

Parameters
  • energy_fn (Callable[…, Array]) – The energy function of the system.

  • state (NPTNoseHooverState) – The current state of the system.

  • pressure (float) – The current goal pressure of the system.

  • kT (float) – The current goal temperature of the system.

Return type

float

Returns

The Hamiltonian of the extended NPT dynamics.

Data Types#

class jax_md.simulate.NoseHooverChain(position, momentum, mass, tau, kinetic_energy, degrees_of_freedom)[source]#

State information for a Nose-Hoover chain.

position#

An ndarray of shape [chain_length] that stores the position of the chain.

Type

jax.Array

momentum#

An ndarray of shape [chain_length] that stores the momentum of the chain.

Type

jax.Array

mass#

An ndarray of shape [chain_length] that stores the mass of the chain.

Type

jax.Array

tau#

The desired period of oscillation for the chain. Longer periods result is better stability but worse temperature control.

Type

jax.Array

kinetic_energy#

A float that stores the current kinetic energy of the system that the chain is coupled to.

Type

jax.Array

degrees_of_freedom#

An integer specifying the number of degrees of freedom that the chain is coupled to.

Type

int

class jax_md.simulate.NVEState(position, momentum, force, mass)[source]#

A struct containing the state of an NVE simulation.

This tuple stores the state of a simulation that samples from the microcanonical ensemble in which the (N)umber of particles, the (V)olume, and the (E)nergy of the system are held fixed.

position#

An ndarray of shape [n, spatial_dimension] storing the position of particles.

Type

jax.Array

momentum#

An ndarray of shape [n, spatial_dimension] storing the momentum of particles.

Type

jax.Array

force#

An ndarray of shape [n, spatial_dimension] storing the force acting on particles from the previous step.

Type

jax.Array

mass#

A float or an ndarray of shape [n] containing the masses of the particles.

Type

jax.Array

class jax_md.simulate.NVTNoseHooverState(position, momentum, force, mass, chain)[source]#

State information for an NVT system with a Nose-Hoover chain thermostat.

position#

The current position of particles. An ndarray of floats with shape [n, spatial_dimension].

Type

jax.Array

momentum#

The momentum of particles. An ndarray of floats with shape [n, spatial_dimension].

Type

jax.Array

force#

The current force on the particles. An ndarray of floats with shape [n, spatial_dimension].

Type

jax.Array

mass#

The mass of the particles. Can either be a float or an ndarray of floats with shape [n].

Type

jax.Array

chain#

The variables describing the Nose-Hoover chain.

Type

jax_md.simulate.NoseHooverChain

class jax_md.simulate.NPTNoseHooverState(position, momentum, force, mass, reference_box, box_position, box_momentum, box_mass, barostat, thermostat)[source]#

State information for an NPT system with Nose-Hoover chain thermostats.

position#

The current position of particles. An ndarray of floats with shape [n, spatial_dimension].

Type

jax.Array

momentum#

The velocity of particles. An ndarray of floats with shape [n, spatial_dimension].

Type

jax.Array

force#

The current force on the particles. An ndarray of floats with shape [n, spatial_dimension].

Type

jax.Array

mass#

The mass of the particles. Can either be a float or an ndarray of floats with shape [n].

Type

jax.Array

reference_box#

A box used to measure relative changes to the simulation environment.

Type

jax.Array

box_position#

A positional degree of freedom used to describe the current box. box_position is parameterized as box_position = (1/d)log(V/V_0) where V is the current volume, V_0 is the reference volume, and d is the spatial dimension.

Type

jax.Array

box_velocity#

A velocity degree of freedom for the box.

box_mass#

The mass assigned to the box.

Type

jax.Array

barostat#

The variables describing the Nose-Hoover chain coupled to the barostat.

Type

jax_md.simulate.NoseHooverChain

thermostsat#

The variables describing the Nose-Hoover chain coupled to the thermostat.

class jax_md.simulate.NVTLangevinState(position, momentum, force, mass, rng)[source]#

A struct containing state information for the Langevin thermostat.

position#

The current position of the particles. An ndarray of floats with shape [n, spatial_dimension].

Type

jax.Array

momentum#

The momentum of particles. An ndarray of floats with shape [n, spatial_dimension].

Type

jax.Array

force#

The (non-stochastic) force on particles. An ndarray of floats with shape [n, spatial_dimension].

Type

jax.Array

mass#

The mass of particles. Will either be a float or an ndarray of floats with shape [n].

Type

jax.Array

rng#

The current state of the random number generator.

Type

jax.Array

class jax_md.simulate.BrownianState(position, mass, rng)[source]#

A tuple containing state information for Brownian dynamics.

position#

The current position of the particles. An ndarray of floats with shape [n, spatial_dimension].

Type

jax.Array

mass#

The mass of particles. Will either be a float or an ndarray of floats with shape [n].

Type

jax.Array

rng#

The current state of the random number generator.

Type

jax.Array