Source code for jax_md._nn.behler_parrinello

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Callable, Tuple, Dict, Any, Optional

import numpy as onp

import jax
from jax import vmap, jit
import jax.numpy as jnp

from jax_md import space, dataclasses, quantity, partition, smap, util
import haiku as hk

from collections import namedtuple
from functools import partial, reduce
from jax.tree_util import tree_map
from jax import ops

import jraph


# Typing


Array = util.Array
f32 = util.f32
f64 = util.f64

InitFn = Callable[..., Array]
CallFn = Callable[..., Array]

DisplacementOrMetricFn = space.DisplacementOrMetricFn
DisplacementFn = space.DisplacementFn
NeighborList = partition.NeighborList

"""
Our neural network force field is based on the Behler-Parrinello neural network
architecture (BP-NN) [1]. The BP-NN architecture uses a relatively simple,
fully connected neural network to predict the local energy for each atom. Then
the total energy is the sum of local energies due to each atom. Atoms of the
same type use the same NN to predict energy.

Each atomic NN is applied to hand-crafted features called symmetry functions.
There are two kinds of symmetry functions: radial and angular. Radial symmetry
functions represent information about two-body interactions of the central
atom, whereas angular symmetry functions represent information about three-body
interactions. Below we implement radial and angular symmetry functions for
arbitrary number of types of atoms (Note that most applications of BP-NN limit
their systems to 1 to 3 types of atoms). We also present a convenience wrapper
that returns radial and angular symmetry functions with symmetry function
parameters that should work reasonably for most systems (the symmetry functions
are taken from reference [2]). Please see references [1, 2] for details about
how the BP-NN works.

[1] Behler, Jörg, and Michele Parrinello. "Generalized neural-network
representation of high-dimensional potential-energy surfaces." Physical
Review Letters 98.14 (2007): 146401.

[2] Artrith, Nongnuch, Björn Hiller, and Jörg Behler. "Neural network
potentials for metals and oxides–First applications to copper clusters at zinc
oxide."
Physica Status Solidi (b) 250.6 (2013): 1191-1203.
"""

def _cutoff_fn(dr: Array,
                                 cutoff_distance: float=8.0) -> Array:
  """Function of pairwise distance that smoothly goes to zero at the cutoff."""
  # Also returns zero if the pairwise distance is zero,
  # to prevent a particle from interacting with itself.
  return jnp.where((dr < cutoff_distance) & (dr > 1e-7),
                  0.5 * (jnp.cos(jnp.pi * dr / cutoff_distance) + 1), 0)


