# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Neural Network Primitives."""
from typing import Callable, Tuple
from flax import nnx
import jax
from jax import vmap
import jax.numpy as jnp
from jax_md import dataclasses, partition
from jax_md import util as jmd_util
from functools import partial
import jraph
from ._nn import behler_parrinello
from ._nn import nequip
from ._nn import gnome
from ._nn import util
# Typing
Array = jmd_util.Array
ActivationFn = Callable[[Array], Array]
DEFAULT_KERNEL_INIT = jax.nn.initializers.variance_scaling(
1.0, 'fan_avg', 'truncated_normal'
)
DEFAULT_BIAS_INIT = jax.nn.initializers.zeros
class MLP(nnx.Module):
"""Multi-layer perceptron with configurable activation."""
def __init__(
self,
in_features: int,
output_sizes: Tuple[int, ...],
*,
rngs: nnx.Rngs,
activation: ActivationFn = jax.nn.relu,
kernel_init: Callable = DEFAULT_KERNEL_INIT,
bias_init: Callable = DEFAULT_BIAS_INIT,
use_bias: bool = True,
activate_final: bool = True,
):
self.activation = activation
self.activate_final = activate_final
sizes = (in_features,) + tuple(output_sizes)
self.num_layers = len(sizes) - 1
for i in range(self.num_layers):
setattr(
self,
f'layers_{i}',
nnx.Linear(
sizes[i],
sizes[i + 1],
use_bias=use_bias,
kernel_init=kernel_init,
bias_init=bias_init,
rngs=rngs,
),
)
def __call__(self, x: Array) -> Array:
for i in range(self.num_layers):
x = getattr(self, f'layers_{i}')(x)
if self.activate_final or i < self.num_layers - 1:
x = self.activation(x)
return x
# TO BE DELETED BELOW:
# Graph neural network primitives
"""
Our implementation here is based off the outstanding GraphNets library by
DeepMind at, www.github.com/deepmind/graph_nets. This implementation was also
heavily influenced by work done by Thomas Keck. We implement a subset of the
functionality from the graph nets library to be compatible with jax-md
states and neighbor lists, end-to-end jit compilation, and easy batching.
Graphs are described by node states, edge states, a global state, and
outgoing / incoming edges.
We provide two components:
1) A GraphIndependent layer that applies a neural network separately to the
node states, the edge states, and the globals. This is often used as an
encoding or decoding step.
2) A GraphNetwork layer that transforms the nodes, edges, and globals using
neural networks following Battaglia et al. (). Here, we use
sum-message-aggregation.
The graphs network components implemented here implement identical functions
to the DeepMind library. However, to be compatible with jax-md, there are
significant differences in the graph layout used here to the reference
implementation. See `GraphsTuple` for details.
"""
[docs]
@dataclasses.dataclass
class GraphsTuple(object):
"""A struct containing graph data.
Attributes:
nodes: For a graph with `N_nodes`, this is an `[N_nodes, node_dimension]`
array containing the state of each node in the graph.
edges: For a graph whose degree is bounded by max_degree, this is an
`[N_nodes, max_degree, edge_dimension]`. Here `edges[i, j]` is the
state of the outgoing edge from node `i` to node `edge_idx[i, j]`.
globals: An array of shape `[global_dimension]`.
edge_idx: An integer array of shape `[N_nodes, max_degree]` where
`edge_idx[i, j]` is the id of the j-th outgoing edge from node `i`.
Empty entries (that don't contain an edge) are denoted by
`edge_idx[i, j] == N_nodes`.
"""
nodes: jnp.ndarray
edges: jnp.ndarray
globals: jnp.ndarray
edge_idx: jnp.ndarray
_replace = dataclasses.replace
def concatenate_graph_features(graphs: Tuple[GraphsTuple, ...]) -> GraphsTuple:
"""Given a list of GraphsTuple returns a new concatenated GraphsTuple.
Note that currently we do not check that the graphs have consistent edge
connectivity.
"""
graph = graphs[0]
return graph._replace(
nodes=jnp.concatenate([g.nodes for g in graphs], axis=-1),
edges=jnp.concatenate([g.edges for g in graphs], axis=-1),
globals=jnp.concatenate(
[g.globals for g in graphs], axis=-1
), # pytype: disable=missing-parameter
)
[docs]
def GraphMapFeatures(
edge_fn: Callable[[Array], Array],
node_fn: Callable[[Array], Array],
global_fn: Callable[[Array], Array],
) -> Callable[[GraphsTuple], GraphsTuple]:
"""Applies functions independently to the nodes, edges, and global states."""
identity = lambda x: x
_node_fn = vmap(node_fn) if node_fn is not None else identity
_edge_fn = vmap(vmap(edge_fn)) if edge_fn is not None else identity
_global_fn = global_fn if global_fn is not None else identity
def embed_fn(graph):
return graph._replace(
nodes=_node_fn(graph.nodes),
edges=_edge_fn(graph.edges),
globals=_global_fn(graph.globals),
)
return embed_fn
def apply_node_fn(
graph: GraphsTuple, node_fn: Callable[[Array, Array, Array, Array], Array]
) -> Array:
mask = graph.edge_idx < graph.nodes.shape[0]
mask = mask[:, :, jnp.newaxis]
if graph.edges is not None:
# TODO: Should we also have outgoing edges?
flat_edges = jnp.reshape(graph.edges, (-1, graph.edges.shape[-1]))
edge_idx = jnp.reshape(graph.edge_idx, (-1,))
incoming_edges = jax.ops.segment_sum(
flat_edges, edge_idx, graph.nodes.shape[0] + 1
)[:-1]
outgoing_edges = jnp.sum(graph.edges * mask, axis=1)
else:
incoming_edges = None
outgoing_edges = None
if graph.globals is not None:
_globals = jnp.broadcast_to(
graph.globals[jnp.newaxis, :], graph.nodes.shape[:1] + graph.globals.shape
)
else:
_globals = None
return node_fn(graph.nodes, incoming_edges, outgoing_edges, _globals)
def apply_edge_fn(
graph: GraphsTuple, edge_fn: Callable[[Array, Array, Array, Array], Array]
) -> Array:
if graph.nodes is not None:
incoming_nodes = graph.nodes[graph.edge_idx]
outgoing_nodes = jnp.broadcast_to(
graph.nodes[:, jnp.newaxis, :],
graph.edge_idx.shape + graph.nodes.shape[-1:],
)
else:
incoming_nodes = None
outgoing_nodes = None
if graph.globals is not None:
_globals = jnp.broadcast_to(
graph.globals[jnp.newaxis, jnp.newaxis, :],
graph.edge_idx.shape + graph.globals.shape,
)
else:
_globals = None
mask = graph.edge_idx < graph.nodes.shape[0]
mask = mask[:, :, jnp.newaxis]
return edge_fn(graph.edges, incoming_nodes, outgoing_nodes, _globals) * mask
def apply_global_fn(
graph: GraphsTuple, global_fn: Callable[[Array, Array, Array], Array]
) -> Array:
nodes = None if graph.nodes is None else jnp.sum(graph.nodes, axis=0)
if graph.edges is not None:
mask = graph.edge_idx < graph.nodes.shape[0]
mask = mask[:, :, jnp.newaxis]
edges = jnp.sum(graph.edges * mask, axis=(0, 1))
else:
edges = None
return global_fn(nodes, edges, graph.globals)
[docs]
class GraphNetwork:
"""Implementation of a Graph Network.
See https://arxiv.org/abs/1806.01261 for more details.
"""
def __init__(
self,
edge_fn: Callable[[Array, Array, Array, Array], Array],
node_fn: Callable[[Array, Array, Array, Array], Array],
global_fn: Callable[[Array, Array, Array], Array],
):
self.node_fn = (
None if node_fn is None else partial(apply_node_fn, node_fn=vmap(node_fn))
)
self.edge_fn = (
None
if edge_fn is None
else partial(apply_edge_fn, edge_fn=vmap(vmap(edge_fn)))
)
self.global_fn = (
None
if global_fn is None
else partial(apply_global_fn, global_fn=global_fn)
)
def __call__(self, graph: GraphsTuple) -> GraphsTuple:
if self.edge_fn is not None:
graph = graph._replace(edges=self.edge_fn(graph))
if self.node_fn is not None:
graph = graph._replace(nodes=self.node_fn(graph))
if self.global_fn is not None:
graph = graph._replace(globals=self.global_fn(graph))
return graph
# Prefab Networks
class GraphNetEncoder(nnx.Module):
"""Implements a Graph Neural Network for energy fitting.
Based on the network used in "Unveiling the predictive power of static
structure in glassy systems"; Bapst et al.
(https://www.nature.com/articles/s41567-020-0842-8). This network first
embeds edges, nodes, and global state. Then ``n_recurrences`` of GraphNetwork
layers are applied. Unlike in Bapst et al. this network does not include a
readout, which should be added separately depending on the application.
For example, when predicting particle mobilities, one would use a decoder
only on the node states while a model of energies would decode only the node
states.
"""
def __init__(
self,
in_node_features: int,
in_edge_features: int,
in_global_features: int,
n_recurrences: int,
mlp_sizes: Tuple[int, ...],
*,
rngs: nnx.Rngs,
activation: ActivationFn = jax.nn.relu,
kernel_init: Callable = DEFAULT_KERNEL_INIT,
bias_init: Callable = DEFAULT_BIAS_INIT,
format: partition.NeighborListFormat = partition.Dense,
):
self.n_recurrences = n_recurrences
self.format = format
kw = dict(
rngs=rngs,
activation=activation,
kernel_init=kernel_init,
bias_init=bias_init,
activate_final=True,
)
m = mlp_sizes[-1]
self.EdgeEncoder = MLP(in_edge_features, mlp_sizes, **kw)
self.NodeEncoder = MLP(in_node_features, mlp_sizes, **kw)
self.GlobalEncoder = MLP(in_global_features, mlp_sizes, **kw)
for i in range(n_recurrences):
setattr(self, f'edge_fns_{i}', MLP(8 * m, mlp_sizes, **kw))
setattr(self, f'node_fns_{i}', MLP(6 * m, mlp_sizes, **kw))
setattr(self, f'global_fns_{i}', MLP(4 * m, mlp_sizes, **kw))
def __call__(self, graph: GraphsTuple) -> GraphsTuple:
if self.format is partition.Dense:
graph_map_features = GraphMapFeatures
graph_network = GraphNetwork
elif self.format is partition.Sparse:
graph_map_features = jraph.GraphMapFeatures
graph_network = jraph.GraphNetwork
else:
raise ValueError()
encoded = graph_map_features(
self.EdgeEncoder, self.NodeEncoder, self.GlobalEncoder
)(graph)
outputs = encoded
for i in range(self.n_recurrences):
edge_mlp = getattr(self, f'edge_fns_{i}')
node_mlp = getattr(self, f'node_fns_{i}')
global_mlp = getattr(self, f'global_fns_{i}')
def edge_update(edges, sent, received, globals_, mlp=edge_mlp):
return mlp(jnp.concatenate((edges, sent, received, globals_), axis=-1))
def node_update(nodes, sent, received, globals_, mlp=node_mlp):
return mlp(jnp.concatenate((nodes, sent, received, globals_), axis=-1))
def global_update(nodes, edges, globals_, mlp=global_mlp):
return mlp(jnp.concatenate((nodes, edges, globals_), axis=-1))
inputs = concatenate_graph_features((outputs, encoded))
outputs = graph_network(edge_update, node_update, global_update)(inputs)
return outputs