Source code for jax_md.simulate

# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Code to simulate systems in various statistical ensembles.

  This file contains a number of different methods that can be used to
  simulate systems in a variety of ensembles.

  In general, simulation code follows the same overall structure as optimizers
  in JAX. Simulations are tuples of two functions:

    init_fn:
      Function that initializes the  state of a system. Should take
      positions as an ndarray of shape `[n, output_dimension]`. Returns a state
      which will be a namedtuple.
    apply_fn:
      Function that takes a state and produces a new state after one
      step of optimization.

  One question that we need to think about is whether the simulations should
  also return a function that computes the invariant for that ensemble. This
  can be used for testing purposes, but is not often used otherwise.
"""

from collections import namedtuple

from typing import Any, Callable, TypeVar, Union, Tuple, Dict, Optional

import functools

from jax import grad
from jax import jit
from jax import random
import jax.numpy as jnp
from jax import lax
from jax.tree_util import tree_map, tree_reduce, tree_flatten, tree_unflatten

from jax_md import quantity
from jax_md import util
from jax_md import space
from jax_md import dataclasses
from jax_md import partition
from jax_md import smap

static_cast = util.static_cast


# Types


Array = util.Array
f32 = util.f32
f64 = util.f64

Box = space.Box

ShiftFn = space.ShiftFn

T = TypeVar('T')
InitFn = Callable[..., T]
ApplyFn = Callable[[T], T]
Simulator = Tuple[InitFn, ApplyFn]


"""Dispatch By State Code.

JAX MD allows for simulations to be extensible using a dispatch strategy where
functions are dispatched to specific cases based on the type of state provided.
In particular, we make decisions about which function to call based on the type
of the position argument. For those familiar with C / C++, our dispatch code is
essentially function overloading based on the type of the positions.

If you are interested in setting up a simulation using a different type of
system you can do so in a relatively light weight manner by introducing a new
type for storing the state that is compatible with the JAX PyTree system
(we usually choose a dataclass) and then overriding the functions below.

These extensions allow a range of simulations to be run by just changing the
type of the position argument. There are essentially two types of functions to
be overloaded. Functions that compute physical quantities, such as the kinetic
energy, and functions that evolve a state according to the Suzuki-Trotter
decomposition. Specifically, one might want to override the position step,
momentum step for deterministic and stochastic simulations or the
`stochastic_step` for stochastic simulations (e.g Langevin).
"""


class dispatch_by_state:
  """Wrap a function and dispatch based on the type of positions."""
  def __init__(self, fn):
    self._fn = fn
    self._registry = {}

  def __call__(self, state, *args, **kwargs):
    if type(state.position) in self._registry:
      return self._registry[type(state.position)](state, *args, **kwargs)
    return self._fn(state, *args, **kwargs)

  def register(self, oftype):
    def register_fn(fn):
      self._registry[oftype] = fn
    return register_fn


@dispatch_by_state
def canonicalize_mass(state: T) -> T:
  """Reshape mass vector for broadcasting with positions."""
  def canonicalize_fn(mass):
    if isinstance(mass, float):
      return mass
    if mass.ndim == 2 and mass.shape[1] == 1:
      return mass
    elif mass.ndim == 1:
      return jnp.reshape(mass, (mass.shape[0], 1))
    elif mass.ndim == 0:
      return mass
    msg = (
      'Expected mass to be either a floating point number or a one-dimensional'
      'ndarray. Found {}.'.format(mass)
    )
    raise ValueError(msg)
  return state.set(mass=tree_map(canonicalize_fn, state.mass))

@dispatch_by_state
def initialize_momenta(state: T, key: Array, kT: float) -> T:
  """Initialize momenta with the Maxwell-Boltzmann distribution."""
  R, mass = state.position, state.mass

  R, treedef = tree_flatten(R)
  mass, _ = tree_flatten(mass)
  keys = random.split(key, len(R))

  def initialize_fn(k, r, m):
    p = jnp.sqrt(m * kT) * random.normal(k, r.shape, dtype=r.dtype)
    # If simulating more than one particle, center the momentum.
    if r.shape[0] > 1:
      p = p - jnp.mean(p, axis=0, keepdims=True)
    return p

  P = [initialize_fn(k, r, m) for k, r, m in zip(keys, R, mass)]

  return state.set(momentum=tree_unflatten(treedef, P))


@dispatch_by_state
def momentum_step(state: T, dt: float) -> T:
  """Apply a single step of the time evolution operator for momenta."""
  assert hasattr(state, 'momentum')
  new_momentum = tree_map(lambda p, f: p + dt * f,
                          state.momentum,
                          state.force)
  return state.set(momentum=new_momentum)


@dispatch_by_state
def position_step(state: T, shift_fn: Callable, dt: float, **kwargs) -> T:
  """Apply a single step of the time evolution operator for positions."""
  if isinstance(shift_fn, Callable):
    shift_fn = tree_map(lambda r: shift_fn, state.position)
  new_position = tree_map(lambda s_fn, r, p, m: s_fn(r, dt * p / m, **kwargs),
                          shift_fn,
                          state.position,
                          state.momentum,
                          state.mass)
  return state.set(position=new_position)


@dispatch_by_state
def kinetic_energy(state: T) -> Array:
  """Compute the kinetic energy of a state."""
  return quantity.kinetic_energy(momentum=state.momentum, mass=state.mass)


@dispatch_by_state
def temperature(state: T) -> Array:
  """Compute the temperature of a state."""
  return quantity.temperature(momentum=state.momentum, mass=state.mass)


"""Deterministic Simulations

JAX MD includes integrators for deterministic simulations of the NVE, NVT, and
NPT ensembles. For a qualitative description of statistical physics ensembles
see the wikipedia article here:
en.wikipedia.org/wiki/Statistical_ensemble_(mathematical_physics)

Integrators are based direct translation method outlined in the paper,

"A Liouville-operator derived measure-preserving integrator for molecular
dynamics simulations in the isothermal–isobaric ensemble"

M. E. Tuckerman, J. Alejandre, R. López-Rendón, A. L Jochim, and G. J. Martyna
J. Phys. A: Math. Gen. 39 5629 (2006)

