Neural Network Primitives#

Neural Network Primitives.

Behler-Parrinello Networks#

JAX MD contains neural network primitives for a common class of fixed-feature neural network architectures known as Behler-Parrinello Neural Networks (BP-NN) 1 2. An energy function using this architecture can be found in energy.py.

The BP-NN architecture uses a relatively simple, fully connected neural network to predict the local energy for each atom. Then the total energy is the sum of local energies due to each atom. Atoms of the same type use the same NN to predict energy. Each atomic NN is applied to hand-crafted features called symmetry functions. There are two kinds of symmetry functions: radial and angular. Radial symmetry functions represent information about two-body interactions of the central atom, whereas angular symmetry functions represent information about three-body interactions. Below we implement radial and angular symmetry functions for arbitrary number of types of atoms (Note that most applications of BP-NN limit their systems to 1 to 3 types of atoms). We also present a convenience wrapper that returns radial and angular symmetry functions with symmetry function parameters that should work reasonably for most systems (the symmetry functions are taken from reference [2]). Please see references [1, 2] for details about how the BP-NN works.

References

1

Behler, Jörg, and Michele Parrinello. “Generalized neural-network representation of high-dimensional potential-energy surfaces.” Physical Review Letters 98.14 (2007): 146401.

2

Artrith, Nongnuch, Björn Hiller, and Jörg Behler. “Neural network potentials for metals and oxides–First applications to copper clusters at zinc oxide.” Physica Status Solidi (b) 250.6 (2013): 1191-1203.

jax_md._nn.behler_parrinello.radial_symmetry_functions(displacement_or_metric, species, etas, cutoff_distance)[source]#

Returns a function that computes radial symmetry functions.

Parameters
  • displacement – A function that produces an [N_atoms, M_atoms, spatial_dimension] of particle displacements from particle positions specified as an [N_atoms, spatial_dimension] and `[M_atoms, spatial_dimension] respectively.

  • species (Optional[Array]) – An [N_atoms] that contains the species of each particle.

  • etas (Array) – List of radial symmetry function parameters that control the spatial extension.

  • cutoff_distance (float) – Neighbors whose distance is larger than cutoff_distance do not contribute to each others symmetry functions. The contribution of a neighbor to the symmetry function and its derivative goes to zero at this distance.

Return type

Callable[[Array], Array]

Returns

A function that computes the radial symmetry function from input [N_atoms, spatial_dimension] and returns [N_atoms, N_etas * N_types] where N_etas is the number of eta parameters, N_types is the number of types of particles in the system.

jax_md._nn.behler_parrinello.radial_symmetry_functions_neighbor_list(displacement_or_metric, species, etas, cutoff_distance)[source]#

Returns a function that computes radial symmetry functions.

Parameters
  • displacement – A function that produces an [N_atoms, M_atoms, spatial_dimension] of particle displacements from particle positions specified as an [N_atoms, spatial_dimension] and `[M_atoms, spatial_dimension] respectively.

  • species (Array) – An [N_atoms] that contains the species of each particle.

  • etas (Array) – List of radial symmetry function parameters that control the spatial extension.

  • cutoff_distance (float) – Neighbors whose distance is larger than cutoff_distance do not contribute to each others symmetry functions. The contribution of a neighbor to the symmetry function and its derivative goes to zero at this distance.

Return type

Callable[[Array, NeighborList], Array]

Returns

A function that computes the radial symmetry function from input [N_atoms, spatial_dimension] and returns [N_atoms, N_etas * N_types] where N_etas is the number of eta parameters, N_types is the number of types of particles in the system.

jax_md._nn.behler_parrinello.angular_symmetry_functions(displacement, species, etas, lambdas, zetas, cutoff_distance)[source]#

Returns a function that computes angular symmetry functions.

