Neural Network Primitives#
Neural Network Primitives.
Behler-Parrinello Networks#
JAX MD contains neural network primitives for a common class of
fixed-feature neural network architectures known as Behler-Parrinello
Neural Networks (BP-NN) 1 2. An energy function
using this architecture can be found in energy.py
.
The BP-NN architecture uses a relatively simple, fully connected neural network to predict the local energy for each atom. Then the total energy is the sum of local energies due to each atom. Atoms of the same type use the same NN to predict energy. Each atomic NN is applied to hand-crafted features called symmetry functions. There are two kinds of symmetry functions: radial and angular. Radial symmetry functions represent information about two-body interactions of the central atom, whereas angular symmetry functions represent information about three-body interactions. Below we implement radial and angular symmetry functions for arbitrary number of types of atoms (Note that most applications of BP-NN limit their systems to 1 to 3 types of atoms). We also present a convenience wrapper that returns radial and angular symmetry functions with symmetry function parameters that should work reasonably for most systems (the symmetry functions are taken from reference [2]). Please see references [1, 2] for details about how the BP-NN works.
References
- 1
Behler, Jörg, and Michele Parrinello. “Generalized neural-network representation of high-dimensional potential-energy surfaces.” Physical Review Letters 98.14 (2007): 146401.
- 2
Artrith, Nongnuch, Björn Hiller, and Jörg Behler. “Neural network potentials for metals and oxides–First applications to copper clusters at zinc oxide.” Physica Status Solidi (b) 250.6 (2013): 1191-1203.
- jax_md._nn.behler_parrinello.radial_symmetry_functions(displacement_or_metric, species, etas, cutoff_distance)[source]#
Returns a function that computes radial symmetry functions.
- Parameters
displacement – A function that produces an
[N_atoms, M_atoms, spatial_dimension]
of particle displacements from particle positions specified as an[N_atoms, spatial_dimension] and `[M_atoms, spatial_dimension]
respectively.species (
Optional
[Array
]) – An[N_atoms]
that contains the species of each particle.etas (
Array
) – List of radial symmetry function parameters that control the spatial extension.cutoff_distance (
float
) – Neighbors whose distance is larger thancutoff_distance
do not contribute to each others symmetry functions. The contribution of a neighbor to the symmetry function and its derivative goes to zero at this distance.
- Return type
Callable
[[Array
],Array
]- Returns
A function that computes the radial symmetry function from input
[N_atoms, spatial_dimension]
and returns[N_atoms, N_etas * N_types]
whereN_etas
is the number of eta parameters,N_types
is the number of types of particles in the system.
- jax_md._nn.behler_parrinello.radial_symmetry_functions_neighbor_list(displacement_or_metric, species, etas, cutoff_distance)[source]#
Returns a function that computes radial symmetry functions.
- Parameters
displacement – A function that produces an
[N_atoms, M_atoms, spatial_dimension]
of particle displacements from particle positions specified as an[N_atoms, spatial_dimension] and `[M_atoms, spatial_dimension]
respectively.species (
Array
) – An[N_atoms]
that contains the species of each particle.etas (
Array
) – List of radial symmetry function parameters that control the spatial extension.cutoff_distance (
float
) – Neighbors whose distance is larger thancutoff_distance
do not contribute to each others symmetry functions. The contribution of a neighbor to the symmetry function and its derivative goes to zero at this distance.
- Return type
Callable
[[Array
,NeighborList
],Array
]- Returns
A function that computes the radial symmetry function from input
[N_atoms, spatial_dimension]
and returns[N_atoms, N_etas * N_types]
whereN_etas
is the number of eta parameters,N_types
is the number of types of particles in the system.
- jax_md._nn.behler_parrinello.angular_symmetry_functions(displacement, species, etas, lambdas, zetas, cutoff_distance)[source]#
Returns a function that computes angular symmetry functions.
- Parameters
displacement (
Callable
[[Array
,Array
],Array
]) – A function that produces an[N_atoms, M_atoms, spatial_dimension]
of particle displacements from particle positions specified as an[N_atoms, spatial_dimension] and `[M_atoms, spatial_dimension]
respectively.species (
Array
) – An[N_atoms]
that contains the species of each particle.eta – Parameter of angular symmetry function that controls the spatial extension.
lam –
zeta –
cutoff_distance (
float
) – Neighbors whose distance is larger thancutoff_distance
do not contribute to each others symmetry functions. The contribution of a neighbor to the symmetry function and its derivative goes to zero at this distance.
- Return type
Callable
[[Array
],Array
]- Returns
A function that computes the angular symmetry function from input
[N_atoms, spatial_dimension]
and returns[N_atoms, N_types * (N_types + 1) / 2]
whereN_types
is the number of types of particles in the system.
- jax_md._nn.behler_parrinello.angular_symmetry_functions_neighbor_list(displacement, species, etas, lambdas, zetas, cutoff_distance)[source]#
Returns a function that computes angular symmetry functions.
- Parameters
displacement (
Callable
[[Array
,Array
],Array
]) – A function that produces an[N_atoms, M_atoms, spatial_dimension]
of particle displacements from particle positions specified as an[N_atoms, spatial_dimension] and `[M_atoms, spatial_dimension]
respectively.species (
Array
) – An[N_atoms]
that contains the species of each particle.eta – Parameter of angular symmetry function that controls the spatial extension.
lam –
zeta –
cutoff_distance (
float
) – Neighbors whose distance is larger thancutoff_distance
do not contribute to each others symmetry functions. The contribution of a neighbor to the symmetry function and its derivative goes to zero at this distance.
- Return type
Callable
[[Array
,NeighborList
],Array
]- Returns
A function that computes the angular symmetry function from input
[N_atoms, spatial_dimension]
and returns[N_atoms, N_types * (N_types + 1) / 2]
whereN_types
is the number of types of particles in the system.
Graph Neural Networks#
JAX MD also contains primitives for constructing graph neural networks. These primitives are based on (and are one-to-one with) the excellent Jraph library (www.github.com/deepmind/jraph). Compared to Jraph, these primitives are adapted to work with Dense neighbor lists. However, it is also possible to use Jraph’s primitives directly in combination with Sparse neighbor lists.
Our implementation here is based off the outstanding GraphNets library by DeepMind at, www.github.com/deepmind/graph_nets. This implementation was also heavily influenced by work done by Thomas Keck. We implement a subset of the functionality from the graph nets library to be compatible with jax-md states and neighbor lists, end-to-end jit compilation, and easy batching. Graphs are described by node states, edge states, a global state, and outgoing / incoming edges.
We provide two components:
A GraphIndependent layer that applies a neural network separately to the node states, the edge states, and the globals. This is often used as an encoding or decoding step.
A GraphNetwork layer that transforms the nodes, edges, and globals using neural networks following Battaglia et al. (). Here, we use sum-message-aggregation.
- class jax_md.nn.GraphMapFeatures(edge_fn, node_fn, global_fn)[source]#
Applies functions independently to the nodes, edges, and global states.
- Return type
- class jax_md.nn.GraphNetwork(edge_fn, node_fn, global_fn)[source]#
Implementation of a Graph Network.
See https://arxiv.org/abs/1806.01261 for more details.
- class jax_md.nn.GraphsTuple(nodes, edges, globals, edge_idx)[source]#
A struct containing graph data.
- nodes#
For a graph with
N_nodes
, this is an[N_nodes, node_dimension]
array containing the state of each node in the graph.- Type
jax.Array
- edges#
For a graph whose degree is bounded by max_degree, this is an
[N_nodes, max_degree, edge_dimension]
. Hereedges[i, j]
is the state of the outgoing edge from nodei
to nodeedge_idx[i, j]
.- Type
jax.Array
- globals#
An array of shape
[global_dimension]
.- Type
jax.Array
- edge_idx#
An integer array of shape
[N_nodes, max_degree]
whereedge_idx[i, j]
is the id of the j-th outgoing edge from nodei
. Empty entries (that don’t contain an edge) are denoted byedge_idx[i, j] == N_nodes
.- Type
jax.Array