As such, we define several primitives that are generically useful in describing
simulations of this type. Namely, the velocity-Verlet integration step that is
used in the NVE and NVT simulations. We also define a general Nose-Hoover chain
primitive that is used to couple components of the system to a chain that
regulates the temperature. These primitives can be combined to construct more
interesting simulations that involve e.g. temperature gradients.
"""


[docs]def velocity_verlet(force_fn: Callable[..., Array], shift_fn: ShiftFn, dt: float, state: T, **kwargs) -> T: """Apply a single step of velocity Verlet integration to a state.""" dt = f32(dt) dt_2 = f32(dt / 2) state = momentum_step(state, dt_2) state = position_step(state, shift_fn, dt, **kwargs) state = state.set(force=force_fn(state.position, **kwargs)) state = momentum_step(state, dt_2) return state
# Constant Energy Simulations
[docs]@dataclasses.dataclass class NVEState: """A struct containing the state of an NVE simulation. This tuple stores the state of a simulation that samples from the microcanonical ensemble in which the (N)umber of particles, the (V)olume, and the (E)nergy of the system are held fixed. Attributes: position: An ndarray of shape `[n, spatial_dimension]` storing the position of particles. momentum: An ndarray of shape `[n, spatial_dimension]` storing the momentum of particles. force: An ndarray of shape `[n, spatial_dimension]` storing the force acting on particles from the previous step. mass: A float or an ndarray of shape `[n]` containing the masses of the particles. """ position: Array momentum: Array force: Array mass: Array @property def velocity(self) -> Array: return self.momentum / self.mass
# pylint: disable=invalid-name
[docs]def nve(energy_or_force_fn, shift_fn, dt=1e-3, **sim_kwargs): """Simulates a system in the NVE ensemble. Samples from the microcanonical ensemble in which the number of particles (N), the system volume (V), and the energy (E) are held constant. We use a standard velocity Verlet integration scheme. Args: energy_or_force: A function that produces either an energy or a force from a set of particle positions specified as an ndarray of shape `[n, spatial_dimension]`. shift_fn: A function that displaces positions, `R`, by an amount `dR`. Both `R` and `dR` should be ndarrays of shape `[n, spatial_dimension]`. dt: Floating point number specifying the timescale (step size) of the simulation. Returns: See above. """ force_fn = quantity.canonicalize_force(energy_or_force_fn) @jit def init_fn(key, R, kT, mass=f32(1.0), **kwargs): force = force_fn(R, **kwargs) state = NVEState(R, None, force, mass) state = canonicalize_mass(state) return initialize_momenta(state, key, kT) @jit def step_fn(state, **kwargs): _dt = kwargs.pop('dt', dt) return velocity_verlet(force_fn, shift_fn, _dt, state, **kwargs) return init_fn, step_fn
# Constant Temperature Simulations # Suzuki-Yoshida weights for integrators of different order. # These are copied from OpenMM at # https://github.com/openmm/openmm/blob/master/openmmapi/src/NoseHooverChain.cpp SUZUKI_YOSHIDA_WEIGHTS = { 1: [1], 3: [0.828981543588751, -0.657963087177502, 0.828981543588751], 5: [0.2967324292201065, 0.2967324292201065, -0.186929716880426, 0.2967324292201065, 0.2967324292201065], 7: [0.784513610477560, 0.235573213359357, -1.17767998417887, 1.31518632068391, -1.17767998417887, 0.235573213359357, 0.784513610477560] }
[docs]@dataclasses.dataclass class NoseHooverChain: """State information for a Nose-Hoover chain. Attributes: position: An ndarray of shape `[chain_length]` that stores the position of the chain. momentum: An ndarray of shape `[chain_length]` that stores the momentum of the chain. mass: An ndarray of shape `[chain_length]` that stores the mass of the chain. tau: The desired period of oscillation for the chain. Longer periods result is better stability but worse temperature control. kinetic_energy: A float that stores the current kinetic energy of the system that the chain is coupled to. degrees_of_freedom: An integer specifying the number of degrees of freedom that the chain is coupled to. """ position: Array momentum: Array mass: Array tau: Array kinetic_energy: Array degrees_of_freedom: int=dataclasses.static_field()
@dataclasses.dataclass class NoseHooverChainFns: initialize: Callable half_step: Callable update_mass: Callable
[docs]def nose_hoover_chain(dt: float, chain_length: int, chain_steps: int, sy_steps: int, tau: float ) -> NoseHooverChainFns: """Helper function to simulate a Nose-Hoover Chain coupled to a system. This function is used in simulations that sample from thermal ensembles by coupling the system to one, or more, Nose-Hoover chains. We use the direct translation method outlined in Martyna et al. [#martyna92]_ and the Nose-Hoover chains are updated using two half steps: one at the beginning of a simulation step and one at the end. The masses of the Nose-Hoover chains are updated automatically to enforce a specific period of oscillation, `tau`. Larger values of `tau` will yield systems that reach the target temperature more slowly but are also more stable. As described in Martyna et al. [#martyna92]_, the Nose-Hoover chain often evolves on a faster timescale than the rest of the simulation. Therefore, it sometimes necessary to integrate the chain over several substeps for each step of MD. To do this we follow the Suzuki-Yoshida scheme. Specifically, we subdivide our chain simulation into :math:`n_c` substeps. These substeps are further subdivided into :math:`n_sy` steps. Each :math:`n_sy` step has length :math:`\delta_i = \Delta t w_i / n_c` where :math:`w_i` are constants such that :math:`\sum_i w_i = 1`. See the table of Suzuki-Yoshida weights above for specific values. The number of substeps and the number of Suzuki-Yoshida steps are set using the `chain_steps` and `sy_steps` arguments. Consequently, the Nose-Hoover chains are described by three functions: an `init_fn` that initializes the state of the chain, a `half_step_fn` that updates the chain for one half-step, and an `update_chain_mass_fn` that updates the masses of the chain to enforce the correct period of oscillation. Note that a system can have many Nose-Hoover chains coupled to it to produce, for example, a temperature gradient. We also note that the NPT ensemble naturally features two chains: one that couples to the thermal degrees of freedom and one that couples to the barostat. Attributes: dt: Floating point number specifying the timescale (step size) of the simulation. chain_length: An integer specifying the number of particles in the Nose-Hoover chain. chain_steps: An integer specifying the number :math:`n_c` of outer substeps. sy_steps: An integer specifying the number of Suzuki-Yoshida steps. This must be either `1`, `3`, `5`, or `7`. tau: A floating point timescale over which temperature equilibration occurs. Measured in units of `dt`. The performance of the Nose-Hoover chain thermostat can be quite sensitive to this choice. Returns: A triple of functions that initialize the chain, do a half step of simulation, and update the chain masses respectively. """ def init_fn(degrees_of_freedom, KE, kT): xi = jnp.zeros(chain_length, KE.dtype) p_xi = jnp.zeros(chain_length, KE.dtype) Q = kT * tau ** f32(2) * jnp.ones(chain_length, dtype=f32) Q = Q.at[0].multiply(degrees_of_freedom) return NoseHooverChain(xi, p_xi, Q, tau, KE, degrees_of_freedom) def substep_fn(delta, P, state, kT): """Apply a single update to the chain parameters and rescales velocity.""" xi, p_xi, Q, _tau, KE, DOF = dataclasses.astuple(state) delta_2 = delta / f32(2.0) delta_4 = delta_2 / f32(2.0) delta_8 = delta_4 / f32(2.0) M = chain_length - 1 G = (p_xi[M - 1] ** f32(2) / Q[M - 1] - kT) p_xi = p_xi.at[M].add(delta_4 * G) def backward_loop_fn(p_xi_new, m): G = p_xi[m - 1] ** 2 / Q[m - 1] - kT scale = jnp.exp(-delta_8 * p_xi_new / Q[m + 1]) p_xi_new = scale * (scale * p_xi[m] + delta_4 * G) return p_xi_new, p_xi_new idx = jnp.arange(M - 1, 0, -1) _, p_xi_update = lax.scan(backward_loop_fn, p_xi[M], idx, unroll=2) p_xi = p_xi.at[idx].set(p_xi_update) G = f32(2.0) * KE - DOF * kT scale = jnp.exp(-delta_8 * p_xi[1] / Q[1]) p_xi = p_xi.at[0].set(scale * (scale * p_xi[0] + delta_4 * G)) scale = jnp.exp(-delta_2 * p_xi[0] / Q[0]) KE = KE * scale ** f32(2) P = tree_map(lambda p: p * scale, P) xi = xi + delta_2 * p_xi / Q G = f32(2) * KE - DOF * kT def forward_loop_fn(G, m): scale = jnp.exp(-delta_8 * p_xi[m + 1] / Q[m + 1]) p_xi_update = scale * (scale * p_xi[m] + delta_4 * G) G = p_xi_update ** 2 / Q[m] - kT return G, p_xi_update idx = jnp.arange(M) G, p_xi_update = lax.scan(forward_loop_fn, G, idx, unroll=2) p_xi = p_xi.at[idx].set(p_xi_update) p_xi = p_xi.at[M].add(delta_4 * G) return P, NoseHooverChain(xi, p_xi, Q, _tau, KE, DOF), kT def half_step_chain_fn(P, state, kT): if chain_steps == 1 and sy_steps == 1: P, state, _ = substep_fn(dt, P, state, kT) return P, state delta = dt / chain_steps ws = jnp.array(SUZUKI_YOSHIDA_WEIGHTS[sy_steps]) def body_fn(cs, i): d = f32(delta * ws[i % sy_steps]) return substep_fn(d, *cs), 0 P, state, _ = lax.scan(body_fn, (P, state, kT), jnp.arange(chain_steps * sy_steps))[0] return P, state def update_chain_mass_fn(state, kT): xi, p_xi, Q, _tau, KE, DOF = dataclasses.astuple(state) Q = kT * _tau ** f32(2) * jnp.ones(chain_length, dtype=f32) Q = Q.at[0].multiply(DOF) return NoseHooverChain(xi, p_xi, Q, _tau, KE, DOF) return NoseHooverChainFns(init_fn, half_step_chain_fn, update_chain_mass_fn)
def default_nhc_kwargs(tau: float, overrides: Dict) -> Dict: default_kwargs = { 'chain_length': 3, 'chain_steps': 2, 'sy_steps': 3, 'tau': tau } if overrides is None: return default_kwargs return { key: overrides.get(key, default_kwargs[key]) for key in default_kwargs }
[docs]@dataclasses.dataclass class NVTNoseHooverState: """State information for an NVT system with a Nose-Hoover chain thermostat. Attributes: position: The current position of particles. An ndarray of floats with shape `[n, spatial_dimension]`. momentum: The momentum of particles. An ndarray of floats with shape `[n, spatial_dimension]`. force: The current force on the particles. An ndarray of floats with shape `[n, spatial_dimension]`. mass: The mass of the particles. Can either be a float or an ndarray of floats with shape `[n]`. chain: The variables describing the Nose-Hoover chain. """ position: Array momentum: Array force: Array mass: Array chain: NoseHooverChain @property def velocity(self): return self.momentum / self.mass
[docs]def nvt_nose_hoover(energy_or_force_fn: Callable[..., Array], shift_fn: ShiftFn, dt: float, kT: float, chain_length: int=5, chain_steps: int=2, sy_steps: int=3, tau: Optional[float]=None, **sim_kwargs) -> Simulator: """Simulation in the NVT ensemble using a Nose Hoover Chain thermostat. Samples from the canonical ensemble in which the number of particles (N), the system volume (V), and the temperature (T) are held constant. We use a Nose Hoover Chain (NHC) thermostat described in [#martyna92]_ [#martyna98]_ [#tuckerman]_. We follow the direct translation method outlined in Tuckerman et al. [#tuckerman]_ and the interested reader might want to look at that paper as a reference. Args: energy_or_force: A function that produces either an energy or a force from a set of particle positions specified as an ndarray of shape `[n, spatial_dimension]`. shift_fn: A function that displaces positions, `R`, by an amount `dR`. Both `R` and `dR` should be ndarrays of shape `[n, spatial_dimension]`. dt: Floating point number specifying the timescale (step size) of the simulation. kT: Floating point number specifying the temperature in units of Boltzmann constant. To update the temperature dynamically during a simulation one should pass `kT` as a keyword argument to the step function. chain_length: An integer specifying the number of particles in the Nose-Hoover chain. chain_steps: An integer specifying the number, :math:`n_c`, of outer substeps. sy_steps: An integer specifying the number of Suzuki-Yoshida steps. This must be either `1`, `3`, `5`, or `7`. tau: A floating point timescale over which temperature equilibration occurs. Measured in units of `dt`. The performance of the Nose-Hoover chain thermostat can be quite sensitive to this choice. Returns: See above. .. rubric:: References .. [#martyna92] Martyna, Glenn J., Michael L. Klein, and Mark Tuckerman. "Nose-Hoover chains: The canonical ensemble via continuous dynamics." The Journal of chemical physics 97, no. 4 (1992): 2635-2643. .. [#martyna98] Martyna, Glenn, Mark Tuckerman, Douglas J. Tobias, and Michael L. Klein. "Explicit reversible integrators for extended systems dynamics." Molecular Physics 87. (1998) 1117-1157. .. [#tuckerman] Tuckerman, Mark E., Jose Alejandre, Roberto Lopez-Rendon, Andrea L. Jochim, and Glenn J. Martyna. "A Liouville-operator derived measure-preserving integrator for molecular dynamics simulations in the isothermal-isobaric ensemble." Journal of Physics A: Mathematical and General 39, no. 19 (2006): 5629. """ force_fn = quantity.canonicalize_force(energy_or_force_fn) dt = f32(dt) dt_2 = f32(dt / 2) if tau is None: tau = dt * 100 tau = f32(tau) thermostat = nose_hoover_chain(dt, chain_length, chain_steps, sy_steps, tau) @jit def init_fn(key, R, mass=f32(1.0), **kwargs): _kT = kT if 'kT' not in kwargs else kwargs['kT'] dof = quantity.count_dof(R) state = NVTNoseHooverState(R, None, force_fn(R, **kwargs), mass, None) state = canonicalize_mass(state) state = initialize_momenta(state, key, _kT) KE = kinetic_energy(state) return state.set(chain=thermostat.initialize(dof, KE, _kT)) @jit def apply_fn(state, **kwargs): _kT = kT if 'kT' not in kwargs else kwargs['kT'] chain = state.chain chain = thermostat.update_mass(chain, _kT) p, chain = thermostat.half_step(state.momentum, chain, _kT) state = state.set(momentum=p) state = velocity_verlet(force_fn, shift_fn, dt, state, **kwargs) chain = chain.set(kinetic_energy=kinetic_energy(state)) p, chain = thermostat.half_step(state.momentum, chain, _kT) state = state.set(momentum=p, chain=chain) return state return init_fn, apply_fn
[docs]def nvt_nose_hoover_invariant(energy_fn: Callable[..., Array], state: NVTNoseHooverState, kT: float, **kwargs) -> float: """The conserved quantity for the NVT ensemble with a Nose-Hoover thermostat. This function is normally used for debugging the Nose-Hoover thermostat. Arguments: energy_fn: The energy function of the Nose-Hoover system. state: The current state of the system. kT: The current goal temperature of the system. Returns: The Hamiltonian of the extended NVT dynamics. """ PE = energy_fn(state.position, **kwargs) KE = kinetic_energy(state) DOF = quantity.count_dof(state.position) E = PE + KE c = state.chain E += c.momentum[0] ** 2 / (2 * c.mass[0]) + DOF * kT * c.position[0] for r, p, m in zip(c.position[1:], c.momentum[1:], c.mass[1:]): E += p ** 2 / (2 * m) + kT * r return E
[docs]@dataclasses.dataclass class NPTNoseHooverState: """State information for an NPT system with Nose-Hoover chain thermostats. Attributes: position: The current position of particles. An ndarray of floats with shape `[n, spatial_dimension]`. momentum: The velocity of particles. An ndarray of floats with shape `[n, spatial_dimension]`. force: The current force on the particles. An ndarray of floats with shape `[n, spatial_dimension]`. mass: The mass of the particles. Can either be a float or an ndarray of floats with shape `[n]`. reference_box: A box used to measure relative changes to the simulation environment. box_position: A positional degree of freedom used to describe the current box. box_position is parameterized as `box_position = (1/d)log(V/V_0)` where `V` is the current volume, `V_0` is the reference volume, and `d` is the spatial dimension. box_velocity: A velocity degree of freedom for the box. box_mass: The mass assigned to the box. barostat: The variables describing the Nose-Hoover chain coupled to the barostat. thermostsat: The variables describing the Nose-Hoover chain coupled to the thermostat. """ position: Array momentum: Array force: Array mass: Array reference_box: Box box_position: Array box_momentum: Array box_mass: Array barostat: NoseHooverChain thermostat: NoseHooverChain @property def velocity(self) -> Array: return self.momentum / self.mass @property def box(self) -> Array: """Get the current box from an NPT simulation.""" dim = self.position.shape[1] ref = self.reference_box V_0 = quantity.volume(dim, ref) V = V_0 * jnp.exp(dim * self.box_position) return (V / V_0) ** (1 / dim) * ref
def _npt_box_info(state: NPTNoseHooverState ) -> Tuple[float, Callable[[float], float]]: """Gets the current volume and a function to compute the box from volume.""" dim = state.position.shape[1] ref = state.reference_box V_0 = quantity.volume(dim, ref) V = V_0 * jnp.exp(dim * state.box_position) return V, lambda V: (V / V_0) ** (1 / dim) * ref
[docs]def npt_box(state: NPTNoseHooverState) -> Box: """Get the current box from an NPT simulation.""" dim = state.position.shape[1] ref = state.reference_box V_0 = quantity.volume(dim, ref) V = V_0 * jnp.exp(dim * state.box_position) return (V / V_0) ** (1 / dim) * ref
[docs]def npt_nose_hoover(energy_fn: Callable[..., Array], shift_fn: ShiftFn, dt: float, pressure: float, kT: float, barostat_kwargs: Optional[Dict]=None, thermostat_kwargs: Optional[Dict]=None) -> Simulator: """Simulation in the NPT ensemble using a pair of Nose Hoover Chains. Samples from the canonical ensemble in which the number of particles (N), the system pressure (P), and the temperature (T) are held constant. We use a pair of Nose Hoover Chains (NHC) described in [#martyna92]_ [#martyna98]_ [#tuckerman]_ coupled to the barostat and the thermostat respectively. We follow the direct translation method outlined in Tuckerman et al. [#tuckerman]_ and the interested reader might want to look at that paper as a reference. Args: energy_fn: A function that produces either an energy from a set of particle positions specified as an ndarray of shape `[n, spatial_dimension]`. shift_fn: A function that displaces positions, `R`, by an amount `dR`. Both `R` and `dR` should be ndarrays of shape `[n, spatial_dimension]`. dt: Floating point number specifying the timescale (step size) of the simulation. pressure: Floating point number specifying the target pressure. To update the pressure dynamically during a simulation one should pass `pressure` as a keyword argument to the step function. kT: Floating point number specifying the temperature in units of Boltzmann constant. To update the temperature dynamically during a simulation one should pass `kT` as a keyword argument to the step function. barostat_kwargs: A dictionary of keyword arguments passed to the barostat NHC. Any parameters not set are drawn from a relatively robust default set. thermostat_kwargs: A dictionary of keyword arguments passed to the thermostat NHC. Any parameters not set are drawn from a relatively robust default set. Returns: See above. """ t = f32(dt) dt_2 = f32(dt / 2) force_fn = quantity.force(energy_fn) barostat_kwargs = default_nhc_kwargs(1000 * dt, barostat_kwargs) barostat = nose_hoover_chain(dt, **barostat_kwargs) thermostat_kwargs = default_nhc_kwargs(100 * dt, thermostat_kwargs) thermostat = nose_hoover_chain(dt, **thermostat_kwargs) def init_fn(key, R, box, mass=f32(1.0), **kwargs): N, dim = R.shape _kT = kT if 'kT' not in kwargs else kwargs['kT'] # The box position is defined via pos = (1 / d) log V / V_0. zero = jnp.zeros((), dtype=R.dtype) one = jnp.ones((), dtype=R.dtype) box_position = zero box_momentum = zero box_mass = dim * (N + 1) * kT * barostat_kwargs['tau'] ** 2 * one KE_box = quantity.kinetic_energy(momentum=box_momentum, mass=box_mass) if jnp.isscalar(box) or box.ndim == 0: # TODO(schsam): This is necessary because of JAX issue #5849. box = jnp.eye(R.shape[-1]) * box state = NPTNoseHooverState( R, None, force_fn(R, box=box, **kwargs), mass, box, box_position, box_momentum, box_mass, barostat.initialize(1, KE_box, _kT), None) # pytype: disable=wrong-arg-count state = canonicalize_mass(state) state = initialize_momenta(state, key, _kT) KE = kinetic_energy(state) return state.set( thermostat=thermostat.initialize(quantity.count_dof(R), KE, _kT)) def update_box_mass(state, kT): N, dim = state.position.shape dtype = state.position.dtype box_mass = jnp.array(dim * (N + 1) * kT * state.barostat.tau ** 2, dtype) return state.set(box_mass=box_mass) def box_force(alpha, vol, box_fn, position, momentum, mass, force, pressure, **kwargs): N, dim = position.shape def U(eps): return energy_fn(position, box=box_fn(vol), perturbation=(1 + eps), **kwargs) dUdV = grad(U) KE2 = util.high_precision_sum(momentum ** 2 / mass) return alpha * KE2 - dUdV(0.0) - pressure * vol * dim def sinhx_x(x): """Taylor series for sinh(x) / x as x -> 0.""" return (1 + x ** 2 / 6 + x ** 4 / 120 + x ** 6 / 5040 + x ** 8 / 362_880 + x ** 10 / 39_916_800) def exp_iL1(box, R, V, V_b, **kwargs): x = V_b * dt x_2 = x / 2 sinhV = sinhx_x(x_2) # jnp.sinh(x_2) / x_2 return shift_fn(R, R * (jnp.exp(x) - 1) + dt * V * jnp.exp(x_2) * sinhV, box=box, **kwargs) # pytype: disable=wrong-keyword-args def exp_iL2(alpha, P, F, V_b): x = alpha * V_b * dt_2 x_2 = x / 2 sinhP = sinhx_x(x_2) # jnp.sinh(x_2) / x_2 return P * jnp.exp(-x) + dt_2 * F * sinhP * jnp.exp(-x_2) def inner_step(state, **kwargs): _pressure = kwargs.pop('pressure', pressure) R, P, M, F = state.position, state.momentum, state.mass, state.force R_b, P_b, M_b = state.box_position, state.box_momentum, state.box_mass N, dim = R.shape vol, box_fn = _npt_box_info(state) alpha = 1 + 1 / N G_e = box_force(alpha, vol, box_fn, R, P, M, F, _pressure, **kwargs) P_b = P_b + dt_2 * G_e P = exp_iL2(alpha, P, F, P_b / M_b) R_b = R_b + P_b / M_b * dt state = state.set( box_position=R_b) vol, box_fn = _npt_box_info(state) box = box_fn(vol) R = exp_iL1(box, R, P / M, P_b / M_b) F = force_fn(R, box=box, **kwargs) P = exp_iL2(alpha, P, F, P_b / M_b) G_e = box_force(alpha, vol, box_fn, R, P, M, F, _pressure, **kwargs) P_b = P_b + dt_2 * G_e return state.set(position=R, momentum=P, mass=M, force=F, box_position=R_b, box_momentum=P_b, box_mass=M_b) def apply_fn(state, **kwargs): S = state _kT = kT if 'kT' not in kwargs else kwargs['kT'] bc = barostat.update_mass(S.barostat, _kT) tc = thermostat.update_mass(S.thermostat, _kT) S = update_box_mass(S, _kT) P_b, bc = barostat.half_step(S.box_momentum, bc, _kT) P, tc = thermostat.half_step(S.momentum, tc, _kT) S = S.set(momentum=P, box_momentum=P_b) S = inner_step(S, **kwargs) KE = quantity.kinetic_energy(momentum=S.momentum, mass=S.mass) tc = tc.set(kinetic_energy=KE) KE_box = quantity.kinetic_energy(momentum=S.box_momentum, mass=S.box_mass) bc = bc.set(kinetic_energy=KE_box) P, tc = thermostat.half_step(S.momentum, tc, _kT) P_b, bc = barostat.half_step(S.box_momentum, bc, _kT) S = S.set(thermostat=tc, barostat=bc, momentum=P, box_momentum=P_b) return S return init_fn, apply_fn
[docs]def npt_nose_hoover_invariant(energy_fn: Callable[..., Array], state: NPTNoseHooverState, pressure: float, kT: float, **kwargs) -> float: """The conserved quantity for the NPT ensemble with a Nose-Hoover thermostat. This function is normally used for debugging the NPT simulation. Arguments: energy_fn: The energy function of the system. state: The current state of the system. pressure: The current goal pressure of the system. kT: The current goal temperature of the system. Returns: The Hamiltonian of the extended NPT dynamics. """ volume, box_fn = _npt_box_info(state) PE = energy_fn(state.position, box=box_fn(volume), **kwargs) KE = kinetic_energy(state) DOF = state.position.size E = PE + KE c = state.thermostat E += c.momentum[0] ** 2 / (2 * c.mass[0]) + DOF * kT * c.position[0] for r, p, m in zip(c.position[1:], c.momentum[1:], c.mass[1:]): E += p ** 2 / (2 * m) + kT * r c = state.barostat for r, p, m in zip(c.position, c.momentum, c.mass): E += p ** 2 / (2 * m) + kT * r E += pressure * volume E += state.box_momentum ** 2 / (2 * state.box_mass) return E
"""Stochastic Simulations JAX MD includes integrators for stochastic simulations of Langevin dynamics and Brownian motion for systems in the NVT ensemble with a solvent. """ @dataclasses.dataclass class Normal: """A simple normal distribution.""" mean: jnp.ndarray var: jnp.ndarray def sample(self, key): mu, sigma = self.mean, jnp.sqrt(self.var) return mu + sigma * random.normal(key, mu.shape ,dtype=mu.dtype) def log_prob(self, x): return (-0.5 * jnp.log(2 * jnp.pi * self.var) - 1 / (2 * self.var) * (x - self.mean)**2)
[docs]@dataclasses.dataclass class NVTLangevinState: """A struct containing state information for the Langevin thermostat. Attributes: position: The current position of the particles. An ndarray of floats with shape `[n, spatial_dimension]`. momentum: The momentum of particles. An ndarray of floats with shape `[n, spatial_dimension]`. force: The (non-stochastic) force on particles. An ndarray of floats with shape `[n, spatial_dimension]`. mass: The mass of particles. Will either be a float or an ndarray of floats with shape `[n]`. rng: The current state of the random number generator. """ position: Array momentum: Array force: Array mass: Array rng: Array @property def velocity(self) -> Array: return self.momentum / self.mass
@dispatch_by_state def stochastic_step(state: NVTLangevinState, dt:float, kT: float, gamma: float): """A single stochastic step (the `O` step).""" c1 = jnp.exp(-gamma * dt) c2 = jnp.sqrt(kT * (1 - c1**2)) momentum_dist = Normal(c1 * state.momentum, c2**2 * state.mass) key, split = random.split(state.rng) return state.set(momentum=momentum_dist.sample(split), rng=key)
[docs]def nvt_langevin(energy_or_force_fn: Callable[..., Array], shift_fn: ShiftFn, dt: float, kT: float, gamma: float=0.1, center_velocity: bool=True, **sim_kwargs) -> Simulator: """Simulation in the NVT ensemble using the BAOAB Langevin thermostat. Samples from the canonical ensemble in which the number of particles (N), the system volume (V), and the temperature (T) are held constant. Langevin dynamics are stochastic and it is supposed that the system is interacting with fictitious microscopic degrees of freedom. An example of this would be large particles in a solvent such as water. Thus, Langevin dynamics are a stochastic ODE described by a friction coefficient and noise of a given covariance. Our implementation follows the paper [#davidcheck] by Davidchack, Ouldridge, and Tretyakov. Args: energy_or_force: A function that produces either an energy or a force from a set of particle positions specified as an ndarray of shape `[n, spatial_dimension]`. shift_fn: A function that displaces positions, `R`, by an amount `dR`. Both `R` and `dR` should be ndarrays of shape `[n, spatial_dimension]`. dt: Floating point number specifying the timescale (step size) of the simulation. kT: Floating point number specifying the temperature in units of Boltzmann constant. To update the temperature dynamically during a simulation one should pass `kT` as a keyword argument to the step function. gamma: A float specifying the friction coefficient between the particles and the solvent. center_velocity: A boolean specifying whether or not the center of mass position should be subtracted. Returns: See above. .. rubric:: References .. [#carlon] R. L. Davidchack, T. E. Ouldridge, and M. V. Tretyakov. "New Langevin and gradient thermostats for rigid body dynamics." The Journal of Chemical Physics 142, 144114 (2015) """ force_fn = quantity.canonicalize_force(energy_or_force_fn) @jit def init_fn(key, R, mass=f32(1.0), **kwargs): _kT = kwargs.pop('kT', kT) key, split = random.split(key) force = force_fn(R, **kwargs) state = NVTLangevinState(R, None, force, mass, key) state = canonicalize_mass(state) return initialize_momenta(state, split, _kT) @jit def step_fn(state, **kwargs): _dt = kwargs.pop('dt', dt) _kT = kwargs.pop('kT', kT) dt_2 = _dt / 2 state = momentum_step(state, dt_2) state = position_step(state, shift_fn, dt_2, **kwargs) state = stochastic_step(state, _dt, _kT, gamma) state = position_step(state, shift_fn, dt_2, **kwargs) state = state.set(force=force_fn(state.position, **kwargs)) state = momentum_step(state, dt_2) return state return init_fn, step_fn
[docs]@dataclasses.dataclass class BrownianState: """A tuple containing state information for Brownian dynamics. Attributes: position: The current position of the particles. An ndarray of floats with shape `[n, spatial_dimension]`. mass: The mass of particles. Will either be a float or an ndarray of floats with shape `[n]`. rng: The current state of the random number generator. """ position: Array mass: Array rng: Array
[docs]def brownian(energy_or_force: Callable[..., Array], shift: ShiftFn, dt: float, kT: float, gamma: float=0.1) -> Simulator: """Simulation of Brownian dynamics. Simulates Brownian dynamics which are synonymous with the overdamped regime of Langevin dynamics. However, in this case we don't need to take into account velocity information and the dynamics simplify. Consequently, when Brownian dynamics can be used they will be faster than Langevin. As in the case of Langevin dynamics our implementation follows Carlon et al. [#carlon]_ Args: energy_or_force: A function that produces either an energy or a force from a set of particle positions specified as an ndarray of shape `[n, spatial_dimension]`. shift_fn: A function that displaces positions, `R`, by an amount `dR`. Both `R` and `dR` should be ndarrays of shape `[n, spatial_dimension]`. dt: Floating point number specifying the timescale (step size) of the simulation. kT: Floating point number specifying the temperature in units of Boltzmann constant. To update the temperature dynamically during a simulation one should pass `kT` as a keyword argument to the step function. gamma: A float specifying the friction coefficient between the particles and the solvent. Returns: See above. """ force_fn = quantity.canonicalize_force(energy_or_force) dt, gamma = static_cast(dt, gamma) def init_fn(key, R, mass=f32(1)): state = BrownianState(R, mass, key) return canonicalize_mass(state) def apply_fn(state, **kwargs): _kT = kT if 'kT' not in kwargs else kwargs['kT'] R, mass, key = dataclasses.astuple(state) key, split = random.split(key) F = force_fn(R, **kwargs) xi = random.normal(split, R.shape, R.dtype) nu = f32(1) / (mass * gamma) dR = F * dt * nu + jnp.sqrt(f32(2) * _kT * dt * nu) * xi R = shift(R, dR, **kwargs) return BrownianState(R, mass, key) # pytype: disable=wrong-arg-count return init_fn, apply_fn
"""Experimental Simulations. Below are simulation environments whose implementation is somewhat experimental / preliminary. These environments might not be as ergonomic as the more polished environments above. """ @dataclasses.dataclass class SwapMCState: """A struct containing state information about a Hybrid Swap MC simulation. Attributes: md: A NVTNoseHooverState containing continuous molecular dynamics data. sigma: An `[n,]` array of particle radii. key: A JAX PRGNKey used for random number generation. neighbor: A NeighborList for the system. """ md: NVTNoseHooverState sigma: Array key: Array neighbor: partition.NeighborList # pytype: disable=wrong-arg-count # pytype: disable=wrong-keyword-args
[docs]def hybrid_swap_mc(space_fns: space.Space, energy_fn: Callable[[Array, Array], Array], neighbor_fn: partition.NeighborFn, dt: float, kT: float, t_md: float, N_swap: int, sigma_fn: Optional[Callable[[Array], Array]]=None ) -> Simulator: """Simulation of Hybrid Swap Monte-Carlo. This code simulates the hybrid Swap Monte Carlo algorithm introduced in Berthier et al. [#berthier]_ Here an NVT simulation is performed for `t_md` time and then `N_swap` MC moves are performed that swap the radii of randomly chosen particles. The random swaps are accepted with Metropolis-Hastings step. Each call to the step function runs molecular dynamics for `t_md` and then performs the swaps. Note that this code doesn't feature some of the convenience functions in the other simulations. In particular, there is no support for dynamics keyword arguments and the energy function must be a simple callable of two variables: the distance between adjacent particles and the diameter of the particles. If you want support for a better notion of potential or dynamic keyword arguments, please file an issue! Args: space_fns: A tuple of a displacement function and a shift function defined in `space.py`. energy_fn: A function that computes the energy between one pair of particles as a function of the distance between the particles and the diameter. This function should not have been passed to `smap.xxx`. neighbor_fn: A function to construct neighbor lists outlined in `partition.py`. dt: The timestep used for the continuous time MD portion of the simulation. kT: The temperature of heat bath that the system is coupled to during MD. t_md: The time of each MD block. N_swap: The number of swapping moves between MD blocks. sigma_fn: An optional function for combining radii if they are to be non-additive. Returns: See above. .. rubric:: References .. [#berthier] L. Berthier, E. Flenner, C. J. Fullerton, C. Scalliet, and M. Singh. "Efficient swap algorithms for molecular dynamics simulations of equilibrium supercooled liquids", J. Stat. Mech. (2019) 064004 """ displacement_fn, shift_fn = space_fns metric_fn = space.metric(displacement_fn) nbr_metric_fn = space.map_neighbor(metric_fn) md_steps = int(t_md // dt) # Canonicalize the argument names to be dr and sigma. wrapped_energy_fn = lambda dr, sigma: energy_fn(dr, sigma) if sigma_fn is None: sigma_fn = lambda si, sj: 0.5 * (si + sj) nbr_energy_fn = smap.pair_neighbor_list(wrapped_energy_fn, metric_fn, sigma=sigma_fn) nvt_init_fn, nvt_step_fn = nvt_nose_hoover(nbr_energy_fn, shift_fn, dt, kT=kT, chain_length=3) def init_fn(key, position, sigma, nbrs=None): key, sim_key = random.split(key) nbrs = neighbor_fn(position, nbrs) # pytype: disable=wrong-arg-count md_state = nvt_init_fn(sim_key, position, neighbor=nbrs, sigma=sigma) return SwapMCState(md_state, sigma, key, nbrs) # pytype: disable=wrong-arg-count def md_step_fn(i, state): md, sigma, key, nbrs = dataclasses.unpack(state) md = nvt_step_fn(md, neighbor=nbrs, sigma=sigma) # pytype: disable=wrong-keyword-args nbrs = neighbor_fn(md.position, nbrs) return SwapMCState(md, sigma, key, nbrs) # pytype: disable=wrong-arg-count def swap_step_fn(i, state): md, sigma, key, nbrs = dataclasses.unpack(state) N = md.position.shape[0] # Swap a random pair of particle radii. key, particle_key, accept_key = random.split(key, 3) ij = random.randint(particle_key, (2,), jnp.array(0), jnp.array(N)) new_sigma = sigma.at[ij].set([sigma[ij[1]], sigma[ij[0]]]) # Collect neighborhoods around the two swapped particles. nbrs_ij = nbrs.idx[ij] R_ij = md.position[ij] R_neigh = md.position[nbrs_ij] sigma_ij = sigma[ij][:, None] sigma_neigh = sigma[nbrs_ij] new_sigma_ij = new_sigma[ij][:, None] new_sigma_neigh = new_sigma[nbrs_ij] dR = nbr_metric_fn(R_ij, R_neigh) # Compute the energy before the swap. energy = energy_fn(dR, sigma_fn(sigma_ij, sigma_neigh)) energy = jnp.sum(energy * (nbrs_ij < N)) # Compute the energy after the swap. new_energy = energy_fn(dR, sigma_fn(new_sigma_ij, new_sigma_neigh)) new_energy = jnp.sum(new_energy * (nbrs_ij < N)) # Accept or reject with a metropolis probability. p = random.uniform(accept_key, ()) accept_prob = jnp.minimum(1, jnp.exp(-(new_energy - energy) / kT)) sigma = jnp.where(p < accept_prob, new_sigma, sigma) return SwapMCState(md, sigma, key, nbrs) # pytype: disable=wrong-arg-count def block_fn(state): state = lax.fori_loop(0, md_steps, md_step_fn, state) state = lax.fori_loop(0, N_swap, swap_step_fn, state) return state return init_fn, block_fn
# pytype: enable=wrong-arg-count # pytype: enable=wrong-keyword-args def temp_rescale(energy_or_force_fn: Callable[..., Array], shift_fn: ShiftFn, dt: float, kT: float, window: float, fraction: float, **sim_kwargs) -> Simulator: """Simulation using explicit velocity rescaling. Rescale the velocities of atoms explicitly so that the desired temperature is reached. Args: energy_or_force: A function that produces either an energy or a force from a set of particle positions specified as an ndarray of shape `[n, spatial_dimension]`. shift_fn: A function that displaces positions, `R`, by an amount `dR`. Both `R` and `dR` should be ndarrays of shape `[n, spatial_dimension]`. dt: Floating point number specifying the timescale (step size) of the simulation. kT: Floating point number specifying the temperature in units of Boltzmann constant. To update the temperature dynamically during a simulation one should pass `kT` as a keyword argument to the step function. window: Floating point number specifying the temperature window outside which rescaling is performed. Measured in units of `kT`. fraction: Floating point number which determines the amount of rescaling applied to the velocities. Takes values from 0.0-1.0. Returns: See above. .. rubric:: References .. [#berendsen84] Woodcock, L. V. "ISOTHERMAL MOLECULAR DYNAMICS CALCULATIONS FOR LIQUID SALTS." Chem. Phys. Lett. 1971, 10, 257–261. """ force_fn = quantity.canonicalize_force(energy_or_force_fn) dt = f32(dt) def velocity_rescale(state, window, fraction, kT): """Rescale the momentum if the the difference between current and target temperature is more than the window""" kT_current = temperature(state) cond = jnp.abs(kT_current - kT) > window kT_target = kT_current - fraction*(kT_current - kT) lam = jnp.where(cond, jnp.sqrt(kT_target / kT_current), 1) new_momentum = tree_map(lambda p: p * lam, state.momentum) return state.set(momentum = new_momentum) def init_fn(key, R, mass=f32(1.0), **kwargs): # Reuse the NVEState dataclass state = NVEState(R, None, force_fn(R, **kwargs), mass) state = canonicalize_mass(state) return initialize_momenta(state, key, kT) def apply_fn(state, **kwargs): state = velocity_rescale(state, window, fraction, kT) state = velocity_verlet(force_fn, shift_fn, dt, state, **kwargs) return state return init_fn, apply_fn def temp_berendsen(energy_or_force_fn: Callable[..., Array], shift_fn: ShiftFn, dt: float, kT: float, tau: float, **sim_kwargs) -> Simulator: """Simulation using the Berendsen thermostat. Berendsen (weak coupling) thermostat rescales the velocities of atoms such that the desired temperature is reached. This rescaling is performed at each timestep (dt) and the rescaling factor is calculated using Eq.10 Berendsen et al. [#berendsen84]_. Args: energy_or_force: A function that produces either an energy or a force from a set of particle positions specified as an ndarray of shape `[n, spatial_dimension]`. shift_fn: A function that displaces positions, `R`, by an amount `dR`. Both `R` and `dR` should be ndarrays of shape `[n, spatial_dimension]`. dt: Floating point number specifying the timescale (step size) of the simulation. kT: Floating point number specifying the temperature in units of Boltzmann constant. To update the temperature dynamically during a simulation one should pass `kT` as a keyword argument to the step function. tau: A floating point number determining how fast the temperature is relaxed during the simulation. Measured in units of `dt`. Returns: See above. .. rubric:: References .. [#berendsen84] H. J. C. Berendsen, J. P. M. Postma, W. F. van Gunsteren, A. DiNola, J. R. Haak. "Molecular dynamics with coupling to an external bath." J. Chem. Phys. 15 October 1984; 81 (8): 3684-3690. """ force_fn = quantity.canonicalize_force(energy_or_force_fn) dt = f32(dt) def berendsen_update(state, tau, kT, dt): """Rescaling the momentum of the particle by the factor lam.""" _kT = temperature(state) lam = jnp.sqrt(1 + ((dt/tau) * ((kT/_kT) - 1))) new_momentum = tree_map(lambda p: p * lam, state.momentum) return state.set(momentum=new_momentum) def init_fn(key, R, mass=f32(1.0), **kwargs): # Reuse the NVEState dataclass state = NVEState(R, None, force_fn(R, **kwargs), mass) state = canonicalize_mass(state) return initialize_momenta(state, key, kT) def apply_fn(state, **kwargs): state = berendsen_update(state, tau, kT, dt) state = velocity_verlet(force_fn, shift_fn, dt, state, **kwargs) return state return init_fn, apply_fn def nvk(energy_or_force_fn: Callable[..., Array], shift_fn: ShiftFn, dt: float, kT: float, **sim_kwargs) -> Simulator: """Simulation in the NVK (isokinetic) ensemble using the Gaussian thermostat. Samples from the isokinetic ensemble in which the number of particles (N), the system volume (V), and the kinetic energy (K) are held constant. A Gaussian thermostat is used for the integration and the kinetic energy is held constant during the simulation. The implementation follows the steps described in [#minary2003]_ and [#zhang97]_. See section 4(B) equation 4.12-4.17 in [#minary2003]_ for detailed description. Args: energy_or_force: A function that produces either an energy or a force from a set of particle positions specified as an ndarray of shape `[n, spatial_dimension]`. shift_fn: A function that displaces positions, `R`, by an amount `dR`. Both `R` and `dR` should be ndarrays of shape `[n, spatial_dimension]`. dt: Floating point number specifying the timescale (step size) of the simulation. kT: Floating point number specifying the temperature in units of Boltzmann constant. To update the temperature dynamically during a simulation one should pass `kT` as a keyword argument to the step function. Returns: See above. .. rubric:: References .. [#minary2003] Minary, Peter and Martyna, Glenn J. and Tuckerman, Mark E. "Algorithms and novel applications based on the isokinetic ensemble. I. Biophysical and path integral molecular dynamics" J. Chem. Phys., Vol. 118, No. 6, 8 February 2003. .. [#zhang97] Zhang, Fei. "Operator-splitting integrators for constant-temperature molecular dynamics" J. Chem. Phys. 106, 6102–6106 (1997). """ force_fn = quantity.canonicalize_force(energy_or_force_fn) dt = f32(dt) dt_2 = f32(dt / 2) def momentum_update(state, KE): # eps to avoid edge cases when forces are zero eps = 1e-16 # Equation 4.13 to compute a and b update_fn = (lambda f, p, m: f * p / m) a = util.high_precision_sum(update_fn(state.force, state.momentum, state.mass)) + eps b = util.high_precision_sum(update_fn(state.force, state.force, state.mass)) + eps a /= (2.0 * KE) b /= (2.0 * KE) # Equation 4.12 to compute s(t) and s_dot(t) b_sqrt = jnp.sqrt(b) s_t = ((a / b) * (jnp.cosh(dt_2 * b_sqrt) - 1.0)) + jnp.sinh(dt_2 * b_sqrt) / b_sqrt s_dot_t = (b_sqrt * (a / b) * jnp.sinh(dt_2 * b_sqrt)) + jnp.cosh(dt_2 * b_sqrt) # Get the new momentum using Equation 4.15 new_momentum = tree_map(lambda p, f, s, sdot: (p + f * s) / sdot, state.momentum, state.force, s_t, s_dot_t) return state.set(momentum=new_momentum) def position_update(state, shift_fn, **kwargs): if isinstance(shift_fn, Callable): shift_fn = tree_map(lambda r: shift_fn, state.position) # Get the new positions using Equation 4.16 (Should read r = r + dt * p / m) new_position = tree_map(lambda s_fn, r, v: s_fn(r, dt * v, **kwargs), shift_fn, state.position, state.velocity) return state.set(position=new_position) def init_fn(key, R, mass=f32(1.0), **kwargs): _kT = kwargs.pop('kT', kT) key, split = random.split(key) # Reuse the NVEState dataclass state = NVEState(R, None, force_fn(R, **kwargs), mass) state = canonicalize_mass(state) return initialize_momenta(state, split, _kT) def apply_fn(state, **kwargs): _KE = kinetic_energy(state) state = momentum_update(state, _KE) state = position_update(state, shift_fn) state = state.set(force=force_fn(state.position, **kwargs)) state = momentum_update(state, _KE) return state return init_fn, apply_fn def temp_csvr(energy_or_force_fn: Callable[..., Array], shift_fn: ShiftFn, dt: float, kT: float, tau: float, **sim_kwargs) -> Simulator: """Simulation using the canonical sampling through velocity rescaling (CSVR) thermostat. Samples from the canonical ensemble in which the number of particles (N), the system volume (V), and the temperature (T) are held constant. CSVR algorithmn samples the canonical distribution by rescaling the velocities by a appropritely chosen random factor. At each timestep (dt) the rescaling takes place and the rescaling factor is calculated using A7 Bussi et al. [#bussi2007]_. CSVR updates to the velocity are stochastic in nature and unlike the Berendsen thermostat it samples the true canonical distribution [#Braun2018]_. Args: energy_or_force: A function that produces either an energy or a force from a set of particle positions specified as an ndarray of shape `[n, spatial_dimension]`. shift_fn: A function that displaces positions, `R`, by an amount `dR`. Both `R` and `dR` should be ndarrays of shape `[n, spatial_dimension]`. dt: Floating point number specifying the timescale (step size) of the simulation. kT: Floating point number specifying the temperature in units of Boltzmann constant. To update the temperature dynamically during a simulation one should pass `kT` as a keyword argument to the step function. tau: A floating point number determining how fast the temperature is relaxed during the simulation. Measured in units of `dt`. Returns: See above. .. rubric:: References .. [#bussi2007] Bussi G, Donadio D, Parrinello M. "Canonical sampling through velocity rescaling." The Journal of chemical physics, 126(1), 014101. .. [#Braun2018] Efrem Braun, Seyed Mohamad Moosavi, and Berend Smit. "Anomalous Effects of Velocity Rescaling Algorithms: The Flying Ice Cube Effect Revisited." Journal of Chemical Theory and Computation 2018 14 (10), 5262-5272. """ force_fn = quantity.canonicalize_force(energy_or_force_fn) dt = f32(dt) def sum_noises(state, key): """Sum of N independent gaussian noises squared. Adapted from https://github.com/GiovanniBussi/StochasticVelocityRescaling For more details see Eq.A7 Bussi et al. [#bussi2007]_""" dof = quantity.count_dof(state.position) - 1 _dtype = state.position.dtype if dof == 0: """If there are no terms return zero.""" return 0 elif dof == 1: """For a single noise term, directly calculate the square of the Gaussian noise value.""" rr = random.normal(key, dtype=_dtype) return rr * rr elif dof % 2 == 0: """For an even number of noise terms, use the gamma-distributed random number generator""" return 2.0 * random.gamma(key, dof // 2, dtype=_dtype) else: """For an odd number of noise terms, sum two terms: one from the gamma-distributed generator and another from the square of a Gaussian-distributed random number.""" rr = random.normal(key, dtype=_dtype) return 2.0 * random.gamma(key, (dof - 1) // 2, dtype=_dtype) + (rr * rr) def csvr_update(state, tau, kT, dt): """Update the momentum by an scaling factor as described by Eq.A7 Bussi et al. [#bussi2007]_""" key, split = random.split(state.rng) dof = quantity.count_dof(state.position) _kT = temperature(state) KE_old = dof * _kT / 2 KE_new = dof * kT / 2 r1 = random.normal(key, dtype=state.position.dtype) r2 = sum_noises(state, key) c1 = jnp.exp(-dt / tau) c2 = (1 - c1) * KE_new / KE_old / dof scale = c1 + (c2*((r1 * r1) + r2)) + (2 * r1 * jnp.sqrt(c1 * c2)) lam = jnp.sqrt(scale) new_momentum = tree_map(lambda p: p * lam, state.momentum) return state.set(momentum=new_momentum, rng=key) def init_fn(key, R, mass=f32(1.0), **kwargs): _kT = kwargs.pop('kT', kT) key, split = random.split(key) # Reuse the NVTLangevinState dataclass state = NVTLangevinState(R, None, force_fn(R, **kwargs), mass, key) state = canonicalize_mass(state) return initialize_momenta(state, split, _kT) def apply_fn(state, **kwargs): state = csvr_update(state, tau, kT, dt) state = velocity_verlet(force_fn, shift_fn, dt, state, **kwargs) return state return init_fn, apply_fn