Parameters
  • displacement (Callable[[Array, Array], Array]) – A function that produces an [N_atoms, M_atoms, spatial_dimension] of particle displacements from particle positions specified as an [N_atoms, spatial_dimension] and `[M_atoms, spatial_dimension] respectively.

  • species (Array) – An [N_atoms] that contains the species of each particle.

  • eta – Parameter of angular symmetry function that controls the spatial extension.

  • lam

  • zeta

  • cutoff_distance (float) – Neighbors whose distance is larger than cutoff_distance do not contribute to each others symmetry functions. The contribution of a neighbor to the symmetry function and its derivative goes to zero at this distance.

Return type

Callable[[Array], Array]

Returns

A function that computes the angular symmetry function from input [N_atoms, spatial_dimension] and returns [N_atoms, N_types * (N_types + 1) / 2] where N_types is the number of types of particles in the system.

jax_md._nn.behler_parrinello.angular_symmetry_functions_neighbor_list(displacement, species, etas, lambdas, zetas, cutoff_distance)[source]#

Returns a function that computes angular symmetry functions.

Parameters
  • displacement (Callable[[Array, Array], Array]) – A function that produces an [N_atoms, M_atoms, spatial_dimension] of particle displacements from particle positions specified as an [N_atoms, spatial_dimension] and `[M_atoms, spatial_dimension] respectively.

  • species (Array) – An [N_atoms] that contains the species of each particle.

  • eta – Parameter of angular symmetry function that controls the spatial extension.

  • lam

  • zeta

  • cutoff_distance (float) – Neighbors whose distance is larger than cutoff_distance do not contribute to each others symmetry functions. The contribution of a neighbor to the symmetry function and its derivative goes to zero at this distance.

Return type

Callable[[Array, NeighborList], Array]

Returns

A function that computes the angular symmetry function from input [N_atoms, spatial_dimension] and returns [N_atoms, N_types * (N_types + 1) / 2] where N_types is the number of types of particles in the system.

Graph Neural Networks#

JAX MD also contains primitives for constructing graph neural networks. These primitives are based on (and are one-to-one with) the excellent Jraph library (www.github.com/deepmind/jraph). Compared to Jraph, these primitives are adapted to work with Dense neighbor lists. However, it is also possible to use Jraph’s primitives directly in combination with Sparse neighbor lists.

Our implementation here is based off the outstanding GraphNets library by DeepMind at, www.github.com/deepmind/graph_nets. This implementation was also heavily influenced by work done by Thomas Keck. We implement a subset of the functionality from the graph nets library to be compatible with jax-md states and neighbor lists, end-to-end jit compilation, and easy batching. Graphs are described by node states, edge states, a global state, and outgoing / incoming edges.

We provide two components:

  1. A GraphIndependent layer that applies a neural network separately to the node states, the edge states, and the globals. This is often used as an encoding or decoding step.

  2. A GraphNetwork layer that transforms the nodes, edges, and globals using neural networks following Battaglia et al. (). Here, we use sum-message-aggregation.

class jax_md.nn.GraphMapFeatures(edge_fn, node_fn, global_fn)[source]#

Applies functions independently to the nodes, edges, and global states.

Return type

Callable[[GraphsTuple], GraphsTuple]

class jax_md.nn.GraphNetwork(edge_fn, node_fn, global_fn)[source]#

Implementation of a Graph Network.

See https://arxiv.org/abs/1806.01261 for more details.

class jax_md.nn.GraphsTuple(nodes, edges, globals, edge_idx)[source]#

A struct containing graph data.

nodes#

For a graph with N_nodes, this is an [N_nodes, node_dimension] array containing the state of each node in the graph.

Type

jax.Array

edges#

For a graph whose degree is bounded by max_degree, this is an [N_nodes, max_degree, edge_dimension]. Here edges[i, j] is the state of the outgoing edge from node i to node edge_idx[i, j].

Type

jax.Array

globals#

An array of shape [global_dimension].

Type

jax.Array

edge_idx#

An integer array of shape [N_nodes, max_degree] where edge_idx[i, j] is the id of the j-th outgoing edge from node i. Empty entries (that don’t contain an edge) are denoted by edge_idx[i, j] == N_nodes.

Type

jax.Array