# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definitions of various standard energy functions."""
from functools import wraps, partial
from typing import Callable, Tuple, TextIO, Dict, Any, Optional
import re
from flax import nnx
import jax
import jax.numpy as jnp
from jax import ops
from jax.tree_util import tree_map
from jax import vmap
from jax.scipy.special import erfc # error function
from jax_md import (
custom_partition,
space,
smap,
partition,
nn,
quantity,
interpolate,
util,
)
from ml_collections import ConfigDict
# Electrostatics
from jax_md._energy.electrostatics import coulomb_direct_pair
from jax_md._energy.electrostatics import coulomb_direct_neighbor_list
from jax_md._energy.electrostatics import coulomb_recip_ewald
from jax_md._energy.electrostatics import coulomb_recip_pme
from jax_md._energy.electrostatics import coulomb_ewald_neighbor_list
from jax_md._energy.electrostatics import coulomb
from jax_md._energy.electrostatics import coulomb_neighbor_list
# Define aliases different neural network primitives.
bp = nn.behler_parrinello
gnome = nn.gnome
nequip = nn.nequip
maybe_downcast = util.maybe_downcast
# Types
f32 = util.f32
f64 = util.f64
Array = util.Array
ArrayLike = util.ArrayLike
PyTree = Any
Box = space.Box
DisplacementFn = space.DisplacementFn
DisplacementOrMetricFn = space.DisplacementOrMetricFn
NeighborFn = partition.NeighborFn
NeighborListFns = partition.NeighborListFns
NeighborList = partition.NeighborList
NeighborListFormat = partition.NeighborListFormat
# Energy Functions
[docs]
def simple_spring(
dr: Array,
length: Array | float = 1,
epsilon: Array | float = 1,
alpha: Array | float = 2,
**unused_kwargs,
) -> Array:
"""Isotropic spring potential with a given rest length.
We define `simple_spring` to be a generalized Hookean spring with
agreement when `alpha = 2`.
"""
return epsilon / alpha * jnp.abs(dr - length) ** alpha
[docs]
def simple_spring_bond(
displacement_or_metric: DisplacementOrMetricFn,
bond: Array,
bond_type: Array | None = None,
length: Array | float = 1,
epsilon: Array | float = 1,
alpha: Array | float = 2,
) -> Callable[[Array], Array]:
"""Convenience wrapper to compute energy of particles bonded by springs."""
length = maybe_downcast(length)
epsilon = maybe_downcast(epsilon)
alpha = maybe_downcast(alpha)
return smap.bond(
simple_spring,
space.canonicalize_displacement_or_metric(displacement_or_metric),
bond,
bond_type,
ignore_unused_parameters=True,
length=length,
epsilon=epsilon,
alpha=alpha,
)
[docs]
def soft_sphere(
dr: Array,
sigma: Array | float = 1,
epsilon: Array | float = 1,
alpha: Array | float = 2,
**unused_kwargs,
) -> Array:
""".. _soft-sphere:
Finite ranged repulsive interaction between soft spheres.
This implements the overlap potential commonly used in jamming and soft
matter studies:
.. math::
U(r) = \\frac{\\epsilon}{\\alpha} \\left(1 - \\frac{r}{\\sigma}\\right)^\\alpha
\\quad \\text{for } r < \\sigma
and :math:`U(r) = 0` for :math:`r \\geq \\sigma`.
Note: This is distinct from the inverse power law potential
:math:`U(r) = \\epsilon (\\sigma/r)^\\alpha` used in some fluid simulations.
Reference:
O'Hern, C. S., Silbert, L. E., Liu, A. J., & Nagel, S. R. (2003).
Jamming at zero temperature and zero applied stress: The epitome of
disorder. *Physical Review E*, 68(1), 011306.
https://doi.org/10.1103/PhysRevE.68.011306
Args:
dr: An ndarray of shape `[n, m]` of pairwise distances between particles.
sigma: Particle diameter. Should either be a floating point scalar or an
ndarray whose shape is `[n, m]`.
epsilon: Interaction energy scale. Should either be a floating point scalar
or an ndarray whose shape is `[n, m]`.
alpha: Exponent specifying interaction stiffness. Should either be a float
point scalar or an ndarray whose shape is `[n, m]`.
unused_kwargs: Allows extra data (e.g. time) to be passed to the energy.
Returns:
Matrix of energies whose shape is `[n, m]`.
"""
dr = dr / sigma
fn = lambda dr: epsilon / alpha * (f32(1.0) - dr) ** alpha
if jnp.issubdtype(jnp.result_type(alpha), jnp.integer):
return jnp.where(dr < 1.0, fn(dr), f32(0.0))
return util.safe_mask(dr < 1.0, fn, dr, f32(0.0))
[docs]
def soft_sphere_pair(
displacement_or_metric: DisplacementOrMetricFn,
species: Array | int | None = None,
sigma: Array | float = 1.0,
epsilon: Array | float = 1.0,
alpha: Array | float = 2.0,
per_particle: bool = False,
) -> Callable[[Array], Array]:
"""Convenience wrapper to compute :ref:`soft sphere energy <soft-sphere>` over a system."""
sigma = maybe_downcast(sigma)
epsilon = maybe_downcast(epsilon)
alpha = maybe_downcast(alpha)
return smap.pair(
soft_sphere,
space.canonicalize_displacement_or_metric(displacement_or_metric),
ignore_unused_parameters=True,
species=species,
sigma=sigma,
epsilon=epsilon,
alpha=alpha,
reduce_axis=(1,) if per_particle else None,
)
[docs]
def soft_sphere_neighbor_list(
displacement_or_metric: DisplacementOrMetricFn,
box_size: Box,
species: Array | int | None = None,
sigma: Array | float = 1.0,
epsilon: Array | float = 1.0,
alpha: Array | float = 2.0,
dr_threshold: float = 0.2,
per_particle: bool = False,
fractional_coordinates: bool = False,
format: NeighborListFormat = partition.OrderedSparse,
neighbor_list_fn: Callable = partition.neighbor_list,
pair_neighbor_list_fn: Callable = smap.pair_neighbor_list,
**neighbor_kwargs,
) -> Tuple[NeighborListFns, Callable[..., Array]]:
"""Convenience wrapper to compute :ref:`soft spheres <soft-sphere>` using a neighbor list."""
sigma = maybe_downcast(sigma)
epsilon = maybe_downcast(epsilon)
alpha = maybe_downcast(alpha)
list_cutoff = jnp.max(sigma)
dr_threshold = maybe_downcast(dr_threshold)
neighbor_fn = neighbor_list_fn(
displacement_or_metric,
box_size,
list_cutoff,
dr_threshold,
fractional_coordinates=fractional_coordinates,
format=format,
**neighbor_kwargs,
)
energy_fn = pair_neighbor_list_fn(
soft_sphere,
space.canonicalize_displacement_or_metric(displacement_or_metric),
ignore_unused_parameters=True,
species=species,
sigma=sigma,
epsilon=epsilon,
alpha=alpha,
reduce_axis=(1,) if per_particle else None,
fractional_coordinates=fractional_coordinates,
)
return neighbor_fn, energy_fn
[docs]
def lennard_jones(
dr: Array,
sigma: Array | float = 1,
epsilon: Array | float = 1,
**unused_kwargs,
) -> Array:
""".. _lj-pot:
Lennard-Jones interaction between particles with a minimum at `sigma`.
Args:
dr: An ndarray of shape `[n, m]` of pairwise distances between particles.
sigma: Distance between particles where the energy has a minimum. Should
either be a floating point scalar or an ndarray whose shape is `[n, m]`.
epsilon: Interaction energy scale. Should either be a floating point scalar
or an ndarray whose shape is `[n, m]`.
unused_kwargs: Allows extra data (e.g. time) to be passed to the energy.
Returns:
Matrix of energies of shape `[n, m]`.
"""
idr = sigma / dr
idr = idr * idr
idr6 = idr * idr * idr
idr12 = idr6 * idr6
# TODO(schsam): This seems potentially dangerous. We should do ErrorChecking
# here.
return jnp.nan_to_num(f32(4) * epsilon * (idr12 - idr6))
[docs]
def lennard_jones_pair(
displacement_or_metric: DisplacementOrMetricFn,
species: Array | None = None,
sigma: Array | float = 1.0,
epsilon: Array | float = 1.0,
r_onset: Array | float = 2.0,
r_cutoff: Array | float = 2.5,
per_particle: bool = False,
) -> Callable[[Array], Array]:
"""Convenience wrapper to compute :ref:`Lennard-Jones energy <lj-pot>` over a system."""
sigma = maybe_downcast(sigma)
epsilon = maybe_downcast(epsilon)
r_onset = maybe_downcast(r_onset) * jnp.max(sigma)
r_cutoff = maybe_downcast(r_cutoff) * jnp.max(sigma)
return smap.pair(
multiplicative_isotropic_cutoff(lennard_jones, r_onset, r_cutoff),
space.canonicalize_displacement_or_metric(displacement_or_metric),
ignore_unused_parameters=True,
species=species,
sigma=sigma,
epsilon=epsilon,
reduce_axis=(1,) if per_particle else None,
)
[docs]
def lennard_jones_neighbor_list(
displacement_or_metric: DisplacementOrMetricFn,
box_size: Box,
species: Array | None = None,
sigma: Array | float = 1.0,
epsilon: Array | float = 1.0,
r_onset: float = 2.0,
r_cutoff: float = 2.5,
dr_threshold: float = 0.5,
per_particle: bool = False,
fractional_coordinates: bool = False,
format: partition.NeighborListFormat = partition.OrderedSparse,
neighbor_list_fn: Callable = partition.neighbor_list,
pair_neighbor_list_fn: Callable = smap.pair_neighbor_list,
**neighbor_kwargs,
) -> Tuple[NeighborListFns, Callable[..., Array]]:
"""Convenience wrapper to compute :ref:`Lennard-Jones <lj-pot>` using a neighbor list."""
sigma = maybe_downcast(sigma)
epsilon = maybe_downcast(epsilon)
r_onset = maybe_downcast(r_onset) * jnp.max(sigma)
r_cutoff = maybe_downcast(r_cutoff) * jnp.max(sigma)
dr_threshold = maybe_downcast(dr_threshold)
neighbor_fn = neighbor_list_fn(
displacement_or_metric,
box_size,
r_cutoff,
dr_threshold,
fractional_coordinates=fractional_coordinates,
format=format,
**neighbor_kwargs,
)
energy_fn = pair_neighbor_list_fn(
multiplicative_isotropic_cutoff(lennard_jones, r_onset, r_cutoff),
space.canonicalize_displacement_or_metric(displacement_or_metric),
ignore_unused_parameters=True,
species=species,
sigma=sigma,
epsilon=epsilon,
reduce_axis=(1,) if per_particle else None,
fractional_coordinates=fractional_coordinates,
)
return neighbor_fn, energy_fn
[docs]
def morse(
dr: Array,
sigma: Array | float = 1.0,
epsilon: Array | float = 5.0,
alpha: Array | float = 5.0,
**unused_kwargs,
) -> Array:
""".. _morse-pot:
Morse interaction between particles with a minimum at `sigma`.
Args:
dr: An ndarray of shape `[n, m]` of pairwise distances between particles.
sigma: Distance between particles where the energy has a minimum. Should
either be a floating point scalar or an ndarray whose shape is `[n, m]`.
epsilon: Interaction energy scale. Should either be a floating point scalar
or an ndarray whose shape is `[n, m]`.
alpha: Range parameter. Should either be a floating point scalar or an
ndarray whose shape is `[n, m]`.
unused_kwargs: Allows extra data (e.g. time) to be passed to the energy.
Returns:
Matrix of energies of shape `[n, m]`.
"""
U = epsilon * (f32(1) - jnp.exp(-alpha * (dr - sigma))) ** f32(2) - epsilon
# TODO(cpgoodri): ErrorChecking following lennard_jones
return jnp.nan_to_num(jnp.array(U, dtype=dr.dtype))
[docs]
def morse_pair(
displacement_or_metric: DisplacementOrMetricFn,
species: Array | None = None,
sigma: Array | float = 1.0,
epsilon: Array | float = 5.0,
alpha: Array | float = 5.0,
r_onset: float = 2.0,
r_cutoff: float = 2.5,
per_particle: bool = False,
) -> Callable[[Array], Array]:
"""Convenience wrapper to compute :ref:`Morse energy <morse-pot>` over a system."""
sigma = maybe_downcast(sigma)
epsilon = maybe_downcast(epsilon)
alpha = maybe_downcast(alpha)
return smap.pair(
multiplicative_isotropic_cutoff(morse, r_onset, r_cutoff),
space.canonicalize_displacement_or_metric(displacement_or_metric),
ignore_unused_parameters=True,
species=species,
sigma=sigma,
epsilon=epsilon,
alpha=alpha,
reduce_axis=(1,) if per_particle else None,
)
[docs]
def morse_neighbor_list(
displacement_or_metric: DisplacementOrMetricFn,
box_size: Box,
species: Array | None = None,
sigma: Array | float = 1.0,
epsilon: Array | float = 5.0,
alpha: Array | float = 5.0,
r_onset: float = 2.0,
r_cutoff: float = 2.5,
dr_threshold: float = 0.5,
per_particle: bool = False,
fractional_coordinates: bool = False,
format: partition.NeighborListFormat = partition.OrderedSparse,
neighbor_list_fn: Callable = partition.neighbor_list,
pair_neighbor_list_fn: Callable = smap.pair_neighbor_list,
**neighbor_kwargs,
) -> Tuple[NeighborListFns, Callable[..., Array]]:
"""Convenience wrapper to compute :ref:`Morse <morse-pot>` using a neighbor list."""
sigma = maybe_downcast(sigma)
epsilon = maybe_downcast(epsilon)
alpha = maybe_downcast(alpha)
r_onset = maybe_downcast(r_onset)
r_cutoff = maybe_downcast(r_cutoff)
dr_threshold = maybe_downcast(dr_threshold)
neighbor_fn = neighbor_list_fn(
displacement_or_metric,
box_size,
r_cutoff,
dr_threshold,
fractional_coordinates=fractional_coordinates,
format=format,
**neighbor_kwargs,
)
energy_fn = pair_neighbor_list_fn(
multiplicative_isotropic_cutoff(morse, r_onset, r_cutoff),
space.canonicalize_displacement_or_metric(displacement_or_metric),
ignore_unused_parameters=True,
species=species,
sigma=sigma,
epsilon=epsilon,
alpha=alpha,
reduce_axis=(1,) if per_particle else None,
fractional_coordinates=fractional_coordinates,
)
return neighbor_fn, energy_fn
[docs]
def gupta_potential(displacement, p, q, r_0n, U_n, A, cutoff):
""".. _gupta-pot:
Gupta potential with default parameters for Au_55 cluster. Gupta
potential was introduced by R. P. Gupta [#gupta]_. This potential uses
parameters that were fit for bulk gold by Jellinek [#jellinek]_. This
particular implementation of the Gupta potential was introduced by Garzon and
Posada-Amarillas [#garzon]_.
Args:
displacement: Function to compute displacement between two positions.
p: Gupta potential parameter of the repulsive term that was fitted for
bulk gold.
q: Gupta potential parameter of the attractive term that was fitted for
bulk gold.
r_0n:
Parameter that determines the length scale of the potential. This
value was particularly fit for gold clusters of size 55 atoms.
U_n:
Parameter that determines the energy scale, fit particularly for
gold clusters of size 55 atoms.
A: Parameter that was obtained using the cohesive energy of the fcc gold
metal.
cutoff:
Pairwise interactions that are farther than the cutoff distance will be
ignored.
Returns:
A function that takes in positions of gold atoms (shape `[n, 3]` where `n`
is the number of atoms) and returns the total energy of the system in units
of eV.
.. rubric:: References
.. [#gupta] R.P. Gupta, Phys. Rev. B 23, 6265 (1981)
.. [#jellinek] J. Jellinek, in Metal-Ligand Interactions, edited by N. Russo
and D. R. Salahub (Kluwer Academic, Dordrecht, 1996), p. 325.
.. [#garzon] I.L. Garzon, A. Posada-Amarillas, Phys. Rev. B 54, 16 (1996)
"""
def _gupta_term1(r, p, r_0n, cutoff):
"""Repulsive term in Gupta potential."""
within_cutoff = (r > 0) & (r < cutoff)
term1 = jnp.exp(-1.0 * p * (r / r_0n - 1))
return jnp.where(within_cutoff, term1, 0.0)
def _gupta_term2(r, q, r_0n, cutoff):
"""Attractive term in Gupta potential."""
within_cutoff = (r > 0) & (r < cutoff)
term2 = jnp.exp(-2.0 * q * (r / r_0n - 1))
return jnp.where(within_cutoff, term2, 0.0)
def compute_fn(R):
dR = space.map_product(displacement)(R, R)
dr = space.distance(dR)
first_term = A * jnp.sum(_gupta_term1(dr, p, r_0n, cutoff), axis=1)
# Safe sqrt used in order to ensure that force calculations are not nan
# when the particles are too widely separated at initialization
# (corresponding to the case where the attractive term is 0.).
attractive_term = jnp.sum(_gupta_term2(dr, q, r_0n, cutoff), axis=1)
second_term = util.safe_mask(attractive_term > 0, jnp.sqrt, attractive_term)
return U_n / 2.0 * jnp.sum(first_term - second_term)
return compute_fn
GUPTA_GOLD55_DICT = {
'p': 10.15,
'q': 4.13,
'r_0n': 2.96,
'U_n': 3.454,
'A': 0.118428,
}
[docs]
def gupta_gold55(displacement, cutoff=8.0):
gupta_gold_fn = gupta_potential(
displacement, cutoff=cutoff, **GUPTA_GOLD55_DICT
)
def energy_fn(R, **unused_kwargs):
return gupta_gold_fn(R)
return energy_fn
[docs]
def multiplicative_isotropic_cutoff(
fn: Callable[..., Array], r_onset: float, r_cutoff: float
) -> Callable[..., Array]:
"""Takes an isotropic function and constructs a truncated function.
Given a function `f:R -> R`, we construct a new function `f':R -> R` such
that `f'(r) = f(r)` for `r < r_onset`, `f'(r) = 0` for `r > r_cutoff`, and
`f(r)` is :math:`C^1` everywhere. To do this, we follow the approach outlined
in HOOMD Blue [#hoomd]_ (thanks to Carl Goodrich for the pointer). We
construct a function `S(r)` such that `S(r) = 1` for `r < r_onset`,
`S(r) = 0` for `r > r_cutoff`, and `S(r)` is :math:`C^1`. Then
`f'(r) = S(r)f(r)`.
Args:
fn: A function that takes an ndarray of distances of shape `[n, m]` as well
as varargs.
r_onset: A float specifying the distance marking the onset of deformation.
r_cutoff: A float specifying the cutoff distance.
Returns:
A new function with the same signature as fn, with the properties outlined
above.
.. rubric:: References
.. [#hoomd] HOOMD Blue documentation. Accessed on 05/31/2019.
https://hoomd-blue.readthedocs.io/en/stable/module-md-pair.html#hoomd.md.pair.pair
"""
r_c = r_cutoff ** f32(2)
r_o = r_onset ** f32(2)
def smooth_fn(dr):
r = dr ** f32(2)
inner = jnp.where(
dr < r_cutoff,
(r_c - r) ** 2 * (r_c + 2 * r - 3 * r_o) / (r_c - r_o) ** 3,
0,
)
return jnp.where(dr < r_onset, 1, inner)
@wraps(fn)
def cutoff_fn(dr, *args, **kwargs):
return smooth_fn(dr) * fn(dr, *args, **kwargs)
return cutoff_fn
def dsf_coulomb(
r: Array, Q_sq: Array, alpha: Array | float = 0.25, cutoff: float = 8.0
) -> Array:
"""Damped-shifted-force approximation of the coulombic interaction."""
qqr2e = 332.06371 # Coulombic conversion factor: 1/(4*pi*epo).
cutoffsq = cutoff * cutoff
erfcc = erfc(alpha * cutoff)
erfcd = jnp.exp(-alpha * alpha * cutoffsq)
f_shift = -(erfcc / cutoffsq + 2 / jnp.sqrt(jnp.pi) * alpha * erfcd / cutoff)
e_shift = erfcc / cutoff - f_shift * cutoff
e = qqr2e * Q_sq / r * (erfc(alpha * r) - r * e_shift - r**2 * f_shift)
return jnp.where(r < cutoff, e, 0.0)
[docs]
def bks(
dr: Array,
Q_sq: Array,
exp_coeff: Array,
exp_decay: Array,
attractive_coeff: Array,
repulsive_coeff: Array,
coulomb_alpha: Array,
cutoff: float,
**unused_kwargs,
) -> Array:
""".. _bks-pot:
Beest-Kramer-van Santen (BKS) potential [#bks]_ which is commonly used to
model silicas. This function computes the interaction between two
given atoms within the Buckingham form [#carre]_ , following the
implementation from Liu et al. [#liu]_ .
Args:
dr: An ndarray of shape `[n, m]` of pairwise distances between particles.
Q_sq: An ndarray of shape `[n, m]` of pairwise product of partial charges.
exp_coeff: An ndarray of shape `[n, m]` that sets the scale of the
exponential decay of the short-range interaction.
attractive_coeff: An ndarray of shape `[n, m]` for the coefficient of the
attractive 6th order term.
repulsive_coeff: An ndarray of shape `[n, m]` for the coefficient of the
repulsive 24th order term, to prevent the unphysical fusion of atoms.
coulomb_alpha: Damping parameter for the approximation of the long-range
coulombic interactions (a scalar).
cutoff: Cutoff distance for considering pairwise interactions.
unused_kwargs: Allows extra data (e.g. time) to be passed to the energy.
Returns:
Matrix of energies of shape `[n, m]`.
.. rubric:: References
.. [#bks] Van Beest, B. W. H., Gert Jan Kramer, and R. A. Van Santen. "Force fields
for silicas and aluminophosphates based on ab initio calculations." Physical
Review Letters 64.16 (1990): 1955.
.. [#carre] Carré, Antoine, et al. "Developing empirical potentials from ab initio
simulations: The case of amorphous silica." Computational Materials Science
124 (2016): 323-334.
.. [#liu] Liu, Han, et al. "Machine learning Forcefield for silicate glasses."
arXiv preprint arXiv:1902.03486 (2019).
"""
energy = (
dsf_coulomb(dr, Q_sq, coulomb_alpha, cutoff)
+ exp_coeff * jnp.exp(-dr / exp_decay)
+ attractive_coeff / dr**6
+ repulsive_coeff / dr**24
)
return jnp.where(dr < cutoff, energy, 0.0)
[docs]
def bks_pair(
displacement_or_metric: DisplacementOrMetricFn,
species: ArrayLike,
Q_sq: ArrayLike | list,
exp_coeff: ArrayLike | list,
exp_decay: ArrayLike | list,
attractive_coeff: ArrayLike | list,
repulsive_coeff: ArrayLike | list,
coulomb_alpha: ArrayLike,
cutoff: float,
) -> Callable[[Array], Array]:
"""Convenience wrapper to compute :ref:`BKS energy <bks-pot>` over a system."""
Q_sq = maybe_downcast(Q_sq)
exp_coeff = maybe_downcast(exp_coeff)
exp_decay = maybe_downcast(exp_decay)
attractive_coeff = maybe_downcast(attractive_coeff)
repulsive_coeff = maybe_downcast(repulsive_coeff)
return smap.pair(
bks,
displacement_or_metric,
species=species,
ignore_unused_parameters=True,
Q_sq=Q_sq,
exp_coeff=exp_coeff,
exp_decay=exp_decay,
attractive_coeff=attractive_coeff,
repulsive_coeff=repulsive_coeff,
coulomb_alpha=coulomb_alpha,
cutoff=cutoff,
)
[docs]
def bks_neighbor_list(
displacement_or_metric: DisplacementOrMetricFn,
box_size: Box,
species: ArrayLike,
Q_sq: ArrayLike | list,
exp_coeff: ArrayLike | list,
exp_decay: ArrayLike | list,
attractive_coeff: ArrayLike | list,
repulsive_coeff: ArrayLike | list,
coulomb_alpha: ArrayLike,
cutoff: float,
dr_threshold: float = 0.8,
fractional_coordinates: bool = False,
format: partition.NeighborListFormat = partition.OrderedSparse,
neighbor_list_fn: Callable = partition.neighbor_list,
pair_neighbor_list_fn: Callable = smap.pair_neighbor_list,
**neighbor_kwargs,
) -> Tuple[NeighborListFns, Callable[..., Array]]:
"""Convenience wrapper to compute :ref:`BKS energy <bks-pot>` using a neighbor list."""
Q_sq = maybe_downcast(Q_sq)
exp_coeff = maybe_downcast(exp_coeff)
exp_decay = maybe_downcast(exp_decay)
attractive_coeff = maybe_downcast(attractive_coeff)
repulsive_coeff = maybe_downcast(repulsive_coeff)
dr_threshold = maybe_downcast(dr_threshold)
neighbor_fn = neighbor_list_fn(
displacement_or_metric,
box_size,
cutoff,
dr_threshold,
fractional_coordinates=fractional_coordinates,
format=format,
**neighbor_kwargs,
)
energy_fn = pair_neighbor_list_fn(
bks,
space.canonicalize_displacement_or_metric(displacement_or_metric),
species=species,
ignore_unused_parameters=True,
Q_sq=Q_sq,
exp_coeff=exp_coeff,
exp_decay=exp_decay,
attractive_coeff=attractive_coeff,
repulsive_coeff=repulsive_coeff,
coulomb_alpha=coulomb_alpha,
cutoff=cutoff,
fractional_coordinates=fractional_coordinates,
)
return neighbor_fn, energy_fn
# BKS Potential Parameters.
# Coefficients given in kcal/mol.
CHARGE_OXYGEN = -0.977476019
CHARGE_SILICON = 1.954952037
BKS_SILICA_DICT: Dict[str, Any] = {
'Q_sq': [
[CHARGE_SILICON**2, CHARGE_SILICON * CHARGE_OXYGEN],
[CHARGE_SILICON * CHARGE_OXYGEN, CHARGE_OXYGEN**2],
],
'exp_coeff': [[0, 471671.1243], [471671.1243, 23138.64826]],
'exp_decay': [[1, 0.19173537], [0.19173537, 0.356855265]],
'attractive_coeff': [[0, -2156.074422], [-2156.074422, -1879.223108]],
'repulsive_coeff': [[78940848.06, 668.7557239], [668.7557239, 2605.841269]],
'coulomb_alpha': 0.25,
}
def _bks_silica_self(
Q_sq: Array | float, alpha: Array | float, cutoff: float
) -> Array:
"""Function for computing the self-energy contributions to BKS."""
cutoffsq = cutoff * cutoff
erfcc = erfc(alpha * cutoff)
erfcd = jnp.exp(-alpha * alpha * cutoffsq)
f_shift = -(
erfcc / cutoffsq + 2.0 / jnp.sqrt(jnp.pi) * alpha * erfcd / cutoff
)
e_shift = erfcc / cutoff - f_shift * cutoff
qqr2e = 332.06371 # kcal/mol coulombic conversion factor: 1/(4*pi*epo)
return -(e_shift / 2.0 + alpha / jnp.sqrt(jnp.pi)) * Q_sq * qqr2e
[docs]
def bks_silica_pair(
displacement_or_metric: DisplacementOrMetricFn,
species: Array,
cutoff: float = 8.0,
):
"""Convenience wrapper to compute :ref:`BKS energy <bks-pot>` for SiO2."""
bks_pair_fn = bks_pair(
displacement_or_metric, species, cutoff=cutoff, **BKS_SILICA_DICT
)
N_0 = jnp.sum(species == 0)
N_1 = jnp.sum(species == 1)
e_self = partial(_bks_silica_self, alpha=0.25, cutoff=cutoff)
def energy_fn(R, **kwargs):
return (
bks_pair_fn(R, **kwargs)
+ N_0 * e_self(CHARGE_SILICON**2)
+ N_1 * e_self(CHARGE_OXYGEN**2)
)
return energy_fn
[docs]
def bks_silica_neighbor_list(
displacement_or_metric: DisplacementOrMetricFn,
box_size: Box,
species: Array,
cutoff: float = 8.0,
dr_threshold: float = 1.0,
fractional_coordinates: bool = False,
format: partition.NeighborListFormat = partition.OrderedSparse,
neighbor_list_fn: Callable = partition.neighbor_list,
pair_neighbor_list_fn: Callable = smap.pair_neighbor_list,
**neighbor_kwargs,
) -> Tuple[NeighborListFns, Callable[..., Array]]:
"""Convenience wrapper to compute :ref:`BKS energy <bks-pot>` using neighbor lists."""
kwargs = {**BKS_SILICA_DICT, **neighbor_kwargs}
neighbor_fn, bks_pair_fn = bks_neighbor_list(
displacement_or_metric,
box_size,
species,
cutoff=cutoff,
dr_threshold=dr_threshold,
fractional_coordinates=fractional_coordinates,
format=format,
neighbor_list_fn=neighbor_list_fn,
pair_neighbor_list_fn=pair_neighbor_list_fn,
**kwargs,
)
N_0 = jnp.sum(species == 0)
N_1 = jnp.sum(species == 1)
e_self = partial(_bks_silica_self, alpha=0.25, cutoff=cutoff)
def energy_fn(R, neighbor, **kwargs):
return (
bks_pair_fn(R, neighbor, **kwargs)
+ N_0 * e_self(CHARGE_SILICON**2)
+ N_1 * e_self(CHARGE_OXYGEN**2)
)
return neighbor_fn, energy_fn
# Stillinger-Weber Potential
def _sw_angle_interaction(
gamma: float, sigma: float, cutoff: float, dR12: Array, dR13: Array
) -> Array:
"""The angular interaction for the Stillinger-Weber potential.
This function is defined only for interaction with a pair of
neighbors. We then vmap this function three times below to make
it work on the whole system of atoms.
Args:
gamma: A scalar used to fit the angle interaction.
sigma: A scalar that sets the distance scale between neighbors.
cutoff: The cutoff beyond which the interactions are not
considered. The default value should not be changed for the
default SW potential.
dR12: A d-dimensional vector that specifies the displacement
of the first neighbor. This potential is usually used in three
dimensions.
dR13: A d-dimensional vector that specifies the displacement
of the second neighbor.
Returns:
Angular interaction energy for one pair of neighbors.
"""
a = cutoff / sigma
dr12 = space.distance(dR12)
dr13 = space.distance(dR13)
dr12 = jnp.where(dr12 < cutoff, dr12, 0)
dr13 = jnp.where(dr13 < cutoff, dr13, 0)
term1 = jnp.exp(gamma / (dr12 / sigma - a) + gamma / (dr13 / sigma - a))
cos_angle = quantity.cosine_angle_between_two_vectors(dR12, dR13)
term2 = (cos_angle + 1.0 / 3) ** 2
within_cutoff = (
(dr12 > 0) & (dr13 > 0) & (jnp.linalg.norm(dR12 - dR13) > 1e-5)
)
return jnp.where(within_cutoff, term1 * term2, 0)
sw_three_body_term = vmap(
vmap(vmap(_sw_angle_interaction, (0, None)), (None, 0)), 0
)
def _sw_radial_interaction(
sigma: float, B: float, cutoff: float, r: Array
) -> Array:
"""The two body term of the Stillinger-Weber potential."""
a = cutoff / sigma
p = 4
term1 = B * (r / sigma) ** (-p) - 1.0
within_cutoff = (r > 0) & (r < cutoff)
r = jnp.where(within_cutoff, r, 0)
term2 = jnp.exp(1 / (r / sigma - a))
return jnp.where(within_cutoff, term1 * term2, 0.0)
[docs]
def stillinger_weber(
displacement: DisplacementFn,
sigma: float = 2.0951,
A: float = 7.049556277,
B: float = 0.6022245584,
lam: float = 21.0,
gamma: float = 1.2,
epsilon: float = 2.16826,
three_body_strength: float = 1.0,
cutoff: float = 3.77118,
) -> Callable[[Array], Array]:
""".. _sw-pot:
Computes the Stillinger-Weber potential.
The Stillinger-Weber (SW) potential [#stillinger]_ which is commonly used to
model silicon and similar systems. This function uses the default SW
parameters from the original paper. The SW potential was originally proposed
to model diamond in the diamond crystal phase and the liquid phase, and is
known to give unphysical amorphous configurations [#holender]_ [#barkema]_ .
For this reason, we provide a `three_body_strength` parameter. Changing this
number to `1.5` or `2.0` has been know to produce more physical amorphous
phase, preventing most atoms from having more than four nearest neighbors.
Note that this function currently assumes nearest-image-convention.
Args:
displacement: The displacement function for the space.
sigma: A scalar that sets the distance scale between neighbors.
A: A scalar that determines the scale of two-body term.
B: A scalar that determines the scale of the :math:`1 / r^p` term.
lam: A scalar that determines the scale of the three-body term.
epsilon: A scalar that sets the total energy scale.
gamma: A scalar used to fit the angle interaction.
three_body_strength:
A scalar that determines the relative strength
of the angular interaction. Default value is `1.0`, which works well
for the diamond crystal and liquid phases. `1.5` and `2.0` have been used
to model amorphous silicon.
Returns:
A function that computes the total energy.
.. rubric:: References
.. [#stillinger] Stillinger, Frank H., and Thomas A. Weber. "Computer
simulation of local order in condensed phases of silicon."
Physical review B 31.8 (1985): 5262.
.. [#holender] Holender, J. M., and G. J. Morgan. "Generation of a large
structure (105 atoms) of amorphous Si using molecular dynamics." Journal of
Physics: Condensed Matter 3.38 (1991): 7241.
.. [#barkema] Barkema, G. T., and Normand Mousseau. "Event-based relaxation of
continuous disordered systems." Physical review letters 77.21 (1996): 4358.
"""
two_body_fn = partial(_sw_radial_interaction, sigma, B, cutoff)
three_body_fn = partial(_sw_angle_interaction, gamma, sigma, cutoff)
three_body_fn = vmap(vmap(vmap(three_body_fn, (0, None)), (None, 0)))
def compute_fn(R, **kwargs):
d = partial(displacement, **kwargs)
dR = space.map_product(d)(R, R)
dr = space.distance(dR)
first_term = util.high_precision_sum(two_body_fn(dr)) / 2.0 * A
second_term = lam * util.high_precision_sum(three_body_fn(dR, dR)) / 2.0
return epsilon * (first_term + three_body_strength * second_term)
return compute_fn
[docs]
def stillinger_weber_neighbor_list(
displacement: DisplacementFn,
box_size: float,
sigma: float = 2.0951,
A: float = 7.049556277,
B: float = 0.6022245584,
lam: float = 21.0,
gamma: float = 1.2,
epsilon: float = 2.16826,
three_body_strength: float = 1.0,
cutoff: float = 3.77118,
dr_threshold: float = 0.5,
fractional_coordinates: bool = False,
format: NeighborListFormat = partition.Dense,
neighbor_list_fn: Callable = partition.neighbor_list,
**neighbor_kwargs,
) -> Tuple[NeighborListFns, Callable[..., Array]]:
"""Convenience wrapper to compute :ref:`Stillinger-Weber <sw-pot>`
using a neighbor list.
"""
two_body_fn = partial(_sw_radial_interaction, sigma, B, cutoff)
three_body_fn = partial(_sw_angle_interaction, gamma, sigma, cutoff)
neighbor_fn = neighbor_list_fn(
displacement,
box_size,
cutoff,
dr_threshold,
format=format,
**neighbor_kwargs,
)
def compute_fn(R, neighbor, **kwargs):
d = partial(displacement, **kwargs)
mask = partition.neighbor_list_mask(neighbor)
if neighbor.format is partition.Dense:
_three_body_fn = vmap(vmap(vmap(three_body_fn, (0, None)), (None, 0)))
dR = space.map_neighbor(d)(R, R[neighbor.idx])
dr = space.distance(dR)
first_term = util.high_precision_sum(two_body_fn(dr) * mask) / 2.0 * A
mask_ijk = mask[:, None, :] * mask[:, :, None]
second_term = (
lam * util.high_precision_sum(_three_body_fn(dR, dR) * mask_ijk) / 2.0
)
else:
raise NotImplementedError(
'Stillinger-Weber potential only implemented with Dense neighbor lists.'
)
return epsilon * (first_term + three_body_strength * second_term)
return neighbor_fn, compute_fn
# Tersoff model
[docs]
def load_lammps_tersoff_parameters(file: TextIO) -> list[Dict[str, Any]]:
""".. _ts-lammps:
Reads Tersoff parameters from a LAMMPS file and returns parameter tables.
This function reads multi-element original Tersoff potential parameters
from a file.
Args:
file: A parameter file that is written with lammps format.
Returns:
params: An array that contains Tersoff parameters.
"""
# start to read file.
# todo: params_per_line becomes input variables.
# depending on the various type of tersoff model.
params = []
params_per_line = 17
# read parameters.
# skip if the line has \# or empty
# if the number of parameters in one line is less than params_per_line,
# additional line is appended to match.
skip = False
line_keep = ''
for line in file.read().split('\n'):
words = line.strip().split()
nwords = len(words)
if '#' in words or nwords == 0:
continue
if nwords < params_per_line and skip is False:
line_keep = line
skip = True
continue
line_keep += ' ' + line
words = line_keep.strip().split()
nwords = len(words)
if nwords != params_per_line:
raise ValueError(f'Incorrect format: {nwords} not in {params_per_line}')
else:
skip = False
words[3:] = f64(words[3:])
params.append(
{
'element1': words[0],
'element2': words[1],
'element3': words[2],
'mTf': words[3],
'gamma': words[4],
'lam3': words[5],
'cTf': words[6],
'dTf': words[7],
'hTf': words[8],
'nTf': words[9],
'beta': words[10],
'lam2': words[11],
'B': words[12],
'R': words[13],
'D': words[14],
'lam1': words[15],
'A': words[16],
}
)
return params
def _ters_cutoff(dr, R, D) -> Array:
"""The cut-off function of the Tersoff potential.
Args:
R: A Parameter that is the average of inner and outer cutoff radii
D: A Parameter that is the half of the difference
between inner and outer cutoff radii
Returns:
cut-off values
"""
outer = jnp.where(
dr < R + D, 0.5 * (1 - jnp.sin(jnp.pi / 2 * (dr - R) / D)), 0
)
inner = jnp.where(dr < R - D, 1, outer)
return inner
def _ters_bij(R, D, c, d, h, lam3, beta, n, m, dRij, dRik, mask_ijk) -> Array:
"""The bond-order term of the Tersoff potential.
Args:
# parameters for cut-off functions
R: A Parameter that is the average of inner and outer cutoff radii
D: A Parameter that is the half of the difference
between inner and outer cutoff radii
# parameters related to the angle Penalty function in the bond-order
# function
# h(\theta) = 1 + c^2/d^2 + c^2/(d^2 + (h - cos(\theta)^2))
c: A Parameter that determines angle penalty
d: A Parameter that determines angle penalty
h: A cosine value that is a desirable angle between 3 atoms.
# parameters related to the distance penalty function in the bond-order
# function
lam3: A Parameter that determines distance penalty value
m: A Parameter that determines distance penalty value
# parameters related to the bond-order function
beta: A Parameter that determines bond-order value
n: A Parameter that determines bond-order value
dRij: A ndarray of shape [n, neighbors, dim] of pairwise distances between
particles
dRik: A ndarray of shape [n, neighbors, dim] of pairwise distances between
particles TODO - Currently, it is the same as the dRij
Returns:
Bond-order values between i and j atoms
"""
drij = space.distance(dRij)
drik = space.distance(dRik)
mask_ijk *= (1 - jnp.eye(mask_ijk.shape[-1], dtype=dRij.dtype))[None, :, :]
# compute g_ijk - angle penalty value
costheta = quantity.cosine_angles(dRij)
gijk = 1.0 + (c**2 / d**2) - (c**2 / (d**2 + (h - costheta) ** 2))
# compute exponential term - distance penalty value
dr_diff = drij[:, None, :] - drik[:, :, None]
dr_diff = jnp.where(mask_ijk, dr_diff, 0)
explr3 = jnp.exp(lam3**m * dr_diff**m)
# compute fC with dr_ik
fC = _ters_cutoff(drik, R, D)
# compute zeta without diagonal term
prod = jnp.where(mask_ijk, gijk * explr3, 0)
zeta_ij = jnp.einsum('ik,ikj->ij', fC, prod)
bij = (1 + (beta * zeta_ij) ** n) ** (-1 / 2 / n)
return bij
def _ters_attractive(
B: f64,
lam2: f64,
R: f64,
D: f64,
c: f64,
d: f64,
h: f64,
lam3: f64,
beta: f64,
n: f64,
m: f64,
dR12: Array,
dR13: Array,
mask_ijk,
) -> Array:
"""The attractive term of the Tersoff potential.
Args:
dR12: A ndarray of shape [n, neighbors, dim] of pairwise distnaces between
particles.
dR13: A ndarray of shape [n, neighbors, dim] of pairwise distnaces between
particles. TODO - Currently, it is the same as the dR12
R: A Parameter that is the average of inner and outer cutoff radii.
D: A Parameter that is the half of the difference.
between inner and outer cutoff radii.
# parameters related to the angle Penalty function in the bond-order
# function.
# h(\theta) = 1 + c^2/d^2 + c^2/(d^2 + (h - cos(\theta)^2))
c: A Parameter that determines angle penalty.
d: A Parameter that determines angle penalty.
h: A cosine value that is a desirable angle between 3 atoms.
# parameters related to the distance penalty function in the bond-order
# function.
lam3: A Parameter that determines distance penalty value.
m: A Parameter that determines distance penalty value.
# parameters related to the bond-order function
beta: A Parameter that determines bond-order value.
n: A Parameter that determines bond-order value.
Returns:
Attractive interaction energy for one pair of neighbors.
"""
dr12 = space.distance(dR12)
fC = _ters_cutoff(dr12, R, D)
fA = -B * jnp.exp(-lam2 * dr12)
bij = _ters_bij(R, D, c, d, h, lam3, beta, n, m, dR12, dR13, mask_ijk)
return 0.5 * fC * bij * fA
def _ters_repulsive(A: f64, lam1: f64, R: f64, D: f64, dr: Array) -> Array:
"""The repulsive term of the Tersoff potential.
Args:
A: A scalar that determines repulsive energy (eV).
lam1: A scalar that determines the scale two-body distance (Angstrom).
R: A scalar that is the average of inner and outer cutoff radii.
D: A scalar that is the half of the difference
between inner and outer cutoff radii.
Returns:
Repulsive interaction energy for one pair of neighbors.
"""
fC = _ters_cutoff(dr, R, D)
fR = A * jnp.exp(-lam1 * dr)
return 0.5 * fC * fR
[docs]
def tersoff(
displacement: DisplacementFn,
params: list[Dict[str, Any]],
species: Array | None = None,
) -> Callable[[Array], Array]:
"""Computes the Tersoff potential.
The Tersoff potential [1] which is commonly used to model
semiconducting materials. The Tersoff potential was originally proposed to
model various types of lattice with a simple functional form.
For this reason, Tersoff model was introduced bond-order function
to determine the strength of repulsive and attractive forces between atoms.
Args:
displacement: The displacement function for the space.
params: A dictionary of parameters for the tersoff potential. Usually this
should be loaded from lammps using the
:ref:`load_lammps_tersoff_parameters <ts-lammps>` function.
species: An array of species. Currently only `None` is supported.
Returns:
A function that computes the total energy.
[1] J. Tersoff "New empirical approach for the structure and energy of
covalent systems" Physical review B 37.12 (1988): 6991.
"""
# check number of parameters set.
if species is None:
p = params[0]
else:
raise NotImplementedError(
'Multiple species is not implemented yet. '
'Please raise an issue if this is important for '
'you.'
)
# define a repulsive and an attractive function with given parameters.
repulsive_fn = partial(_ters_repulsive, p['A'], p['lam1'], p['R'], p['D'])
attractive_fn = partial(
_ters_attractive,
p['B'],
p['lam2'],
p['R'],
p['D'],
p['cTf'],
p['dTf'],
p['hTf'],
p['lam3'],
p['beta'],
p['nTf'],
p['mTf'],
)
# define compute functions.
def compute_fn(R, **kwargs):
d = partial(displacement, **kwargs)
dR = space.map_product(d)(R, R)
dr = space.distance(dR)
N = R.shape[0]
mask = jnp.where(1 - jnp.eye(N), dr < p['R'] + p['D'], 0)
mask = mask.astype(R.dtype)
mask_ijk = mask[:, None, :] * mask[:, :, None]
repulsive = util.safe_mask(mask, repulsive_fn, dr)
attractive = attractive_fn(dR, dR, mask_ijk) * mask
first_term = util.high_precision_sum(repulsive)
second_term = util.high_precision_sum(attractive)
return first_term + second_term
return compute_fn
def tersoff_from_lammps_parameters(
displacement: DisplacementFn,
f: TextIO,
) -> Callable[[Array], Array]:
"""Convenience wrapper to compute Tersoff energy with LAMMPS parameters."""
return tersoff(displacement, load_lammps_tersoff_parameters(f))
[docs]
def tersoff_neighbor_list(
displacement: DisplacementFn,
box_size: float,
params: list[Dict[str, Any]],
species: Array | None = None,
dr_threshold: float = 0.5,
disable_cell_list: bool = False,
fractional_coordinates: bool = True,
format: NeighborListFormat = partition.Dense,
neighbor_list_fn: Callable = partition.neighbor_list,
**neighbor_kwargs,
) -> Tuple[NeighborListFns, Callable[..., Array]]:
"""Computes the Tersoff potential.
The Tersoff potential [1] which is commonly used to model
semiconducting materials. The Tersoff potential was originally proposed to
model various types of lattice with a simple functional form.
For this reason, Tersoff model was introduced bond-order function
to determine the strength of repulsive and attractive forces between atoms.
Args:
displacement: The displacement function for the space.
box_size: A float or vector specifying the size of the simulation box.
params: A dictionary of parameters for the tersoff potential. Usually this
should be loaded from lammps using the
:ref:`load_lammps_tersoff_parameters <ts-lammps>` function.
species: An array of species. Currently only `None` is supported.
dr_threshold: A distance threshold that controls how often the neighor list
is recomputed.
fractional_coordinates: A boolean specifying whether coordinates are stored
in the unit cube.
format: Format of the neighbor list.
Returns:
A pair of functions. One that builds the neighbor list and one that
computes the total energy.
[1] J. Tersoff "New empirical approach for the structure and energy of
covalent systems" Physical review B 37.12 (1988): 6991.
"""
# check number of parameters set
if species is None:
p = params[0]
nparams = 1
else:
raise NotImplementedError('Multiple species were not implemented yet.')
# define a repulsive and an attractive function with given parameters
repulsive_fn = partial(_ters_repulsive, p['A'], p['lam1'], p['R'], p['D'])
attractive_fn = partial(
_ters_attractive,
p['B'],
p['lam2'],
p['R'],
p['D'],
p['cTf'],
p['dTf'],
p['hTf'],
p['lam3'],
p['beta'],
p['nTf'],
p['mTf'],
)
# define a neighbor function.
# TODO: other neighbor list construction method will be implemented.
if format is partition.Dense:
neighbor_fn = neighbor_list_fn(
displacement,
box_size,
p['R'] + p['D'],
dr_threshold,
disable_cell_list=disable_cell_list,
fractional_coordinates=fractional_coordinates,
format=format,
**neighbor_kwargs,
)
else:
raise NotImplementedError(
'Tersoff potential only implemented with Dense neighbor lists.'
)
# define compute functions
def compute_fn(R, neighbor, **kwargs):
d = partial(displacement, **kwargs)
mask = partition.neighbor_list_mask(neighbor, mask_self=True)
mask_ijk = mask[:, None, :] * mask[:, :, None]
dR = space.map_neighbor(d)(R, R[neighbor.idx])
dr = space.distance(dR)
first_term = util.high_precision_sum(repulsive_fn(dr) * mask)
second_term = util.high_precision_sum(
attractive_fn(dR, dR, mask_ijk) * mask
)
return first_term + second_term
return neighbor_fn, compute_fn
[docs]
def tersoff_from_lammps_parameters_neighbor_list(
displacement: DisplacementFn,
box_size: float,
f: TextIO,
dr_threshold: float = 0.5,
fractional_coordinates=True,
neighbor_list_fn: Callable = partition.neighbor_list,
**neighbor_kwargs,
) -> Tuple[NeighborListFns, Callable[..., Array]]:
"""Convenience wrapper to compute Tersoff energy with LAMMPS parameters."""
return tersoff_neighbor_list(
displacement,
box_size,
load_lammps_tersoff_parameters(f),
dr_threshold=dr_threshold,
fractional_coordinates=fractional_coordinates,
neighbor_list_fn=neighbor_list_fn,
**neighbor_kwargs,
)
# (EDIP) Environment-dependent interatomic potential
def _edip_cutoff_function(r: Array, cutoff: f64, c: f64, alpha: f64) -> Array:
x_term = (r - c) / (cutoff - c)
expo_term = jnp.exp(alpha / (1 - (x_term ** (-3))))
outer = jnp.where(r > cutoff, 0, expo_term)
inner = jnp.where(r > c, outer, 1)
return inner
def _edip_radial_interaction(
A: f64,
B: f64,
rho: f64,
sigma: f64,
c: f64,
alpha: f64,
beta: f64,
cutoff: f64,
mask,
r: Array,
) -> Array:
within_cutoff = (r > 0) & (r < cutoff)
repul = (B / r) ** (rho)
r = jnp.where(within_cutoff, r, 0)
Z_i = util.high_precision_sum(
_edip_cutoff_function(r, cutoff, c, alpha) * mask, axis=1, keepdims=True
)
p_Z = jnp.exp(-beta * (Z_i**2))
term1 = repul - p_Z
term2 = jnp.exp(sigma / (r - cutoff))
return jnp.where(within_cutoff, A * term1 * term2, 0.0)
def _edip_angle_interaction(
lam: f64,
gamma: f64,
Q_0: f64,
cutoff: f64,
u1: f64,
u2: f64,
u3: f64,
u4: f64,
c: f64,
eta: f64,
alpha: f64,
mu: f64,
mask,
dR12: Array,
dR13: Array,
) -> Array:
dr12 = space.distance(dR12)
dr13 = space.distance(dR13)
dr12 = jnp.where(dr12 < cutoff, dr12, 0)
dr13 = jnp.where(dr13 < cutoff, dr13, 0)
term1 = jnp.exp(gamma / (dr12 - cutoff) + gamma / (dr13 - cutoff))
l_ijk = quantity.cosine_angle_between_two_vectors(dR12, dR13)
Z_i = util.high_precision_sum(
_edip_cutoff_function(dr13, cutoff, c, alpha) * mask, keepdims=False
)
tau_Z = u1 + u2 * ((u3 * jnp.exp(-u4 * Z_i)) - jnp.exp(-2 * u4 * Z_i))
Q_Z = Q_0 * jnp.exp(-mu * Z_i)
l_tau = (l_ijk + tau_Z) ** 2
term2 = (1 - jnp.exp(-Q_Z * l_tau)) + (eta * Q_Z * l_tau)
within_cutoff = (
(dr12 > 0) & (dr13 > 0) & (jnp.linalg.norm(dR12 - dR13) > 1e-5)
)
return jnp.where(within_cutoff, lam * term1 * term2, 0)
def edip(
displacement: DisplacementFn,
u1: f64 = -0.165799,
u2: f64 = 32.557,
u3: f64 = 0.286198,
u4: f64 = 0.66,
rho: f64 = 1.2085196,
eta: f64 = 0.2523244,
Q_0: f64 = 312.1341346,
mu: f64 = 0.6966326,
beta: f64 = 0.0070975,
alpha: f64 = 3.1083847,
A: f64 = 7.9821730,
lam: f64 = 1.4533108,
B: f64 = 1.5075463,
gamma: f64 = 1.1247945,
sigma: f64 = 0.5774108,
c: f64 = 2.5609104,
cutoff: f64 = 3.1213820,
) -> Callable[[Array], Array]:
"""
Computes the the Environment-dependent interatomic potential (EDIP).
The parameter values are for bulk Silicon [1,2]. The EDIP potential is a bond
order potential which depends on the local coordination number of the atom.
:param displacement: displacement function for the space.
:param u1: parameter for the three-body bond order function tau(Z) (pure number)
:param u2: parameter for the three-body bond order function tau(Z) (pure number)
:param u3: parameter for the three-body bond order function tau(Z) (pure number)
:param u4: parameter for the three-body bond order function tau(Z) (pure number)
:param rho: exponent for the repulsive part of two-body potential (pure number)
:param eta: parameter for the three-body term (pure number)
:param Q_0: parameter for the three-body bond order function Q(Z) (pure number)
:param mu: parameter for the three-body bond order function Q(Z) (pure number)
:param beta: parameter for the two-body bond order function p(Z) (pure number)
:param alpha: parameter for the cutoff function (pure number)
:param A: parameter that determines the energy scale of two-body term (eV)
:param lam: parameter that determines the energy scale of three-body term (eV)
:param B: parameter for the repulsive part of two-body potential (Angstrom)
:param gamma: parameter for the radial part of three-body term (Angstrom)
:param sigma: parameter that determines the distance scale between neighbors (Angstrom)
:param c: inner cutoff for the cutoff function f(r) (Angstrom)
:param cutoff: outer cutoff (a) for the cutoff function f(r) (Angstrom)
:return: A function that computes the potential energy.
References:
[1] - Martin Z. Bazant, Efthimios Kaxiras, and J. F. Justo.
"Environment-dependent interatomic potential for bulk silicon".
Phys. Rev. B 56, 8542 (1997).
[2] - João F. Justo, Martin Z. Bazant, Efthimios Kaxiras, V. V. Bulatov,
and Sidney Yip. "Interatomic potential for silicon defects and
disordered phases". Phys. Rev. B 58, 2539 (1998).
"""
two_body_fn = partial(
_edip_radial_interaction, A, B, rho, sigma, c, alpha, beta, cutoff
)
three_body_fn = partial(
_edip_angle_interaction,
lam,
gamma,
Q_0,
cutoff,
u1,
u2,
u3,
u4,
c,
eta,
alpha,
mu,
)
def compute_fn(R, **kwargs):
_three_body_fn = vmap(
vmap(vmap(three_body_fn, (None, 0, None)), (None, None, 0))
)
d = partial(displacement, **kwargs)
dR = space.map_product(d)(R, R)
dr = space.distance(dR)
N = R.shape[0]
mask = (1 - jnp.eye(N, dtype=R.dtype)) * (dr < cutoff)
first_term = util.high_precision_sum(two_body_fn(mask, dr))
second_term = util.high_precision_sum(_three_body_fn(mask, dR, dR)) / 2.0
return first_term + second_term
return compute_fn
def edip_neighbor_list(
displacement: DisplacementFn,
box_size: f64,
u1: f64 = -0.165799,
u2: f64 = 32.557,
u3: f64 = 0.286198,
u4: f64 = 0.66,
rho: f64 = 1.2085196,
eta: f64 = 0.2523244,
Q_0: f64 = 312.1341346,
mu: f64 = 0.6966326,
beta: f64 = 0.0070975,
alpha: f64 = 3.1083847,
A: f64 = 7.9821730,
lam: f64 = 1.4533108,
B: f64 = 1.5075463,
gamma: f64 = 1.1247945,
sigma: f64 = 0.5774108,
c: f64 = 2.5609104,
cutoff: f64 = 3.1213820,
dr_threshold: f64 = 0.0,
fractional_coordinates: bool = True,
format: NeighborListFormat = partition.Dense,
neighbor_list_fn: Callable = partition.neighbor_list,
**neighbor_kwargs,
) -> Tuple[NeighborListFns, Callable[..., Array]]:
two_body_fn = partial(
_edip_radial_interaction, A, B, rho, sigma, c, alpha, beta, cutoff
)
three_body_fn = partial(
_edip_angle_interaction,
lam,
gamma,
Q_0,
cutoff,
u1,
u2,
u3,
u4,
c,
eta,
alpha,
mu,
)
neighbor_fn = neighbor_list_fn(
displacement,
box_size,
cutoff,
dr_threshold,
format=format,
fractional_coordinates=fractional_coordinates,
**neighbor_kwargs,
)
def compute_fn(R, neighbor, **kwargs):
d = partial(displacement, **kwargs)
mask = partition.neighbor_list_mask(neighbor, mask_self=True)
if format is partition.Dense:
_three_body_fn = vmap(
vmap(vmap(three_body_fn, (None, 0, None)), (None, None, 0))
)
dR = space.map_neighbor(d)(R, R[neighbor.idx])
dr = space.distance(dR)
first_term = util.high_precision_sum(two_body_fn(mask, dr) * mask)
mask_ijk = mask[:, None, :] * mask[:, :, None]
second_term = (
util.high_precision_sum(_three_body_fn(mask, dR, dR) * mask_ijk) / 2.0
)
else:
raise NotImplementedError(
'EDIP potential only implemented with Dense neighbor lists.'
)
return first_term + second_term
return neighbor_fn, compute_fn
# Embedded Atom Method
[docs]
def load_lammps_eam_parameters(
file: TextIO,
) -> Tuple[
Callable[[Array], Array],
Callable[[Array], Array],
Callable[[Array], Array],
float,
]:
"""Reads EAM parameters from a LAMMPS file and returns relevant spline fits.
This function reads single-element EAM potential fit parameters from a file
in DYNAMO funcl format. In summary, the file contains:
* Line 1-3: Comments
* Line 4: Number of elements and the element type
* Line 5: The number of charge values that the embedding energy is evaluated
on (`num_drho`), interval between the charge values (`drho`), the number of
distances the pairwise energy and the charge density is evaluated on
(`num_dr`), the interval between these distances (`dr`), and the cutoff
distance (`cutoff`).
The lines that come after are the embedding function evaluated on `num_drho`
charge values, charge function evaluated at `num_dr` distance values, and
pairwise energy evaluated at `num_dr` distance values. Note that the pairwise
energy is multiplied by distance (in units of eV x Angstroms).
For more details of the DYNAMO file format, see:
https://sites.google.com/a/ncsu.edu/cjobrien/tutorials-and-guides/eam
Args:
f: File handle for the EAM parameters text file.
Returns:
A tuple containing three functions and a cutoff distance.
charge_fn:
A function that takes an ndarray of shape `[n, m]` of distances
between particles and returns a matrix of charge contributions.
embedding_fn:
Function that takes an ndarray of shape `[n]` of charges and
returns an ndarray of shape `[n]` of the energy cost of embedding an atom
into the charge.
pairwise_fn:
A function that takes an ndarray of shape `[n, m]` of distances
and returns an ndarray of shape `[n, m]` of pairwise energies.
cutoff:
Cutoff distance for the `embedding_fn` and `pairwise_fn`.
"""
raw_text = file.read().split('\n')
if 'setfl' not in raw_text[0]:
raise ValueError('File format is incorrect, expected LAMMPS setfl format.')
temp_params = raw_text[4].split()
num_drho, num_dr = int(temp_params[0]), int(temp_params[2])
drho, dr, cutoff = (
float(temp_params[1]),
float(temp_params[3]),
float(temp_params[4]),
)
if len(re.split(' +', raw_text[6].strip())) > 1:
data = [
maybe_downcast([float(i) for i in re.split(' +', rt.strip())])
for rt in raw_text[6:]
]
data = jnp.concatenate(data)
else:
data = maybe_downcast([float(i) for i in raw_text[6:-1]])
embedding_fn = interpolate.spline(data[:num_drho], drho)
charge_fn = interpolate.spline(data[num_drho : num_drho + num_dr], dr)
# LAMMPS EAM parameters file lists pairwise energies after multiplying by
# distance, in units of eV*Angstrom. We divide the energy by distance below,
distances = jnp.arange(num_dr) * dr
# Prevent dividing by zero at zero distance, which will not
# affect the calculation
distances = jnp.where(distances == 0, f32(0.001), distances)
pairwise_fn = interpolate.spline(
data[num_dr + num_drho : num_drho + 2 * num_dr] / distances, dr
)
return charge_fn, embedding_fn, pairwise_fn, cutoff
[docs]
def eam(
displacement_or_metric: DisplacementOrMetricFn,
charge_fn: Callable[[Array], Array],
embedding_fn: Callable[[Array], Array],
pairwise_fn: Callable[[Array], Array],
axis: Tuple[int, ...] | None = None,
) -> Callable[[Array], Array]:
""".. _eam-pot:
Interatomic potential as approximated by embedded atom model (EAM).
This code implements the EAM approximation to interactions between metallic
atoms. In EAM, the potential energy of an atom is given by two terms: a
pairwise energy and an embedding energy due to the interaction between the
atom and background charge density. The EAM potential for a single atomic
species is often determined by three functions:
1) Charge density contribution of an atom as a function of distance.
2) Energy of embedding an atom in the background charge density.
3) Pairwise energy.
These three functions are usually provided as spline fits, and we follow the
implementation and spline fits given by Mishin et al. [#mishin]_
Note that in current implementation, the three functions listed above
can also be expressed by a any function with the correct signature,
including neural networks.
Args:
displacement: A function that produces an ndarray of shape `[n, m,
spatial_dimension]` of particle displacements from particle positions
specified as an ndarray of shape `[n, spatial_dimension]` and `[m,
spatial_dimension]` respectively.
box_size: The size of the simulation box.
charge_fn: A function that takes an ndarray of shape `[n, m]` of distances
between particles and returns a matrix of charge contributions.
embedding_fn: Function that takes an ndarray of shape `[n]` of charges and
returns an ndarray of shape `[n]` of the energy cost of embedding an atom
into the charge.
pairwise_fn: A function that takes an ndarray of shape `[n, m]` of distances
and returns an ndarray of shape `[n, m]` of pairwise energies.
cutoff: A float specifying the maximum interaction distance.
dr_threshold: A float specifying the halo in the neighbor list.
axis: Specifies which axis the total energy should be summed over.
fractional_coordinates: A boolean specifying whether or not the coordinates
will be in the unit cube.
format: The format of the neighbor list.
Returns:
A tuple containing a function to build the neighbor list and function that
computes the EAM energy of a set of atoms with positions given by an
`[n, spatial_dimension]` ndarray.
.. rubric:: References
.. [#mishin] Y. Mishin, D. Farkas, M.J. Mehl, DA Papaconstantopoulos, "Interatomic
potentials for monoatomic metals from experimental data and ab initio
calculations." Physical Review B, 59 (1999)
"""
metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
def energy(R, **kwargs):
d = partial(metric, **kwargs)
dr = space.map_product(d)(R, R)
total_charge = util.high_precision_sum(charge_fn(dr), axis=1)
embedding_energy = embedding_fn(total_charge)
pairwise_energy = util.high_precision_sum(
smap._diagonal_mask(pairwise_fn(dr)), axis=1
) / f32(2.0)
return util.high_precision_sum(
embedding_energy + pairwise_energy, axis=axis
)
return energy
[docs]
def eam_from_lammps_parameters(
displacement: DisplacementFn, f: TextIO
) -> Callable[[Array], Array]:
"""Convenience wrapper to compute :ref:`EAM energy <eam-pot>` with LAMMPS parameters."""
return eam(displacement, *load_lammps_eam_parameters(f)[:-1])
[docs]
def eam_neighbor_list(
displacement_or_metric: DisplacementOrMetricFn,
box_size: float,
charge_fn: Callable[[Array], Array],
embedding_fn: Callable[[Array], Array],
pairwise_fn: Callable[[Array], Array],
cutoff: float,
dr_threshold: float = 0.5,
axis: Tuple[int, ...] | None = None,
fractional_coordinates: bool = True,
format: partition.NeighborListFormat = partition.Sparse,
neighbor_list_fn: Callable = partition.neighbor_list,
**neighbor_kwargs,
) -> Tuple[NeighborListFns, Callable[..., Array]]:
"""Convenience wrapper to compute :ref:`EAM <eam-pot>` using a neighbor list."""
metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
neighbor_fn = neighbor_list_fn(
displacement_or_metric,
box_size,
cutoff,
dr_threshold,
mask_self=False,
format=format,
**neighbor_kwargs,
)
def energy_fn(R, neighbor, **kwargs):
mask = partition.neighbor_list_mask(neighbor)
self_mask = partition.neighbor_list_mask(neighbor, mask_self=True)
d = partial(metric, **kwargs)
if neighbor.format is partition.Dense:
dr = space.map_neighbor(d)(R, R[neighbor.idx])
total_charge = util.high_precision_sum(charge_fn(dr) * mask, axis=1)
embedding_energy = embedding_fn(total_charge)
pairwise_energy = util.high_precision_sum(
pairwise_fn(dr) * self_mask, axis=1
)
elif neighbor.format is partition.Sparse:
N = len(R)
dr = space.map_bond(d)(R[neighbor.idx[0]], R[neighbor.idx[1]])
total_charge = ops.segment_sum(charge_fn(dr) * mask, neighbor.idx[0], N)
embedding_energy = embedding_fn(total_charge)
pairwise_energy = ops.segment_sum(
pairwise_fn(dr) * self_mask, neighbor.idx[0], N
)
else:
raise NotImplementedError(
'EAM potential not implemented for OrderedSparse neighbor lists.'
)
return util.high_precision_sum(
embedding_energy + pairwise_energy / 2.0, axis=axis
)
return neighbor_fn, energy_fn
[docs]
def eam_from_lammps_parameters_neighbor_list(
displacement: DisplacementFn,
box_size: float,
f: TextIO,
axis=None,
dr_threshold: float = 0.5,
fractional_coordinates=True,
neighbor_list_fn: Callable = partition.neighbor_list,
**neighbor_kwargs,
) -> Tuple[NeighborListFns, Callable[..., Array]]:
"""Convenience wrapper to compute :ref:`EAM energy <eam-pot>`
with parameters from LAMMPS using a neighbor list.."""
return eam_neighbor_list(
displacement,
box_size,
*load_lammps_eam_parameters(f),
dr_threshold=dr_threshold,
neighbor_list_fn=neighbor_list_fn,
**neighbor_kwargs,
)
class BehlerParrinelloEnergy(nnx.Module):
"""Behler-Parrinello symmetry-function energy model.
Callable as ``model(R, neighbor=None, **kwargs) -> scalar``.
"""
def __init__(
self,
sym_fn,
in_features,
mlp_sizes,
*,
rngs,
per_particle=False,
activation=jnp.tanh,
):
self.sym_fn = sym_fn
self.per_particle = per_particle
self.mlp = nn.MLP(
in_features,
mlp_sizes + (1,),
rngs=rngs,
activate_final=False,
activation=activation,
)
def __call__(self, R, neighbor=None, **kwargs):
if neighbor is not None:
sym = self.sym_fn(R, neighbor, **kwargs)
else:
sym = self.sym_fn(R, **kwargs)
readout = vmap(self.mlp)(sym)
if self.per_particle:
return readout
return jnp.sum(readout)
def _probe_sym_dim(sym_fn, species, spatial_dimension=3):
n = len(species) if species is not None else 3
dummy = jnp.zeros((n, spatial_dimension))
return sym_fn(dummy).shape[-1]
[docs]
def behler_parrinello(
displacement: DisplacementFn,
*,
key: jax.Array,
species: Array | None = None,
mlp_sizes: Tuple[int, ...] = (30, 30),
sym_kwargs: Dict[str, Any] | None = None,
per_particle: bool = False,
spatial_dimension: int = 3,
activation: Callable = jnp.tanh,
):
"""Build a Behler-Parrinello symmetry-function energy model.
Args:
displacement: Function to compute displacement between two positions.
key: A JAX random key used to initialize model parameters.
species: None or an integer array specifying the species of each particle.
mlp_sizes: Layer widths for the per-particle MLP.
sym_kwargs: Additional kwargs for symmetry function construction.
per_particle: If True, return per-particle energies instead of the total.
spatial_dimension: The spatial dimension of the system (default 3).
activation: Activation function for the MLP (default ``jnp.tanh``).
Returns:
An ``energy_fn(R, **kwargs) -> scalar``. The underlying NNX module is
accessible via ``energy_fn.model`` for training.
"""
if sym_kwargs is None:
sym_kwargs = {}
sym_fn = bp.symmetry_functions(displacement, species, **sym_kwargs)
sym_dim = _probe_sym_dim(sym_fn, species, spatial_dimension)
model = BehlerParrinelloEnergy(
sym_fn,
sym_dim,
mlp_sizes,
rngs=nnx.Rngs(key),
per_particle=per_particle,
activation=activation,
)
def energy_fn(R, **kwargs):
return model(R, **kwargs)
setattr(energy_fn, 'model', model)
return energy_fn
[docs]
def behler_parrinello_neighbor_list(
displacement: DisplacementFn,
box_size: float,
*,
key: jax.Array,
species: Array | None = None,
mlp_sizes: Tuple[int, ...] = (30, 30),
sym_kwargs: Dict[str, Any] | None = None,
dr_threshold: float = 0.5,
fractional_coordinates: bool = False,
format: partition.NeighborListFormat = partition.Sparse,
neighbor_list_fn: Callable = partition.neighbor_list,
activation: Callable = jnp.tanh,
**neighbor_kwargs,
):
"""Build a Behler-Parrinello energy model with neighbor lists.
Args:
displacement: Function to compute displacement between two positions.
box_size: The size of the simulation volume.
key: A JAX random key used to initialize model parameters.
species: None or an integer array specifying the species of each particle.
mlp_sizes: Layer widths for the per-particle MLP.
sym_kwargs: Additional kwargs for symmetry function construction.
dr_threshold: Halo radius for neighbor list construction.
fractional_coordinates: Whether coordinates are fractional.
format: Neighbor list format (``Dense`` or ``Sparse``).
activation: Activation function for the MLP (default ``jnp.tanh``).
Returns:
A tuple ``(neighbor_fn, energy_fn)`` where
``energy_fn(R, neighbor=nbrs) -> scalar``. The underlying NNX module is
accessible via ``energy_fn.model`` for training.
"""
if sym_kwargs is None:
sym_kwargs = {}
cutoff_distance = 8.0
if 'cutoff_distance' in sym_kwargs:
cutoff_distance = sym_kwargs['cutoff_distance']
neighbor_fn = neighbor_list_fn(
displacement,
box_size,
cutoff_distance,
dr_threshold,
fractional_coordinates=fractional_coordinates,
format=format,
**neighbor_kwargs,
)
box_arr = jnp.asarray(box_size)
spatial_dimension = box_arr.shape[-1] if box_arr.ndim >= 1 else 3
sym_fn = bp.symmetry_functions_neighbor_list(
displacement, species, **sym_kwargs
)
sym_dim = _probe_sym_dim(
bp.symmetry_functions(displacement, species, **sym_kwargs),
species,
spatial_dimension,
)
model = BehlerParrinelloEnergy(
sym_fn,
sym_dim,
mlp_sizes,
rngs=nnx.Rngs(key),
activation=activation,
)
def energy_fn(R, neighbor=None, **kwargs):
return model(R, neighbor, **kwargs)
setattr(energy_fn, 'model', model)
return neighbor_fn, energy_fn
class EnergyGraphNet(nnx.Module):
"""Graph Neural Network energy model.
Combines a ``GraphNetEncoder`` with a global decoder to predict a scalar
energy from positions and (optionally) a neighbor list.
The model stores the graph-construction config so that ``__call__`` accepts
``(R, neighbor=None, **kwargs)`` directly, matching the standard jax_md
energy function signature.
"""
def __init__(
self,
displacement_fn: space.DisplacementFn,
r_cutoff: float,
*,
key: jax.Array,
in_node_features: int = 1,
in_edge_features: int = 3,
in_global_features: int = 1,
n_recurrences: int = 2,
mlp_sizes: Tuple[int, ...] = (64, 64),
activation: Callable = jax.nn.softplus,
kernel_init: Callable = nn.DEFAULT_KERNEL_INIT,
bias_init: Callable = nn.DEFAULT_BIAS_INIT,
format: partition.NeighborListFormat = partition.Dense,
dr_threshold: float = 0.0,
nodes: Array | None = None,
):
rngs = nnx.Rngs(key)
self.displacement_fn = displacement_fn
self.r_cutoff = r_cutoff
self.dr_threshold = dr_threshold
self.format = format
self.nodes = nodes
kw: Dict[str, Any] = dict(
activation=activation, kernel_init=kernel_init, bias_init=bias_init
)
m = mlp_sizes[-1]
self.GraphNetEncoder = nn.GraphNetEncoder(
in_node_features=in_node_features,
in_edge_features=in_edge_features,
in_global_features=in_global_features,
n_recurrences=n_recurrences,
mlp_sizes=mlp_sizes,
rngs=rngs,
format=format,
**kw,
)
self.GlobalDecoder = nn.MLP(
m,
mlp_sizes + (1,),
rngs=rngs,
activate_final=False,
**kw,
)
def _build_graph(self, R, neighbor, **kwargs):
if neighbor is None:
return dense_graph_input(
self.displacement_fn, self.r_cutoff, self.nodes, R, **kwargs
)
return neighbor_graph_input(
self.displacement_fn,
self.r_cutoff,
self.dr_threshold,
self.format,
self.nodes,
R,
neighbor,
**kwargs,
)
def __call__(self, R, neighbor=None, **kwargs):
graph = self._build_graph(R, neighbor, **kwargs)
output = self.GraphNetEncoder(graph)
output = jnp.squeeze(self.GlobalDecoder(output.globals), axis=-1)
if self.format is partition.Sparse:
output = output[0]
return output
def canonicalize_node_state(nodes: Array | None) -> Array | None:
if nodes is None:
return nodes
if nodes.ndim == 1:
nodes = nodes[:, jnp.newaxis]
if nodes.ndim != 2:
raise ValueError(
'Nodes must be a [N, node_dim] array. Found {}.'.format(nodes.shape)
)
return nodes
def node_state_or_default(
nodes: Array | None, R: Array, kwargs: Dict[str, Any]
) -> Array:
if 'nodes' in kwargs:
nodes = canonicalize_node_state(kwargs['nodes'])
return jnp.zeros((R.shape[0], 1), R.dtype) if nodes is None else nodes
def dense_graph_input(
displacement_fn: DisplacementFn,
r_cutoff: float,
nodes: Array | None,
R: Array,
**kwargs,
) -> nn.GraphsTuple:
N = R.shape[0]
d = partial(displacement_fn, **kwargs)
d = space.map_product(d)
dR = d(R, R)
dr_2 = space.square_distance(dR)
nodes = node_state_or_default(nodes, R, kwargs)
edge_idx = jnp.broadcast_to(jnp.arange(N)[jnp.newaxis, :], (N, N))
edge_idx = jnp.where(dr_2 < r_cutoff**2, edge_idx, N)
globals_ = jnp.zeros((1,), R.dtype)
return nn.GraphsTuple(nodes, dR, globals_, edge_idx)
def neighbor_graph_input(
displacement_fn: DisplacementFn,
r_cutoff: float,
dr_threshold: float,
format: partition.NeighborListFormat,
nodes: Array | None,
R: Array,
neighbor: NeighborList,
**kwargs,
):
N = R.shape[0]
d = partial(displacement_fn, **kwargs)
nodes = node_state_or_default(nodes, R, kwargs)
globals_ = jnp.zeros((1,), R.dtype)
if format is partition.Dense:
d = space.map_neighbor(d)
R_neigh = R[neighbor.idx]
dR = d(R, R_neigh)
dr_2 = space.square_distance(dR)
edge_idx = jnp.where(dr_2 < r_cutoff**2, neighbor.idx, N)
return nn.GraphsTuple(nodes, dR, globals_, edge_idx)
d = space.map_bond(d)
dR = d(R[neighbor.idx[0]], R[neighbor.idx[1]])
if dr_threshold > 0.0:
dr_2 = space.square_distance(dR)
mask = dr_2 < r_cutoff**2 + 1e-5
graph = partition.to_jraph(neighbor, mask)
# TODO(schsam): It seems wasteful to recompute dR after we remask the
# edges. If I can think of a clean way to get rid of this, I should.
dR = d(R[graph.receivers], R[graph.senders])
else:
graph = partition.to_jraph(neighbor)
return graph._replace(
nodes=jnp.concatenate(
(nodes, jnp.zeros((1,) + nodes.shape[1:], R.dtype)), axis=0
),
edges=dR,
globals=jnp.broadcast_to(globals_[:, None], (2, 1)),
)
[docs]
def graph_network(
displacement_fn: DisplacementFn,
r_cutoff: float,
*,
key: jax.Array,
nodes: Array | None = None,
spatial_dimension: int = 3,
n_recurrences: int = 2,
mlp_sizes: Tuple[int, ...] = (64, 64),
**model_kwargs,
) -> Callable[..., Array]:
"""Convenience wrapper around EnergyGraphNet model.
Args:
displacement_fn: Function to compute displacement between two positions.
r_cutoff: A floating point cutoff; Edges will be added to the graph
for pairs of particles whose separation is smaller than the cutoff.
key: A JAX random key used to initialize model parameters.
nodes: None or an ndarray of shape ``[N, node_dim]`` specifying the state
of the nodes. If None this is set to the zeros vector. Often, for a
system with multiple species, this could be the species id.
spatial_dimension: The spatial dimension of the system (default 3).
n_recurrences: The number of steps of message passing in the graph network.
mlp_sizes: A tuple specifying the layer-widths for the fully-connected
networks used to update the states in the graph network.
**model_kwargs: Additional kwargs forwarded to ``EnergyGraphNet``
(e.g. ``activation``, ``kernel_init``, ``bias_init``).
Returns:
An ``energy_fn(R, **kwargs) -> scalar``. The underlying NNX module is
accessible via ``energy_fn.model`` for training.
"""
nodes = canonicalize_node_state(nodes)
node_dim = 1 if nodes is None else nodes.shape[-1]
model = EnergyGraphNet(
displacement_fn,
r_cutoff,
key=key,
in_node_features=node_dim,
in_edge_features=spatial_dimension,
n_recurrences=n_recurrences,
mlp_sizes=mlp_sizes,
nodes=nodes,
**model_kwargs,
)
def energy_fn(R, **kwargs):
return model(R, **kwargs)
setattr(energy_fn, 'model', model)
return energy_fn
[docs]
def graph_network_neighbor_list(
displacement_fn: DisplacementFn,
box_size: Box,
r_cutoff: float,
dr_threshold: float,
*,
key: jax.Array,
nodes: Array | None = None,
spatial_dimension: int = 3,
n_recurrences: int = 2,
mlp_sizes: Tuple[int, ...] = (64, 64),
fractional_coordinates: bool = False,
format: partition.NeighborListFormat = partition.Sparse,
neighbor_list_fn: Callable = partition.neighbor_list,
**neighbor_kwargs,
):
"""Convenience wrapper around EnergyGraphNet model using neighbor lists.
Args:
displacement_fn: Function to compute displacement between two positions.
box_size: The size of the simulation volume, used to construct neighbor
list.
r_cutoff: A floating point cutoff; Edges will be added to the graph
for pairs of particles whose separation is smaller than the cutoff.
dr_threshold: A floating point number specifying a "halo" radius that we use
for neighbor list construction. See ``neighbor_list`` for details.
key: A JAX random key used to initialize model parameters.
nodes: None or an ndarray of shape ``[N, node_dim]`` specifying the state
of the nodes. If None this is set to the zeros vector. Often, for a
system with multiple species, this could be the species id.
spatial_dimension: The spatial dimension of the system (default 3).
n_recurrences: The number of steps of message passing in the graph network.
mlp_sizes: A tuple specifying the layer-widths for the fully-connected
networks used to update the states in the graph network.
fractional_coordinates: A boolean specifying whether or not the coordinates
will be in the unit cube.
format: The format of the neighbor list. See ``partition.NeighborListFormat``
for details. Only ``Dense`` and ``Sparse`` formats are accepted. If the
``Dense`` format is used, then the graph network is constructed using the
JAX MD backend, otherwise Jraph is used.
**neighbor_kwargs: Additional kwargs forwarded to the neighbor list
constructor and to ``EnergyGraphNet`` (e.g. ``activation``,
``kernel_init``, ``bias_init``).
Returns:
A tuple ``(neighbor_fn, energy_fn)`` where
``energy_fn(R, neighbor=nbrs) -> scalar``. The underlying NNX module is
accessible via ``energy_fn.model`` for training.
"""
nodes = canonicalize_node_state(nodes)
node_dim = 1 if nodes is None else nodes.shape[-1]
box_arr = jnp.asarray(box_size)
if box_arr.ndim >= 1:
spatial_dimension = box_arr.shape[-1]
model_kwargs = {}
for k in ('activation', 'kernel_init', 'bias_init'):
if k in neighbor_kwargs:
model_kwargs[k] = neighbor_kwargs.pop(k)
neighbor_fn = neighbor_list_fn(
displacement_fn,
box_size,
r_cutoff,
dr_threshold,
mask_self=False,
fractional_coordinates=fractional_coordinates,
format=format,
**neighbor_kwargs,
)
model = EnergyGraphNet(
displacement_fn,
r_cutoff,
key=key,
in_node_features=node_dim,
in_edge_features=spatial_dimension,
n_recurrences=n_recurrences,
mlp_sizes=mlp_sizes,
format=format,
dr_threshold=dr_threshold,
nodes=nodes,
**model_kwargs,
)
def energy_fn(R, neighbor=None, **kwargs):
return model(R, neighbor, **kwargs)
setattr(energy_fn, 'model', model)
return neighbor_fn, energy_fn
def nequip_neighbor_list(
displacement_fn,
box,
cfg: ConfigDict | None = None,
atoms=None,
neighbor_list_fn: Callable = partition.neighbor_list,
featurizer_fn: Callable = nn.util.neighbor_list_featurizer,
**nl_kwargs,
):
"""Convenience wrapper to compute NequIP energy using a neighbor list.
Args:
displacement_fn: Displacement function from `jax_md.space`.
box: Box matrix with columns as lattice vectors, shape (dim, dim).
cfg: NequIP configuration. If None, uses default config.
atoms: One-hot encoding of atom types.
neighbor_list_fn: Neighbor list constructor. Can be `partition.neighbor_list`
(default, uses MIC) or `custom_partition.neighbor_list_multi_image` for
small boxes where r_cut > L/2.
featurizer_fn: Function to create a featurizer. Signature:
`featurizer_fn(displacement_fn) -> featurize(atoms, position, neighbor)`.
Defaults to `nn.util.neighbor_list_featurizer` for standard MIC.
Use `custom_partition.graph_featurizer` with `neighbor_list_multi_image`
for correct multi-image displacement computation.
**nl_kwargs: Additional kwargs for neighbor list (e.g., `fractional_coordinates`).
Returns:
Tuple of (neighbor_fn, init_fn, energy_fn).
"""
cfg = nequip.default_config() if cfg is None else cfg
model = nequip.model_from_config(cfg)
neighbor_fn = neighbor_list_fn(
displacement_fn, box, cfg.r_max, format=partition.Sparse, **nl_kwargs
)
featurizer = featurizer_fn(displacement_fn)
def init_fn(key, position, neighbor, **kwargs):
_atoms = kwargs.pop('atoms', atoms)
if _atoms is None:
raise ValueError('A one-hot encoding of the atoms is required.')
# TODO: It would be nicer to do this without computing flops
# since we really only need the shape of the graph for initialization.
graph = featurizer(_atoms, position, neighbor, **kwargs)
return model.init(key, graph)
def energy_fn(params, position, neighbor, **kwargs):
_atoms = kwargs.pop('atoms', atoms)
if _atoms is None:
raise ValueError('A one-hot encoding of the atoms is required.')
graph = featurizer(_atoms, position, neighbor, **kwargs)
out: Any = model.apply(params, graph)
return out[0, 0]
return neighbor_fn, init_fn, energy_fn
def load_gnome_model_neighbor_list(
displacement_fn,
box,
directory,
atoms=None,
neighbor_list_fn: Callable = partition.neighbor_list,
featurizer_fn: Callable = nn.util.neighbor_list_featurizer,
**nl_kwargs,
):
"""Load a gnome model from a checkpoint.
Args:
displacement_fn: Displacement function from `jax_md.space`.
box: Box matrix with columns as lattice vectors, shape (dim, dim).
directory: Directory containing the gnome model checkpoint.
atoms: One-hot encoding of atom types.
neighbor_list_fn: Neighbor list constructor. Can be `partition.neighbor_list`
(default, uses MIC) or `custom_partition.neighbor_list_multi_image` for
small boxes where r_cut > L/2.
featurizer_fn: Function to create a featurizer. Signature:
`featurizer_fn(displacement_fn) -> featurize(atoms, position, neighbor)`.
Defaults to `nn.util.neighbor_list_featurizer` for standard MIC.
Use `custom_partition.graph_featurizer` with `neighbor_list_multi_image`
for correct multi-image displacement computation.
**nl_kwargs: Additional kwargs for neighbor list (e.g., `fractional_coordinates`).
Returns:
Tuple of (neighbor_fn, energy_fn).
Note:
When using `neighbor_list_multi_image`, you must also use a compatible
featurizer (e.g., `custom_partition.graph_featurizer`) to correctly compute
displacements using the stored shifts instead of MIC.
"""
cfg, model, params = gnome.load_model(directory)
neighbor_fn = neighbor_list_fn(
displacement_fn, box, cfg.r_max, format=partition.Sparse, **nl_kwargs
)
featurizer = featurizer_fn(displacement_fn)
def energy_fn(position, neighbor, **kwargs):
_atoms = kwargs.pop('atoms', atoms)
if _atoms is None:
raise ValueError('A one-hot encoding of the atoms is required.')
graph = featurizer(_atoms, position, neighbor, **kwargs)
out: Any = model.apply(params, graph)
return out[0, 0]
return neighbor_fn, energy_fn
# TRIANGULATED SURFACE POTENTIALS / MEMBRANE POTENTIALS
def triangle_area_potential(
R_mem: Array,
triangles: Array,
displacement_fn: DisplacementOrMetricFn,
A_0: Array,
k: Array,
) -> Array:
""".. _triangle_area_potential
Local area conservation of the triangles in the vesicle discretization as
proposed by Vutukuri HR et al. (Gompper Group)
https://doi.org/10.1038/s41586-020-2730-x
Args:
R_mem (Array, shape: (N, spatial dim)): Positions of the membrane
vertices. (R_mem is allowed to contain non-membrane particles too,
only particles with indices in 'triangles' are considered in the
calculation)
triangles (Array, shape: (2(N-1), 3)): Array of triangles storing the indices
of the vertices that comprise the triangle
displacement_fn (DisplacementOrMetricFn): _description_
A_0 (Array): desired triangle area
k (Array): local-area conservation coefficient
Returns:
energy contribution due to local area conservation
"""
areas = _calc_triangle_areas(R_mem, triangles, displacement_fn)
energy = 0.5 * k * util.high_precision_sum((areas - A_0) ** 2 / A_0)
return energy
def _calc_triangle_areas(
R_mem: Array, triangles: Array, displacement_fn: DisplacementOrMetricFn
) -> Array:
"""
Calculate the areas of the triangle given an a point cloud and
its triangulation
Args:
R_mem (Array, shape: (N, spatial dim)): Positions of the membrane
vertices. (R_mem is allowed to contain non-membrane particles too,
only particles with indices in 'triangles' are considered in the
calculation)
triangles (Array, shape: (2(N-1), 3)): Array of triangles storing the indices
of the vertices that comprise the triangle
displacement_fn (DisplacementOrMetricFn): _description_
Returns:
Array of shape (N) where the i-th entry is the area of the i-th triangle
of the triangles array.
"""
R0 = R_mem[triangles[:, 0]]
vec_displacement_fn = vmap(displacement_fn)
dR1 = vec_displacement_fn(R_mem[triangles[:, 1]], R0)
dR2 = vec_displacement_fn(R_mem[triangles[:, 2]], R0)
dr1 = space.distance(dR1)
dr2 = space.distance(dR2)
cos = vmap(quantity.cosine_angle_between_two_vectors)(dR1, dR2)
sin = jnp.sqrt(1 - cos**2)
return 0.5 * dr1 * dr2 * sin
def volume_potential(R_mem: Array, triangles: Array, V_0: float, k: float):
""".. _volume_potential
Global volume conservation of the vesicle as used by Vutukuri HR et al.
(Gompper Group) https://doi.org/10.1038/s41586-020-2730-x
Args:
R_mem (Array, shape: (N, spatial dim)): Positions of the membrane
vertices. (R_mem is allowed to contain non-membrane particles too,
only particles with indices in 'triangles' are considered in the
calculation)
triangles (Array, shape: (2(N-1), 3)): Array of triangles storing the indices
of the vertices that comprise the triangle
V_0 (float): desired vesicle volume
k (float): volume stiffness
Returns:
energy contribution due to global volume conservation
Note:
This calculation assumes the membrane vertices are in a single unwrapped
coordinate frame. Wrapped periodic coordinates are not supported because
the enclosed volume is a global surface property.
"""
V = _calc_volume(R_mem, triangles)
return k * (V - V_0) ** 2 / (2 * V_0)
def _calc_volume(R_mem: Array, triangles: Array) -> Array:
"""
Calculates the volume enclosed by a triangulated surface
(arbitrary non-convex polyhedron).
Assumes vertices are ordered in triangle list.
Implementation Follows https://doi.org/10.1109/MCG.1984.6429334
Args:
R_mem (Array, shape: (N, spatial dim)): Positions of the membrane
vertices. (R_mem is allowed to contain non-membrane particles too,
only particles with indices in 'triangles' are considered in the
calculation)
triangles (Array, shape: (2(N-1), 3)): Array of triangles storing the indices
of the vertices that comprise the triangle
Returns:
Volume enclosed by triangulated surface
"""
a = R_mem[triangles[:, 0]]
b = R_mem[triangles[:, 1]]
c = R_mem[triangles[:, 2]]
det = (
a[:, 0] * (b[:, 1] * c[:, 2] - c[:, 1] * b[:, 2])
- a[:, 1] * (b[:, 0] * c[:, 2] - c[:, 0] * b[:, 2])
+ a[:, 2] * (b[:, 0] * c[:, 1] - c[:, 0] * b[:, 1])
)
return jnp.abs(util.high_precision_sum(det)) / 6
def bending_potential(
R_mem: Array,
triangles: Array,
displacement_fn: DisplacementOrMetricFn,
kappa: Array,
) -> Array:
""".. _bending_potential
Calculates the bending potential of an triangulated surface,
based on discretization by Gompper: https://doi.org/10.1051/jp1:1996246
Args:
R_mem (Array, shape: (N, spatial dim)): Positions of the membrane
vertices
triangles (Array, shape: (N, 3)): Array of triangles storing the indices
of the vertices that comprise the triangle
displacement_fn (DisplacementOrMetricFn): _description_
kappa (Array, scalar): Bending rigidity
Returns:
Energy of the membrane conformation due to the bending potential
"""
displacement_fn = vmap(displacement_fn)
cos = vmap(quantity.cosine_angle_between_two_vectors)
N = R_mem.shape[0]
R0 = R_mem[triangles[:, 0]]
R1 = R_mem[triangles[:, 1]]
R2 = R_mem[triangles[:, 2]]
dR01 = displacement_fn(R1, R0) # R1 - R0
dR12 = displacement_fn(R2, R1) # R2 - R1
dR20 = displacement_fn(R0, R2) # R0 - R2
dr01 = space.distance(dR01)
dr12 = space.distance(dR12)
dr20 = space.distance(dR20)
cos01 = cos(dR20, -dR12)
cos12 = cos(dR01, -dR20)
cos20 = cos(dR12, -dR01)
cot01 = cos01 / (jnp.sqrt(1 - cos01**2) + 1e-7)
cot12 = cos12 / (jnp.sqrt(1 - cos12**2) + 1e-7)
cot20 = cos20 / (jnp.sqrt(1 - cos20**2) + 1e-7)
sigma0 = dr01**2 * cot01 + dr20**2 * cot20
sigma1 = dr01**2 * cot01 + dr12**2 * cot12
sigma2 = dr12**2 * cot12 + dr20**2 * cot20
sigma = jnp.zeros(N) # sigma per vertex
sigma = sigma.at[triangles[:, 0]].add(sigma0)
sigma = sigma.at[triangles[:, 1]].add(sigma1)
sigma = sigma.at[triangles[:, 2]].add(sigma2)
sigma = sigma / 8
rho = jnp.zeros((N, 3)) # per Vertex
rho = rho.at[triangles[:, 0]].add(
cot20[:, None] * dR20 - cot01[:, None] * dR01
)
rho = rho.at[triangles[:, 1]].add(
cot01[:, None] * dR01 - cot12[:, None] * dR12
)
rho = rho.at[triangles[:, 2]].add(
cot12[:, None] * dR12 - cot20[:, None] * dR20
)
per_particle = jnp.sum(rho * rho, axis=1) / (sigma + 1e-7)
return (kappa / 8) * util.high_precision_sum(per_particle)
def uma_neighbor_list(
displacement_fn,
box,
cfg=None,
atoms=None,
checkpoint_path=None,
head_type='mlp',
neighbor_list_fn: Callable = custom_partition.neighbor_list_multi_image,
featurizer_fn: Callable | None = None,
charge=None,
spin=None,
dataset_idx=None,
head_dataset='omat',
apply_atom_refs=True,
atom_refs=None,
use_kernels=True,
merge_mole=None,
so2_block_gemm=None,
**nl_kwargs,
):
"""Convenience wrapper to compute UMA energy using a neighbor list.
This follows the standard JAX-MD pattern of returning
(neighbor_fn, init_fn, energy_fn). Forces can be computed via
``jax.grad(energy_fn)`` which ensures energy conservation.
Args:
displacement_fn: Displacement function from ``jax_md.space``.
box: Box matrix with columns as lattice vectors, shape ``(dim, dim)``.
cfg: ``UMAConfig`` instance. If None, uses default config.
atoms: Integer atomic numbers, shape ``[num_atoms]``.
checkpoint_path: Path to a PyTorch checkpoint to load pretrained weights.
If provided, ``init_fn`` returns the converted weights instead of
random initialization.
head_type: Energy head type: ``'mlp'`` or ``'linear'``.
neighbor_list_fn: Neighbor list constructor (default:
``custom_partition.neighbor_list_multi_image``).
featurizer_fn: Function to create a UMA featurizer. Defaults to
``uma_multi_image_featurizer`` for multi-image neighbor lists. When
using ``partition.neighbor_list``, pass ``uma_featurizer`` explicitly.
charge: System charge(s), shape ``[num_systems]`` (default: ``[0]``).
spin: System spin(s), shape ``[num_systems]`` (default: ``[0]``).
dataset_idx: Integer dataset index, shape ``[num_systems]`` (default: ``[0]``).
head_dataset: Dataset name for pretrained MoE energy head selection.
apply_atom_refs: Whether to apply checkpoint task normalizer and element
references for pretrained checkpoints. Defaults to True when refs are
available.
atom_refs: Optional per-element task reference array, or a dictionary with
``element_refs``, ``mean``, and ``rmsd`` values. If only an array is
provided, ``mean=0`` and ``rmsd=1`` are used.
use_kernels: UMA kernel toggle. True enables eligible JAX-MD UMA Pallas
kernels.
merge_mole: If True, pre-mix MoE expert weights for single-system
inference and skip runtime routing. If None, disabled by default.
so2_block_gemm: If True, use a block GEMM formulation for SO2 m>0
convolutions. If None, enabled for pretrained MoE checkpoints and
disabled otherwise.
**nl_kwargs: Additional kwargs for neighbor list (e.g., ``fractional_coordinates``).
Returns:
Tuple of ``(neighbor_fn, init_fn, energy_fn)`` where:
- ``neighbor_fn``: Allocates/updates neighbor list.
- ``init_fn(key, position, neighbor, **kw)``: Returns model parameters.
- ``energy_fn(params, position, neighbor, **kw)``: Returns scalar energy.
"""
from jax_md._nn.uma.model import UMABackbone, default_config
from jax_md._nn.uma.heads import MLPEnergyHead, LinearEnergyHead
from jax_md._nn.uma.featurizer import uma_multi_image_featurizer
import flax.linen as flax_nn
# Load pretrained checkpoint (MoE) if provided
is_moe = False
pretrained_params = None
if checkpoint_path is not None:
from jax_md._nn.uma.model_moe import load_pretrained
moe_config, backbone_params, head_params = load_pretrained(
checkpoint_path, head_dataset=head_dataset
)
if head_params is None:
raise ValueError(
f'Checkpoint {checkpoint_path!r} provides no energy head for '
f'dataset {head_dataset!r}.'
)
cfg = moe_config
is_moe = True
pretrained_params = {
'params': {
'backbone': backbone_params['params'],
'energy_head': head_params['params'],
}
}
elif cfg is None:
cfg = default_config()
# If callers provide a preloaded MoE config/params pair, use the MoE
# backbone even when checkpoint_path is not passed here.
is_moe = is_moe or hasattr(cfg, 'num_experts')
auto_moe_optimizations = checkpoint_path is not None and is_moe
merge_mole = False if merge_mole is None else bool(merge_mole)
so2_block_gemm = (
auto_moe_optimizations if so2_block_gemm is None else bool(so2_block_gemm)
)
if use_kernels is not None:
from dataclasses import replace
cfg = replace(cfg, use_kernels=bool(use_kernels))
if so2_block_gemm:
if not is_moe:
raise ValueError('so2_block_gemm=True requires a UMA MoE config.')
from dataclasses import replace
cfg = replace(cfg, so2_block_gemm=True)
if merge_mole and not is_moe:
raise ValueError('merge_mole=True requires a UMA MoE checkpoint/config.')
runtime_cfg = cfg
if merge_mole:
from dataclasses import replace
runtime_cfg = replace(cfg, merged_mole=True)
resolved_atom_refs = None
energy_mean = None
energy_rmsd = None
should_apply_atom_refs = apply_atom_refs and (
checkpoint_path is not None or atom_refs is not None
)
if should_apply_atom_refs:
if atom_refs is None:
from jax_md._nn.uma.pretrained import load_energy_correction
correction = None
if checkpoint_path is not None:
try:
correction = load_energy_correction(
checkpoint_path, dataset=head_dataset
)
except Exception as exc:
raise ValueError(
'Failed to load UMA atom references or energy normalizer for '
f'{checkpoint_path!r}. Pass apply_atom_refs=False to use raw '
'checkpoint energies, or pass atom_refs explicitly.'
) from exc
if correction is not None:
atom_refs = correction.get('element_refs')
energy_mean = correction.get('mean')
energy_rmsd = correction.get('rmsd')
else:
energy_mean = 0.0
energy_rmsd = 1.0
elif isinstance(atom_refs, dict):
correction = atom_refs
atom_refs = correction.get('element_refs', correction.get('atom_refs'))
energy_mean = correction.get('mean', 0.0)
energy_rmsd = correction.get('rmsd', 1.0)
else:
energy_mean = 0.0
energy_rmsd = 1.0
if atom_refs is not None:
resolved_atom_refs = jnp.asarray(atom_refs, dtype=jnp.float32)
energy_mean = jnp.asarray(energy_mean, dtype=jnp.float32)
energy_rmsd = jnp.asarray(energy_rmsd, dtype=jnp.float32)
# Build neighbor list. The UMA default uses explicit periodic images and
# expects a box matrix; keep vector boxes working for common orthorhombic use.
neighbor_box = box
if neighbor_list_fn is custom_partition.neighbor_list_multi_image:
neighbor_box = jnp.asarray(box)
if neighbor_box.ndim == 1:
neighbor_box = jnp.diag(neighbor_box)
nl_kwargs.setdefault('fractional_coordinates', False)
neighbor_fn = neighbor_list_fn(
displacement_fn,
neighbor_box,
cfg.cutoff,
format=partition.Sparse,
**nl_kwargs,
)
if featurizer_fn is None:
featurizer_fn = uma_multi_image_featurizer
featurizer = featurizer_fn(displacement_fn, cutoff=cfg.cutoff)
# Combined backbone + head as a single Flax module
class UMAEnergyModel(flax_nn.Module):
config: Any
use_moe: bool = False
atom_refs: Any = None
energy_mean: Any = None
energy_rmsd: Any = None
@flax_nn.compact
def __call__(self, features):
if self.use_moe:
from jax_md._nn.uma.model_moe import UMAMoEBackbone
backbone = UMAMoEBackbone(config=self.config, name='backbone')
else:
backbone = UMABackbone(config=self.config, name='backbone')
emb = backbone(
features['positions'],
features['atomic_numbers'],
features['batch'],
features['edge_index'],
features['edge_distance_vec'],
features['charge'],
features['spin'],
features.get('dataset_idx'),
)
if head_type == 'mlp':
head = MLPEnergyHead(
sphere_channels=self.config.sphere_channels,
hidden_channels=self.config.hidden_channels,
name='energy_head',
)
else:
head = LinearEnergyHead(
sphere_channels=self.config.sphere_channels,
name='energy_head',
)
num_systems = features['charge'].shape[0]
result = head(emb['node_embedding'], features['batch'], num_systems)
energy = result['energy']
if self.energy_rmsd is not None:
energy = energy * self.energy_rmsd.astype(energy.dtype)
energy = energy + self.energy_mean.astype(energy.dtype)
if self.atom_refs is not None:
ref_shift = (
jnp.zeros((num_systems,), dtype=energy.dtype)
.at[features['batch']]
.add(self.atom_refs[features['atomic_numbers']].astype(energy.dtype))
)
energy = energy + ref_shift
return energy
model = UMAEnergyModel(
config=runtime_cfg,
use_moe=is_moe,
atom_refs=resolved_atom_refs,
energy_mean=energy_mean,
energy_rmsd=energy_rmsd,
)
init_model = model
if merge_mole:
init_model = UMAEnergyModel(
config=cfg,
use_moe=is_moe,
atom_refs=resolved_atom_refs,
energy_mean=energy_mean,
energy_rmsd=energy_rmsd,
)
def init_fn(key, position, neighbor, **kwargs):
_atoms = kwargs.pop('atoms', atoms)
if _atoms is None:
raise ValueError('Integer atomic numbers are required.')
_charge = kwargs.pop('charge', charge)
_spin = kwargs.pop('spin', spin)
_dataset_idx = kwargs.pop('dataset_idx', dataset_idx)
features = featurizer(
_atoms,
position,
neighbor,
charge=_charge,
spin=_spin,
dataset_idx=_dataset_idx,
**kwargs,
)
if pretrained_params is not None:
result_params = pretrained_params
else:
result_params = init_model.init(key, features)
if merge_mole:
if features['charge'].shape[0] != 1:
raise ValueError('merge_mole=True supports only one system.')
from jax_md._nn.uma.model_moe import (
UMAMoEBackbone,
UMAMoEConfig,
merge_mole_params,
)
moe_cfg = cfg
if not isinstance(moe_cfg, UMAMoEConfig):
raise ValueError('merge_mole=True requires an MoE config.')
routing_backbone = UMAMoEBackbone(config=moe_cfg)
routing_out: Any = routing_backbone.apply(
{'params': result_params['params']['backbone']},
features['positions'],
features['atomic_numbers'],
features['batch'],
features['edge_index'],
features['edge_distance_vec'],
features['charge'],
features['spin'],
features.get('dataset_idx'),
)
result_params = merge_mole_params(
result_params, routing_out['expert_coefficients']
)
return result_params
def energy_fn(params, position, neighbor, **kwargs):
_atoms = kwargs.pop('atoms', atoms)
if _atoms is None:
raise ValueError('Integer atomic numbers are required.')
_charge = kwargs.pop('charge', charge)
_spin = kwargs.pop('spin', spin)
_dataset_idx = kwargs.pop('dataset_idx', dataset_idx)
features = featurizer(
_atoms,
position,
neighbor,
charge=_charge,
spin=_spin,
dataset_idx=_dataset_idx,
**kwargs,
)
# flax apply returns a (output | variables) union; the model yields
# a per-atom energy array here.
out: Any = model.apply(params, features)
return out.sum()
return neighbor_fn, init_fn, energy_fn
def mace_neighbor_list(
displacement_fn,
box,
*,
model,
config: Dict[str, Any],
z_atomic,
r_cutoff: float,
dr_threshold: float = 0.5,
fractional_coordinates: bool = False,
neighbor_list_fn: Callable = partition.neighbor_list,
featurizer_fn=None,
head=None,
**neighbor_kwargs,
):
"""Wraps a MACE-JAX potential as a JAX MD neighbor-list energy.
Args:
displacement_fn: Displacement function from ``jax_md.space``.
box: Periodic box with shape ``(3,)`` or ``(3, 3)``.
model: Callable MACE-JAX model accepting a MACE batch dictionary.
config: Model configuration containing at least ``atomic_numbers``.
z_atomic: Atomic numbers for real atoms, shape ``(N,)``.
r_cutoff: Neighbor cutoff radius.
dr_threshold: Neighbor-list rebuild threshold (skin).
fractional_coordinates: Whether positions are fractional coordinates.
neighbor_list_fn: Neighbor-list factory.
featurizer_fn: Featurizer factory. Defaults to ``mace_featurizer``.
For multi-image neighbor lists, pass ``mace_multi_image_featurizer``.
head: Optional head array to include in the MACE batch.
**neighbor_kwargs: Additional keyword arguments for ``neighbor_list_fn``.
Returns:
A pair ``(neighbor_fn, energy_fn)`` following the standard JAX MD
convention. ``energy_fn`` has signature:
``energy_fn(R, *, box=None, neighbor=None, perturbation=None)``.
"""
from jax_md._nn.mace.featurizer import (
mace_featurizer,
mace_multi_image_featurizer,
)
box_default = jnp.asarray(box)
neighbor_fn = neighbor_list_fn(
displacement_fn,
box_default,
r_cutoff,
dr_threshold=dr_threshold,
fractional_coordinates=fractional_coordinates,
**neighbor_kwargs,
)
if featurizer_fn is None:
featurizer_fn = mace_featurizer
if featurizer_fn is mace_multi_image_featurizer:
featurize = featurizer_fn(
config, z_atomic, fractional_coordinates=fractional_coordinates, head=head
)
else:
featurize = featurizer_fn(
displacement_fn,
config,
z_atomic,
fractional_coordinates=fractional_coordinates,
head=head,
)
@jax.jit
def energy_fn(R, *, box=None, neighbor=None, perturbation=None, **kwargs):
box_now = box if box is not None else box_default
batch = featurize(R, neighbor, box=box_now, perturbation=perturbation)
model_kwargs = dict(kwargs)
model_kwargs.setdefault('compute_force', False)
model_kwargs.setdefault('compute_stress', False)
out = model(batch, **model_kwargs)
e = out['energy'] if isinstance(out, dict) and 'energy' in out else out
e = jnp.asarray(e)
return jnp.reshape(e, ()) if e.shape == (1,) else e
return neighbor_fn, energy_fn