Source code for jax_md.smap

# Copyright 2019 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

"""Code to transform functions on individual tuples of particles to sets."""

from functools import reduce, partial

from typing import Dict, Callable, List, Tuple, Union, Optional

import math
import enum
from operator import mul

import numpy as onp

from jax import lax, ops, vmap, eval_shape, tree_map
from jax.core import ShapedArray
from jax.interpreters import partial_eval as pe
import jax.numpy as jnp

from jax_md import dataclasses
from jax_md import quantity
from jax_md import space
from jax_md import util
from jax_md import partition

high_precision_sum = util.high_precision_sum

# Typing

Array = util.Array
PyTree = util.PyTree

f32 = util.f32
f64 = util.f64

i32 = util.i32
i64 = util.i64

DisplacementOrMetricFn = space.DisplacementOrMetricFn

# Parameter Trees

[docs]class ParameterTreeMapping(enum.Enum): """An enum specifying how parameters are processed in mapped functions. Attributes: Global: Global parameters are passed directly to the mapped function. PerParticle: PerParticle parameters are combined in pairs based on the particle index. E.g. `p_ij = combinator(p_i, p_j)` for particles i and j. These parameters are expected to have a leading axis of length the number of particles. PerBond: PerBond parameters are expected to have leading two dimensions equal to the number of particles in the system. PerSpecies: PerSpecies parameters are expected to have two leading dimensions equal to the number of species. For particles of species `s_i` and `s_j` parameters are combined according to `p_ij = combinator(p[s_i], p[s_j])`. """ Global = 0 PerParticle = 1 PerBond = 2 PerSpecies = 3
[docs]@dataclasses.dataclass class ParameterTree: """A container denoting that parameters are in the form of a PyTree. Attributes: tree: A JAX PyTree containing a tree of parameters. Before being fed into mapped functions, these parameters are processed according to the mapping. mapping: A ParameterTreeMapping object that specifies how the parameters are processed. """ tree: PyTree mapping: ParameterTreeMapping = dataclasses.static_field()
Parameter = Union[ParameterTree, Array, float] # Mapping potential functional forms to bonds. def _get_bond_type_parameters(params: Array, bond_type: Array) -> Array: """Get parameters for interactions for bonds indexed by a bond-type.""" # TODO(schsam): We should do better error checking here. assert util.is_array(bond_type) assert len(bond_type.shape) == 1 if util.is_array(params): if len(params.shape) == 1: return params[bond_type] elif len(params.shape) == 0: return params else: raise ValueError( 'Params must be a scalar or a 1d array if using a bond-type lookup.') elif isinstance(params, ParameterTree): if params.mapping is ParameterTreeMapping.Global: return params.tree elif params.mapping is ParameterTreeMapping.PerBond: return tree_map(lambda p: p[bond_type], params.tree) else: raise ValueError('ParameterTreeMapping must be either Global or PerBond' 'if used with ``.') elif(isinstance(params, int) or isinstance(params, float) or jnp.issubdtype(params, jnp.integer) or jnp.issubdtype(params, jnp.floating)): return params raise NotImplementedError def _kwargs_to_bond_parameters(bond_type: Array, kwargs: Dict[str, Array]) -> Dict[str, Array]: """Extract parameters from keyword arguments.""" # NOTE(schsam): We could pull out the species case from the generic case. for k, v in kwargs.items(): if bond_type is not None: kwargs[k] = _get_bond_type_parameters(v, bond_type) return kwargs
[docs]def bond(fn: Callable[..., Array], displacement_or_metric: DisplacementOrMetricFn, static_bonds: Optional[Array]=None, static_bond_types: Optional[Array]=None, ignore_unused_parameters: bool=False, **kwargs) -> Callable[..., Array]: """Promotes a function that acts on a single pair to one on a set of bonds. TODO(schsam): It seems like bonds might potentially have poor memory access. Should think about this a bit and potentially optimize. Args: fn: A function that takes an ndarray of pairwise distances or displacements of shape `[n, m]` or `[n, m, d_in]` respectively as well as kwargs specifying parameters for the function. `fn` returns an ndarray of evaluations of shape `[n, m, d_out]`. metric: A function that takes two ndarray of positions of shape `[spatial_dimension]` and `[spatial_dimension]` respectively and returns an ndarray of distances or displacements of shape `[]` or `[d_in]` respectively. The metric can optionally take a floating point time as a third argument. static_bonds: An ndarray of integer pairs wth shape `[b, 2]` where each pair specifies a bond. `static_bonds` are baked into the returned compute function statically and cannot be changed after the fact. static_bond_types: An ndarray of integers of shape `[b]` specifying the type of each bond. Only specify bond types if you want to specify bond parameters by type. One can also specify constant or per-bond parameters (see below). ignore_unused_parameters: A boolean that denotes whether dynamically specified keyword arguments passed to the mapped function get ignored if they were not first specified as keyword arguments when calling ``. kwargs: Arguments providing parameters to the mapped function. In cases where no bond type information is provided these should be either 1. a scalar 2. an ndarray of shape `[b]`. If bond type information is provided then the parameters should be specified as either 1. a scalar 2. an ndarray of shape `[max_bond_type]`. 3. a ParameterTree containing a PyTree of parameters and a mapping. See ParameterTree for details. Returns: A function `fn_mapped`. Note that `fn_mapped` can take arguments bonds and `bond_types` which will be bonds that are specified dynamically. This will incur a recompilation when the number of bonds changes. Improving this state of affairs I will leave as a TODO until someone actually uses this feature and runs into speed issues. """ # Each call to vmap adds a single batch dimension. Here, we would like to # promote the metric function from one that computes the distance / # displacement between two vectors to one that acts on two lists of vectors. # Thus, we apply a single application of vmap. merge_dicts = partial(util.merge_dicts, ignore_unused_parameters=ignore_unused_parameters) def compute_fn(R, bonds, bond_types, static_kwargs, dynamic_kwargs): Ra = R[bonds[:, 0]] Rb = R[bonds[:, 1]] _kwargs = merge_dicts(static_kwargs, dynamic_kwargs) _kwargs = _kwargs_to_bond_parameters(bond_types, _kwargs) # NOTE(schsam): This pattern is needed due to JAX issue #912. d = vmap(partial(displacement_or_metric, **dynamic_kwargs), 0, 0) dr = d(Ra, Rb) return high_precision_sum(fn(dr, **_kwargs)) def mapped_fn(R: Array, bonds: Optional[Array]=None, bond_types: Optional[Array]=None, **dynamic_kwargs) -> Array: accum = f32(0) if bonds is not None: accum = accum + compute_fn(R, bonds, bond_types, kwargs, dynamic_kwargs) if static_bonds is not None: accum = accum + compute_fn( R, static_bonds, static_bond_types, kwargs, dynamic_kwargs) return accum return mapped_fn
# Mapping potential functional forms to pairwise interactions. def _get_species_parameters(params: Parameter, species: Array ) -> Parameter: """Get parameters for interactions between species pairs.""" # TODO(schsam): We should do better error checking here. if util.is_array(params): if len(params.shape) == 2: return params[species] elif len(params.shape) == 0: return params else: raise ValueError( 'Params must be a scalar or a 2d array if using a species lookup.') elif isinstance(params, ParameterTree): p = params.tree if params.mapping is ParameterTreeMapping.Global: return p elif params.mapping is ParameterTreeMapping.PerSpecies: return tree_map(lambda x: x[species], p) else: raise ValueError('When species are present, ParameterTreeMapping must ' f'be Global or PerSpecies. Found {params.mapping}.') return params def _get_matrix_parameters(params: Parameter, combinator: Callable[[Array, Array], Array] ) -> Parameter: """Get an NxN parameter matrix from per-particle parameters.""" if util.is_array(params): if params.ndim == 1: return combinator(params[:, jnp.newaxis], params[jnp.newaxis, :]) elif params.ndim == 0 or params.ndim == 2: return params else: raise ValueError('Without species information, parameters must be ' 'an array of dimension 0, 1, or 2. ' f'Found {params.ndim}.') elif isinstance(params, ParameterTree): M = ParameterTreeMapping if params.mapping in (M.Global, M.PerBond): return params.tree elif params.mapping is M.PerParticle: return tree_map(lambda p: combinator(p[:, None, ...], p[None, :, ...]), params.tree) else: raise ValueError('Without species information, ParameterTreeMapping ' 'must be Global, PerBond, or PerParticle. ' f'Found {params.mapping}.') elif(isinstance(params, int) or isinstance(params, float) or jnp.issubdtype(params, jnp.integer) or jnp.issubdtype(params, jnp.floating)): return params else: raise ValueError('Without species information, params must eitehr be an ' 'array, a ParameterTree, or a float. ' f'Found {type(params)}.') def _kwargs_to_parameters(species: Array, kwargs: Dict[str, Parameter], combinators: Dict[str, Callable] ) -> Dict[str, Array]: """Extract parameters from keyword arguments.""" # NOTE(schsam): We could pull out the species case from the generic case. s_kwargs = {} for k, v in kwargs.items(): if species is None: combinator = combinators.get(k, lambda x, y: 0.5 * (x + y)) s_kwargs[k] = _get_matrix_parameters(v, combinator) else: if k in combinators: raise ValueError('Cannot specify custom combinator with species.') s_kwargs[k] = _get_species_parameters(v, species) return s_kwargs def _diagonal_mask(X: Array) -> Array: """Sets the diagonal of a matrix to zero.""" if X.shape[0] != X.shape[1]: raise ValueError( 'Diagonal mask can only mask square matrices. Found {}x{}.'.format( X.shape[0], X.shape[1])) if len(X.shape) > 3: raise ValueError( ('Diagonal mask can only mask rank-2 or rank-3 tensors. ' 'Found {}.'.format(len(X.shape)))) N = X.shape[0] # NOTE(schsam): It seems potentially dangerous to set nans to 0 here. # However, masking nans also doesn't seem to work. So it also seems # necessary. At the very least we should do some error checking. X = jnp.nan_to_num(X) mask = f32(1.0) - jnp.eye(N, dtype=X.dtype) if len(X.shape) == 3: mask = jnp.reshape(mask, (N, N, 1)) return mask * X def _check_species_dtype(species): if species.dtype == i32 or species.dtype == i64: return msg = 'Species has wrong dtype. Expected integer but found {}.'.format( species.dtype) raise ValueError(msg) def _split_params_and_combinators(kwargs): combinators = {} params = {} for k, v in kwargs.items(): if isinstance(v, Callable): combinators[k] = v elif isinstance(v, tuple) and isinstance(v[0], Callable): assert len(v) == 2 combinators[k] = v[0] params[k] = v[1] else: params[k] = v return params, combinators
[docs]def pair(fn: Callable[..., Array], displacement_or_metric: DisplacementOrMetricFn, species: Optional[Array]=None, reduce_axis: Optional[Tuple[int, ...]]=None, keepdims: bool=False, ignore_unused_parameters: bool=False, **kwargs) -> Callable[..., Array]: """Promotes a function that acts on a pair of particles to one on a system. Args: fn: A function that takes an ndarray of pairwise distances or displacements of shape `[n, m]` or `[n, m, d_in]` respectively as well as kwargs specifying parameters for the function. fn returns an ndarray of evaluations of shape `[n, m, d_out]`. metric: A function that takes two ndarray of positions of shape `[spatial_dimension]` and `[spatial_dimension]` respectively and returns an ndarray of distances or displacements of shape `[]` or `[d_in]` respectively. The metric can optionally take a floating point time as a third argument. species: A list of species for the different particles. This should either be None (in which case it is assumed that all the particles have the same species), an integer ndarray of shape `[n]` with species data, or an integer in which case the species data will be specified dynamically with `species` giving the maximum number of types of particles. Note: that dynamic species specification is less efficient, because we cannot specialize shape information. reduce_axis: A list of axes to reduce over. This is supplied to `jnp.sum` and so the same convention is used. keepdims: A boolean specifying whether the empty dimensions should be kept upon reduction. This is supplied to `jnp.sum` and so the same convention is used. ignore_unused_parameters: A boolean that denotes whether dynamically specified keyword arguments passed to the mapped function get ignored if they were not first specified as keyword arguments when calling `smap.pair(...)`. kwargs: Arguments providing parameters to the mapped function. In cases where no species information is provided these should be either 1) a scalar 2) an ndarray of shape `[n]` 3) an ndarray of shape `[n, n]`, 4) a ParameterTree containing a PyTree of parameters and a mapping. See ParameterTree for details. 5) a binary function that determines how per-particle parameters are to be combined 6) a binary function as well as a default set of parameters as in 2) or 4). If unspecified then this is taken to be the average of the two per-particle parameters. If species information is provided then the parameters should be specified as either 1) a scalar 2) an ndarray of shape `[max_species, max_species]` 3) a ParameterTree containing a PyTree of parameters and a mapping. See ParameterTree for details. Returns: A function fn_mapped. If species is `None` or statically specified then `fn_mapped` takes as arguments an ndarray of positions of shape `[n, spatial_dimension]`. If species is dynamic then `fn_mapped` takes as input an ndarray of shape `[n, spatial_dimension]`, an integer ndarray of species of shape `[n]`, and an integer specifying the maximum species. The mapped function can also optionally take keyword arguments that get threaded through the metric. """ kwargs, param_combinators = _split_params_and_combinators(kwargs) merge_dicts = partial(util.merge_dicts, ignore_unused_parameters=ignore_unused_parameters) if species is None: def fn_mapped(R: Array, **dynamic_kwargs) -> Array: d = space.map_product(partial(displacement_or_metric, **dynamic_kwargs)) _kwargs = merge_dicts(kwargs, dynamic_kwargs) _kwargs = _kwargs_to_parameters(None, _kwargs, param_combinators) dr = d(R, R) # NOTE(schsam): Currently we place a diagonal mask no matter what function # we are mapping. Should this be an option? return high_precision_sum(_diagonal_mask(fn(dr, **_kwargs)), axis=reduce_axis, keepdims=keepdims) * f32(0.5) elif util.is_array(species): species = onp.array(species) _check_species_dtype(species) species_count = int(onp.max(species)) if reduce_axis is not None or keepdims: # TODO(schsam): Support reduce_axis with static species. raise ValueError def fn_mapped(R, **dynamic_kwargs): U = f32(0.0) d = space.map_product(partial(displacement_or_metric, **dynamic_kwargs)) for i in range(species_count + 1): for j in range(i, species_count + 1): _kwargs = merge_dicts(kwargs, dynamic_kwargs) s_kwargs = _kwargs_to_parameters((i, j), _kwargs, param_combinators) Ra = R[species == i] Rb = R[species == j] dr = d(Ra, Rb) if j == i: dU = high_precision_sum(_diagonal_mask(fn(dr, **s_kwargs))) U = U + f32(0.5) * dU else: dU = high_precision_sum(fn(dr, **s_kwargs)) U = U + dU return U elif isinstance(species, int): species_count = species def fn_mapped(R, species, **dynamic_kwargs): _check_species_dtype(species) U = f32(0.0) N = R.shape[0] d = space.map_product(partial(displacement_or_metric, **dynamic_kwargs)) _kwargs = merge_dicts(kwargs, dynamic_kwargs) dr = d(R, R) for i in range(species_count): for j in range(species_count): s_kwargs = _kwargs_to_parameters((i, j), _kwargs, param_combinators) mask_a = jnp.array(jnp.reshape(species == i, (N,)), dtype=R.dtype) mask_b = jnp.array(jnp.reshape(species == j, (N,)), dtype=R.dtype) mask = mask_a[:, jnp.newaxis] * mask_b[jnp.newaxis, :] if i == j: mask = mask * _diagonal_mask(mask) dU = mask * fn(dr, **s_kwargs) U = U + high_precision_sum(dU, axis=reduce_axis, keepdims=keepdims) return U / f32(2.0) else: raise ValueError( 'Species must be None, an ndarray, or an integer. Found {}.'.format( species)) return fn_mapped
# Mapping pairwise functional forms to systems using neighbor lists. def _get_neighborhood_matrix_params(format: partition.NeighborListFormat, idx: Array, params: Parameter, combinator: Callable[[Array, Array], Array] ) -> Parameter: if util.is_array(params): if params.ndim == 1: if partition.is_sparse(format): return space.map_bond(combinator)(params[idx[0]], params[idx[1]]) else: return combinator(params[:, None], params[idx]) return space.map_neighbor(combinator)(params, params[idx]) elif params.ndim == 2: def query(id_a, id_b): return params[id_a, id_b] if partition.is_sparse(format): return space.map_bond(query)(idx[0], idx[1]) else: query = vmap(vmap(query, (None, 0))) return query(jnp.arange(idx.shape[0], dtype=jnp.int32), idx) elif params.ndim == 0: return params else: raise ValueError('Parameter array must be either a scalar, a vector, ' f'or a matrix. Found ndim={params.ndim}.') elif isinstance(params, ParameterTree): if params.mapping is ParameterTreeMapping.Global: return params.tree elif params.mapping is ParameterTreeMapping.PerParticle: if partition.is_sparse(format): c_fn = space.map_bond(combinator) return tree_map(lambda p: c_fn(p[idx[0]], p[idx[1]]), params.tree) else: c_fn = space.map_neighbor(combinator) return tree_map(lambda p: c_fn(p, p[idx]), params.tree) elif params.mapping is ParameterTreeMapping.PerBond: def query(p, id_a, id_b): return p[id_a, id_b] if partition.is_sparse(format): c_fn = lambda p: space.map_bond(partial(query, p))(idx[0], idx[1]) return tree_map(c_fn, params.tree) else: r = jnp.arange(idx.shape[0], dtype=jnp.int32) c_fn = lambda p: vmap(vmap(partial(query, p), (None, 0)))(r, idx) return tree_map(c_fn, params.tree) else: raise ValueError('Without species information ParameterTreeMapping ' f'be Global or PerParticle. Found {params.mapping}.') elif(isinstance(params, int) or isinstance(params, float) or jnp.issubdtype(params, jnp.integer) or jnp.issubdtype(params, jnp.floating)): return params else: raise ValueError('Parameter must be an array, a ParameterTree, or a ' f'float. Found {type(params)}.') def _get_neighborhood_species_params(format: partition.NeighborListFormat, idx: Array, species: Array, params: Parameter) -> Parameter: """Get parameters for interactions between species pairs.""" # TODO(schsam): We should do better error checking here. def lookup(p, species_a, species_b): return p[species_a, species_b] if util.is_array(params): lookup = partial(lookup, params) if len(params.shape) == 2: if partition.is_sparse(format): return space.map_bond(lookup)(species[idx[0]], species[idx[1]]) else: lookup = vmap(vmap(lookup, (None, 0))) return lookup(species, species[idx]) elif len(params.shape) == 0: return params else: raise ValueError( 'Params must be a scalar or a 2d array if using a species lookup.') elif isinstance(params, ParameterTree): if params.mapping is ParameterTreeMapping.Global: return params.tree elif params.mapping is ParameterTreeMapping.PerSpecies: if partition.is_sparse(format): l_fn = lambda p: space.map_bond(partial(lookup, p))(species[idx[0]], species[idx[1]]) return tree_map(l_fn, params.tree) else: l_fn = lambda p: vmap(vmap(partial(lookup, p), (None, 0)))( species, species[idx]) return tree_map(l_fn, params.tree) else: raise ValueError('Parameter tree mapping must be either Global or ' 'PerSpecies if using a species lookup.') return params def _neighborhood_kwargs_to_params(format: partition.NeighborListFormat, idx: Array, species: Array, kwargs: Dict[str, Array], combinators: Dict[str, Callable] ) -> Dict[str, Array]: out_dict = {} for k in kwargs: if species is None or (util.is_array(kwargs[k]) and kwargs[k].ndim == 1): combinator = combinators.get(k, lambda x, y: 0.5 * (x + y)) out_dict[k] = _get_neighborhood_matrix_params(format, idx, kwargs[k], combinator) else: if k in combinators: raise ValueError() out_dict[k] = _get_neighborhood_species_params(format, idx, species, kwargs[k]) return out_dict def _vectorized_cond(pred: Array, fn: Callable[[Array], Array], operand: Array) -> Array: masked = jnp.where(pred, operand, 1) return jnp.where(pred, fn(masked), 0)
[docs]def pair_neighbor_list(fn: Callable[..., Array], displacement_or_metric: DisplacementOrMetricFn, species: Optional[Array]=None, reduce_axis: Optional[Tuple[int, ...]]=None, ignore_unused_parameters: bool=False, **kwargs) -> Callable[..., Array]: """Promotes a function acting on pairs of particles to use neighbor lists. Args: fn: A function that takes an ndarray of pairwise distances or displacements of shape `[n, m]` or `[n, m, d_in]` respectively as well as kwargs specifying parameters for the function. fn returns an ndarray of evaluations of shape `[n, m, d_out]`. metric: A function that takes two ndarray of positions of shape `[spatial_dimension]` and `[spatial_dimension]` respectively and returns an ndarray of distances or displacements of shape `[]` or `[d_in]` respectively. The metric can optionally take a floating point time as a third argument. species: Species information for the different particles. Should either be None (in which case it is assumed that all the particles have the same species), an integer array of shape `[n]` with species data. Note that species data can be specified dynamically by passing a `species` keyword argument to the mapped function. reduce_axis: A list of axes to reduce over. We use a convention where axis 0 corresponds to the particles, axis 1 corresponds to neighbors, and the remaining axes correspond to the output axes of `fn`. Note that it is not well-defined to sum over particles without summing over neighbors. One also cannot report per-particle values (excluding axis `0`) for neighbor lists whose format is `OrderedSparse`. ignore_unused_parameters: A boolean that denotes whether dynamically specified keyword arguments passed to the mapped function get ignored if they were not first specified as keyword arguments when calling `smap.pair_neighbor_list(...)`. kwargs: Arguments providing parameters to the mapped function. In cases where no species information is provided these should be either 1) a scalar 2) an ndarray of shape `[n]` 3) an ndarray of shape `[n, n]`, 4) a ParameterTree containing a PyTree of parameters and a mapping. See ParameterTree for details. 5) a binary function that determines how per-particle parameters are to be combined If unspecified then this is taken to be the average of the two per-particle parameters. If species information is provided then the parameters should be specified as either 1) a scalar 2) an ndarray of shape `[max_species, max_species]` 3) a ParameterTree containing a PyTree of parameters and a mapping. See ParameterTree for details. Returns: A function `fn_mapped` that takes an ndarray of floats of shape `[N, d_in]` of positions and and ndarray of integers of shape `[N, max_neighbors]` specifying neighbors. """ kwargs, param_combinators = _split_params_and_combinators(kwargs) merge_dicts = partial(util.merge_dicts, ignore_unused_parameters=ignore_unused_parameters) def fn_mapped(R: Array, neighbor: partition.NeighborList, **dynamic_kwargs ) -> Array: d = partial(displacement_or_metric, **dynamic_kwargs) _species = dynamic_kwargs.get('species', species) normalization = 2.0 if partition.is_sparse(neighbor.format): d = space.map_bond(d) dR = d(R[neighbor.idx[0]], R[neighbor.idx[1]]) mask = neighbor.idx[0] < R.shape[0] if neighbor.format is partition.OrderedSparse: normalization = 1.0 else: d = space.map_neighbor(d) R_neigh = R[neighbor.idx] dR = d(R, R_neigh) mask = neighbor.idx < R.shape[0] merged_kwargs = merge_dicts(kwargs, dynamic_kwargs) merged_kwargs = _neighborhood_kwargs_to_params(neighbor.format, neighbor.idx, _species, merged_kwargs, param_combinators) out = fn(dR, **merged_kwargs) if out.ndim > mask.ndim: ddim = out.ndim - mask.ndim mask = jnp.reshape(mask, mask.shape + (1,) * ddim) out *= mask if reduce_axis is None: return util.high_precision_sum(out) / normalization if 0 in reduce_axis and 1 not in reduce_axis: raise ValueError() if not partition.is_sparse(neighbor.format): return util.high_precision_sum(out, reduce_axis) / normalization _reduce_axis = tuple(a - 1 for a in reduce_axis if a > 1) if 0 in reduce_axis: return util.high_precision_sum(out, (0,) + _reduce_axis) if neighbor.format is partition.OrderedSparse: raise ValueError('Cannot report per-particle values with a neighbor ' 'list whose format is `OrderedSparse`. Please use ' 'either `Dense` or `Sparse`.') out = util.high_precision_sum(out, _reduce_axis) return ops.segment_sum(out, neighbor.idx[0], R.shape[0]) / normalization return fn_mapped
[docs]def triplet(fn: Callable[..., Array], displacement_or_metric: DisplacementOrMetricFn, species: Optional[Array]=None, reduce_axis: Optional[Tuple[int, ...]]=None, keepdims: bool=False, ignore_unused_parameters: bool=False, **kwargs) -> Callable[..., Array]: """Promotes a function that acts on triples of particles to one on a system. Many empirical potentials in jax_md include three-body angular terms (e.g. Stillinger Weber). This utility function simplifies the loss computation in such cases by converting a function that takes in two pairwise displacements or distances to one that only requires the system as input. Args: fn: A function that takes an ndarray of two distances or displacements from a central atom, both of shape `[n, m]` or `[n, m, d_in]` respectively, as well as kwargs specifying parameters for the function. metric: A function that takes two ndarray of positions of shape `[spatial_dimensions]` and `[spatial_dimensions]` respectively and returns an ndarray of distances or displacements of shape `[]` or `[d_in]` respectively. species: A list of species for the different particles. This should either be None (in which case it is assumed that all the particles have the same species), an integer ndarray of shape `[n]` with species data, or an integer in which case the species data will be specified dynamically with `species` giving the maximum number of types of particles. Note: that dynamic species specification is less efficient, because we cannot specialize shape information. reduce_axis: A list of axis to reduce over. This is supplied to np.sum and the same convention is used. keepdims: A boolean specifying whether the empty dimensions should be kept upon reduction. This is supplied to np.sum and so the same convention is used. ignore_unused_parameters: A boolean that denotes whether dynamically specified keyword arguments passed to the mapped function get ignored if they were not first specified as keyword arguments when calling `smap.triplet(...)`. kwargs: Argument providing parameters to the mapped function. In cases where no species information is provided, these should either be 1) a scalar 2) an ndarray of shape `[n]` based on the central atom 3) an ndarray of shape `[n, n, n]` defining triplet interactions. If species information is provided, then the parameters should be specified as either 1) a scalar 2) an ndarray of shape `[max_species]` 3) an ndarray of shape `[max_species, max_species, max_species]` defining triplet interactions. Returns: A function `fn_mapped`. If species is None or statically specified, then `fn_mapped` takes as arguments an ndarray of positions of shape `[n, spatial_dimension]`. If species is dynamic then `fn_mapped` takes as input an ndarray of shape `[n, spatial_dimension]`, an integer ndarray of species of shape `[n]`, and an integer specifying the maximum species. The mapped function can also optionally take keyword arguments that get threaded through the metric. """ merge_dicts = partial(util.merge_dicts, ignore_unused_parameters=ignore_unused_parameters) def extract_parameters_by_dim(kwargs, dim: Union[int, List[int]] = 0): """Extract parameters from a dictionary via dimension.""" if isinstance(dim, int): dim = [dim] return {name: value for name, value in kwargs.items() if value.ndim in dim} if species is None: def fn_mapped(R, **dynamic_kwargs) -> Array: d = space.map_product(partial(displacement_or_metric, **dynamic_kwargs)) _kwargs = merge_dicts(kwargs, dynamic_kwargs) _kwargs = _kwargs_to_parameters(species, _kwargs, {}) dR = d(R, R) compute_triplet = partial(fn, **_kwargs) output = vmap(vmap(vmap( compute_triplet, (None, 0)), (0, None)), 0)(dR, dR) return high_precision_sum(output, axis=reduce_axis, keepdims=keepdims) / 2. elif util.is_array(species): def fn_mapped(R, **dynamic_kwargs): d = partial(displacement_or_metric, **dynamic_kwargs) idx = onp.tile(onp.arange(R.shape[0]), [R.shape[0], 1]) dR = vmap(vmap(d, (None, 0)))(R, R[idx]) _kwargs = merge_dicts(kwargs, dynamic_kwargs) mapped_args = extract_parameters_by_dim(_kwargs, [3]) mapped_args = {arg_name: arg_value[species] for arg_name, arg_value in mapped_args.items()} # While we support 2 dimensional inputs, these often make less sense # as the parameters do not depend on the central atom unmapped_args = extract_parameters_by_dim(_kwargs, [0]) if extract_parameters_by_dim(_kwargs, [1, 2]): assert ValueError('Improper argument dimensions (1 or 2) not well ' 'defined for triplets.') def compute_triplet(dR, mapped_args, unmapped_args): paired_args = extract_parameters_by_dim(mapped_args, 2) paired_args.update(extract_parameters_by_dim(unmapped_args, 2)) unpaired_args = extract_parameters_by_dim(mapped_args, 0) unpaired_args.update(extract_parameters_by_dim(unmapped_args, 0)) output_fn = lambda dR1, dR2, paired_args: fn(dR1, dR2, **unpaired_args, **paired_args) neighbor_args = _neighborhood_kwargs_to_params(partition.Dense, idx, species, paired_args, {}) output_fn = vmap(vmap(output_fn, (None, 0, 0)), (0, None, 0)) return output_fn(dR, dR, neighbor_args) output_fn = partial(compute_triplet, unmapped_args=unmapped_args) output = vmap(output_fn)(dR, mapped_args) return high_precision_sum(output, axis=reduce_axis, keepdims=keepdims) / 2. elif isinstance(species, int): raise NotImplementedError else: raise ValueError( 'Species must be None, an ndarray, or Dynamic. Found {}.'.format( species)) return fn_mapped