[docs]def radial_symmetry_functions(displacement_or_metric: DisplacementOrMetricFn, species: Optional[Array], etas: Array, cutoff_distance: float ) -> Callable[[Array], Array]: """Returns a function that computes radial symmetry functions. Args: displacement: A function that produces an `[N_atoms, M_atoms, spatial_dimension]` of particle displacements from particle positions specified as an `[N_atoms, spatial_dimension] and `[M_atoms, spatial_dimension]` respectively. species: An `[N_atoms]` that contains the species of each particle. etas: List of radial symmetry function parameters that control the spatial extension. cutoff_distance: Neighbors whose distance is larger than `cutoff_distance` do not contribute to each others symmetry functions. The contribution of a neighbor to the symmetry function and its derivative goes to zero at this distance. Returns: A function that computes the radial symmetry function from input `[N_atoms, spatial_dimension]` and returns `[N_atoms, N_etas * N_types]` where `N_etas` is the number of eta parameters, `N_types` is the number of types of particles in the system. """ metric = space.canonicalize_displacement_or_metric(displacement_or_metric) def radial_fn(eta: Array, dr: Array) -> Array: return jnp.exp(-eta * dr**2) * _cutoff_fn(dr, cutoff_distance) radial_fn = vmap(radial_fn, (0, None)) if species is None: def compute_fn(R: Array, **kwargs) -> Array: _metric = partial(metric, **kwargs) _metric = space.map_product(_metric) return util.high_precision_sum(radial_fn(etas, _metric(R, R)), axis=1).T elif isinstance(species, jnp.ndarray): species = onp.array(species) def compute_fn(R: Array, **kwargs) -> Array: _metric = partial(metric, **kwargs) _metric = space.map_product(_metric) def return_radial(atom_type): """Returns the radial symmetry functions for neighbor type atom_type.""" R_neigh = R[species == atom_type, :] dr = _metric(R, R_neigh) return util.high_precision_sum(radial_fn(etas, dr), axis=1).T return jnp.hstack([return_radial(atom_type) for atom_type in onp.unique(species)]) return compute_fn
[docs]def radial_symmetry_functions_neighbor_list( displacement_or_metric: DisplacementOrMetricFn, species: Array, etas: Array, cutoff_distance: float) -> Callable[[Array, NeighborList], Array]: """Returns a function that computes radial symmetry functions. Args: displacement: A function that produces an `[N_atoms, M_atoms, spatial_dimension]` of particle displacements from particle positions specified as an `[N_atoms, spatial_dimension] and `[M_atoms, spatial_dimension]` respectively. species: An `[N_atoms]` that contains the species of each particle. etas: List of radial symmetry function parameters that control the spatial extension. cutoff_distance: Neighbors whose distance is larger than `cutoff_distance` do not contribute to each others symmetry functions. The contribution of a neighbor to the symmetry function and its derivative goes to zero at this distance. Returns: A function that computes the radial symmetry function from input `[N_atoms, spatial_dimension]` and returns `[N_atoms, N_etas * N_types]` where `N_etas` is the number of eta parameters, `N_types` is the number of types of particles in the system. """ metric = space.canonicalize_displacement_or_metric(displacement_or_metric) def radial_fn(eta, dr): return jnp.exp(-eta * dr**2) * _cutoff_fn(dr, cutoff_distance) radial_fn = vmap(radial_fn, (0, None)) def sym_fn(R: Array, neighbor: NeighborList, mask: Array=None, **kwargs) -> Array: _metric = partial(metric, **kwargs) if neighbor.format is partition.Dense: _metric = space.map_neighbor(_metric) R_neigh = R[neighbor.idx] mask = True if mask is None else mask[neighbor.idx] mask = (neighbor.idx < R.shape[0])[None, :, :] & mask dr = _metric(R, R_neigh) return util.high_precision_sum(radial_fn(etas, dr) * mask, axis=2).T elif neighbor.format is partition.Sparse: _metric = space.map_bond(_metric) dr = _metric(R[neighbor.idx[0]], R[neighbor.idx[1]]) radial = radial_fn(etas, dr).T N = R.shape[0] mask = True if mask is None else mask[neighbor.idx[1]] mask = (neighbor.idx[0] < N) & mask return ops.segment_sum(radial * mask[:, None], neighbor.idx[0], N) else: raise ValueError() if species is None: def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array: return sym_fn(R, neighbor, **kwargs) return compute_fn def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array: _metric = partial(metric, **kwargs) def return_radial(atom_type): """Returns the radial symmetry functions for neighbor type atom_type.""" return sym_fn(R, neighbor, species==atom_type, **kwargs) return jnp.hstack([return_radial(atom_type) for atom_type in onp.unique(species)]) return compute_fn
def single_pair_angular_symmetry_function(dR12: Array, dR13: Array, eta: Array, lam: Array, zeta: Array, cutoff_distance: float ) -> Array: """Computes the angular symmetry function due to one pair of neighbors.""" dR23 = dR12 - dR13 dr12_2 = space.square_distance(dR12) dr13_2 = space.square_distance(dR13) dr23_2 = space.square_distance(dR23) dr12 = space.distance(dR12) dr13 = space.distance(dR13) dr23 = space.distance(dR23) triplet_squared_distances = dr12_2 + dr13_2 + dr23_2 triplet_cutoff = reduce( lambda x, y: x * _cutoff_fn(y, cutoff_distance), [dr12, dr13, dr23], 1.0) z = zeta result = 2.0 ** (1.0 - zeta) * (( 1.0 + lam * quantity.cosine_angle_between_two_vectors(dR12, dR13)) ** z * jnp.exp(-eta * triplet_squared_distances) * triplet_cutoff) return result
[docs]def angular_symmetry_functions(displacement: DisplacementFn, species: Array, etas: Array, lambdas: Array, zetas: Array, cutoff_distance: float ) -> Callable[[Array], Array]: """Returns a function that computes angular symmetry functions. Args: displacement: A function that produces an `[N_atoms, M_atoms, spatial_dimension]` of particle displacements from particle positions specified as an `[N_atoms, spatial_dimension] and `[M_atoms, spatial_dimension]` respectively. species: An `[N_atoms]` that contains the species of each particle. eta: Parameter of angular symmetry function that controls the spatial extension. lam: zeta: cutoff_distance: Neighbors whose distance is larger than `cutoff_distance` do not contribute to each others symmetry functions. The contribution of a neighbor to the symmetry function and its derivative goes to zero at this distance. Returns: A function that computes the angular symmetry function from input `[N_atoms, spatial_dimension]` and returns `[N_atoms, N_types * (N_types + 1) / 2]` where `N_types` is the number of types of particles in the system. """ _angular_fn = vmap(single_pair_angular_symmetry_function, (None, None, 0, 0, 0, None)) _batched_angular_fn = lambda dR12, dR13: _angular_fn(dR12, dR13, etas, lambdas, zetas, cutoff_distance) _all_pairs_angular = vmap( vmap(vmap(_batched_angular_fn, (0, None)), (None, 0)), 0) if species is None: def compute_fn(R, **kwargs): D_fn = partial(displacement, **kwargs) D_fn = space.map_product(D_fn) dR = D_fn(R, R) return jnp.sum(_all_pairs_angular(dR, dR), axis=[1, 2]) return compute_fn if isinstance(species, jnp.ndarray): species = onp.array(species) def compute_fn(R, **kwargs): atom_types = onp.unique(species) D_fn = partial(displacement, **kwargs) D_fn = space.map_product(D_fn) D_different_types = [D_fn(R[species == s, :], R) for s in atom_types] out = [] for i in range(len(atom_types)): for j in range(i, len(atom_types)): out += [ jnp.sum( _all_pairs_angular(D_different_types[i], D_different_types[j]), axis=[1, 2]) ] return jnp.hstack(out) return compute_fn
[docs]def angular_symmetry_functions_neighbor_list( displacement: DisplacementFn, species: Array, etas: Array, lambdas: Array, zetas: Array, cutoff_distance: float ) -> Callable[[Array, NeighborList], Array]: """Returns a function that computes angular symmetry functions. Args: displacement: A function that produces an `[N_atoms, M_atoms, spatial_dimension]` of particle displacements from particle positions specified as an `[N_atoms, spatial_dimension] and `[M_atoms, spatial_dimension]` respectively. species: An `[N_atoms]` that contains the species of each particle. eta: Parameter of angular symmetry function that controls the spatial extension. lam: zeta: cutoff_distance: Neighbors whose distance is larger than `cutoff_distance` do not contribute to each others symmetry functions. The contribution of a neighbor to the symmetry function and its derivative goes to zero at this distance. Returns: A function that computes the angular symmetry function from input `[N_atoms, spatial_dimension]` and returns `[N_atoms, N_types * (N_types + 1) / 2]` where `N_types` is the number of types of particles in the system. """ _angular_fn = vmap(single_pair_angular_symmetry_function, (None, None, 0, 0, 0, None)) _batched_angular_fn = lambda dR12, dR13: _angular_fn(dR12, dR13, etas, lambdas, zetas, cutoff_distance) _all_pairs_angular = vmap( vmap(vmap(_batched_angular_fn, (0, None)), (None, 0)), 0) def sym_fn(R: Array, neighbor: NeighborList, mask_i: Array=None, mask_j: Array=None, **kwargs) -> Array: D_fn = partial(displacement, **kwargs) if neighbor.format is partition.Dense: D_fn = space.map_neighbor(D_fn) R_neigh = R[neighbor.idx] dR = D_fn(R, R_neigh) _all_pairs_angular = vmap( vmap(vmap(_batched_angular_fn, (0, None)), (None, 0)), 0) all_angular = _all_pairs_angular(dR, dR) mask_i = True if mask_i is None else mask_i[neighbor.idx] mask_j = True if mask_j is None else mask_j[neighbor.idx] mask_i = (neighbor.idx < R.shape[0]) & mask_i mask_i = mask_i[:, :, jnp.newaxis, jnp.newaxis] mask_j = (neighbor.idx < R.shape[0]) & mask_j mask_j = mask_j[:, jnp.newaxis, :, jnp.newaxis] return util.high_precision_sum(all_angular * mask_i * mask_j, axis=[1, 2]) elif neighbor.format is partition.Sparse: D_fn = space.map_bond(D_fn) dR = D_fn(R[neighbor.idx[0]], R[neighbor.idx[1]]) _all_pairs_angular = vmap(vmap(_batched_angular_fn, (0, None)), (None, 0)) all_angular = _all_pairs_angular(dR, dR) N = R.shape[0] mask_i = True if mask_i is None else mask_i[neighbor.idx[1]] mask_j = True if mask_j is None else mask_j[neighbor.idx[1]] mask_i = (neighbor.idx[0] < N) & mask_i mask_j = (neighbor.idx[0] < N) & mask_j mask = mask_i[:, None] & mask_j[None, :] mask = mask[:, :, None, None] all_angular = jnp.reshape(all_angular, (-1,) + all_angular.shape[2:]) neighbor_idx = jnp.repeat(neighbor.idx[0], len(neighbor.idx[0])) out = ops.segment_sum(all_angular, neighbor_idx, N) return out else: raise ValueError() if species is None: def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array: return sym_fn(R, neighbor, **kwargs) return compute_fn def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array: atom_types = onp.unique(species) out = [] for i in range(len(atom_types)): mask_i = species == i for j in range(i, len(atom_types)): mask_j = species == j out += [ sym_fn(R, neighbor, mask_i, mask_j) ] return jnp.hstack(out) return compute_fn
def symmetry_functions_neighbor_list( displacement: DisplacementFn, species: Array, radial_etas: Optional[Array]=None, angular_etas: Optional[Array]=None, lambdas: Optional[Array]=None, zetas: Optional[Array]=None, cutoff_distance: float=8.0) -> Callable[[Array, NeighborList], Array]: if radial_etas is None: radial_etas = jnp.array([9e-4, 0.01, 0.02, 0.035, 0.06, 0.1, 0.2, 0.4], f32) / f32(0.529177 ** 2) if angular_etas is None: angular_etas = jnp.array([1e-4] * 4 + [0.003] * 4 + [0.008] * 2 + [0.015] * 4 + [0.025] * 4 + [0.045] * 4, f32) / f32(0.529177 ** 2) if lambdas is None: lambdas = jnp.array([-1, 1] * 4 + [1] * 14, f32) if zetas is None: zetas = jnp.array([1, 1, 2, 2] * 2 + [1, 2] + [1, 2, 4, 16] * 3, f32) radial_fn = radial_symmetry_functions_neighbor_list( displacement, species, etas=radial_etas, cutoff_distance=cutoff_distance) angular_fn = angular_symmetry_functions_neighbor_list( displacement, species, etas=angular_etas, lambdas=lambdas, zetas=zetas, cutoff_distance=cutoff_distance) return (lambda R, neighbor, **kwargs: jnp.hstack((radial_fn(R, neighbor, **kwargs), angular_fn(R, neighbor, **kwargs)))) def symmetry_functions(displacement: DisplacementFn, species: Optional[Array]=None, radial_etas: Optional[Array]=None, angular_etas: Optional[Array]=None, lambdas: Optional[Array]=None, zetas: Optional[Array]=None, cutoff_distance: float=8.0 ) -> Callable[[Array], Array]: if radial_etas is None: radial_etas = jnp.array([9e-4, 0.01, 0.02, 0.035, 0.06, 0.1, 0.2, 0.4], f32) / f32(0.529177 ** 2) if angular_etas is None: angular_etas = jnp.array([1e-4] * 4 + [0.003] * 4 + [0.008] * 2 + [0.015] * 4 + [0.025] * 4 + [0.045] * 4, f32) / f32(0.529177 ** 2) if lambdas is None: lambdas = jnp.array([-1, 1] * 4 + [1] * 14, f32) if zetas is None: zetas = jnp.array([1, 1, 2, 2] * 2 + [1, 2] + [1, 2, 4, 16] * 3, f32) radial_fn = radial_symmetry_functions(displacement, species, etas=radial_etas, cutoff_distance=cutoff_distance) angular_fn = angular_symmetry_functions(displacement, species, etas=angular_etas, lambdas=lambdas, zetas=zetas, cutoff_distance=cutoff_distance) symmetry_fn = lambda R, **kwargs: jnp.hstack((radial_fn(R, **kwargs), angular_fn(R, **kwargs))) return symmetry_fn