Higher Order Functions#
Code to transform functions on individual tuples of particles to sets.
- class jax_md.smap.ParameterTree(tree, mapping)[source]#
A container denoting that parameters are in the form of a PyTree.
- tree#
A JAX PyTree containing a tree of parameters. Before being fed into mapped functions, these parameters are processed according to the mapping.
- Type
Any
- mapping#
A ParameterTreeMapping object that specifies how the parameters are processed.
- class jax_md.smap.ParameterTreeMapping(value)[source]#
An enum specifying how parameters are processed in mapped functions.
- Global#
Global parameters are passed directly to the mapped function.
- PerParticle#
PerParticle parameters are combined in pairs based on the particle index. E.g.
p_ij = combinator(p_i, p_j)
for particles i and j. These parameters are expected to have a leading axis of length the number of particles.
- PerBond#
PerBond parameters are expected to have leading two dimensions equal to the number of particles in the system.
- PerSpecies#
PerSpecies parameters are expected to have two leading dimensions equal to the number of species. For particles of species
s_i
ands_j
parameters are combined according top_ij = combinator(p[s_i], p[s_j])
.
- jax_md.smap.bond(fn, displacement_or_metric, static_bonds=None, static_bond_types=None, ignore_unused_parameters=False, **kwargs)[source]#
Promotes a function that acts on a single pair to one on a set of bonds.
TODO(schsam): It seems like bonds might potentially have poor memory access. Should think about this a bit and potentially optimize.
- Parameters
fn (
Callable
[…,Array
]) – A function that takes an ndarray of pairwise distances or displacements of shape[n, m]
or[n, m, d_in]
respectively as well as kwargs specifying parameters for the function.fn
returns an ndarray of evaluations of shape[n, m, d_out]
.metric – A function that takes two ndarray of positions of shape
[spatial_dimension]
and[spatial_dimension]
respectively and returns an ndarray of distances or displacements of shape[]
or[d_in]
respectively. The metric can optionally take a floating point time as a third argument.static_bonds (
Optional
[Array
]) – An ndarray of integer pairs wth shape[b, 2]
where each pair specifies a bond.static_bonds
are baked into the returned compute function statically and cannot be changed after the fact.static_bond_types (
Optional
[Array
]) – An ndarray of integers of shape[b]
specifying the type of each bond. Only specify bond types if you want to specify bond parameters by type. One can also specify constant or per-bond parameters (see below).ignore_unused_parameters (
bool
) – A boolean that denotes whether dynamically specified keyword arguments passed to the mapped function get ignored if they were not first specified as keyword arguments when callingsmap.bond(...)
.kwargs –
Arguments providing parameters to the mapped function. In cases where no bond type information is provided these should be either
a scalar
an ndarray of shape
[b]
.
If bond type information is provided then the parameters should be specified as either
a scalar
an ndarray of shape
[max_bond_type]
.a ParameterTree containing a PyTree of parameters and a mapping. See ParameterTree for details.
- Return type
Callable
[…,Array
]- Returns
A function
fn_mapped
. Note thatfn_mapped
can take arguments bonds andbond_types
which will be bonds that are specified dynamically. This will incur a recompilation when the number of bonds changes. Improving this state of affairs I will leave as a TODO until someone actually uses this feature and runs into speed issues.
- jax_md.smap.pair(fn, displacement_or_metric, species=None, reduce_axis=None, keepdims=False, ignore_unused_parameters=False, **kwargs)[source]#
Promotes a function that acts on a pair of particles to one on a system.
- Parameters
fn (
Callable
[…,Array
]) – A function that takes an ndarray of pairwise distances or displacements of shape[n, m]
or[n, m, d_in]
respectively as well as kwargs specifying parameters for the function. fn returns an ndarray of evaluations of shape[n, m, d_out]
.metric – A function that takes two ndarray of positions of shape
[spatial_dimension]
and[spatial_dimension]
respectively and returns an ndarray of distances or displacements of shape[]
or[d_in]
respectively. The metric can optionally take a floating point time as a third argument.species (
Optional
[Array
]) – A list of species for the different particles. This should either be None (in which case it is assumed that all the particles have the same species), an integer ndarray of shape[n]
with species data, or an integer in which case the species data will be specified dynamically withspecies
giving the maximum number of types of particles. Note: that dynamic species specification is less efficient, because we cannot specialize shape information.reduce_axis (
Optional
[Tuple
[int
, …]]) – A list of axes to reduce over. This is supplied tojnp.sum
and so the same convention is used.keepdims (
bool
) – A boolean specifying whether the empty dimensions should be kept upon reduction. This is supplied tojnp.sum
and so the same convention is used.ignore_unused_parameters (
bool
) – A boolean that denotes whether dynamically specified keyword arguments passed to the mapped function get ignored if they were not first specified as keyword arguments when callingsmap.pair(...)
.kwargs –
Arguments providing parameters to the mapped function. In cases where no species information is provided these should be either
a scalar
an ndarray of shape
[n]
an ndarray of shape
[n, n]
,a ParameterTree containing a PyTree of parameters and a mapping. See ParameterTree for details.
a binary function that determines how per-particle parameters are to be combined
a binary function as well as a default set of parameters as in 2) or 4).
If unspecified then this is taken to be the average of the two per-particle parameters. If species information is provided then the parameters should be specified as either
a scalar
an ndarray of shape
[max_species, max_species]
a ParameterTree containing a PyTree of parameters and a mapping. See ParameterTree for details.
- Return type
Callable
[…,Array
]- Returns
A function fn_mapped.
If species is
None
or statically specified thenfn_mapped
takes as arguments an ndarray of positions of shape[n, spatial_dimension]
.If species is dynamic then
fn_mapped
takes as input an ndarray of shape[n, spatial_dimension]
, an integer ndarray of species of shape[n]
, and an integer specifying the maximum species.The mapped function can also optionally take keyword arguments that get threaded through the metric.
- jax_md.smap.pair_neighbor_list(fn, displacement_or_metric, species=None, reduce_axis=None, ignore_unused_parameters=False, **kwargs)[source]#
Promotes a function acting on pairs of particles to use neighbor lists.
- Parameters
fn (
Callable
[…,Array
]) – A function that takes an ndarray of pairwise distances or displacements of shape[n, m]
or[n, m, d_in]
respectively as well as kwargs specifying parameters for the function. fn returns an ndarray of evaluations of shape[n, m, d_out]
.metric – A function that takes two ndarray of positions of shape
[spatial_dimension]
and[spatial_dimension]
respectively and returns an ndarray of distances or displacements of shape[]
or[d_in]
respectively. The metric can optionally take a floating point time as a third argument.species (
Optional
[Array
]) – Species information for the different particles. Should either be None (in which case it is assumed that all the particles have the same species), an integer array of shape[n]
with species data. Note that species data can be specified dynamically by passing aspecies
keyword argument to the mapped function.reduce_axis (
Optional
[Tuple
[int
, …]]) – A list of axes to reduce over. We use a convention where axis 0 corresponds to the particles, axis 1 corresponds to neighbors, and the remaining axes correspond to the output axes offn
. Note that it is not well-defined to sum over particles without summing over neighbors. One also cannot report per-particle values (excluding axis0
) for neighbor lists whose format isOrderedSparse
.ignore_unused_parameters (
bool
) – A boolean that denotes whether dynamically specified keyword arguments passed to the mapped function get ignored if they were not first specified as keyword arguments when callingsmap.pair_neighbor_list(...)
.kwargs –
Arguments providing parameters to the mapped function. In cases where no species information is provided these should be either
a scalar
an ndarray of shape
[n]
an ndarray of shape
[n, n]
,a ParameterTree containing a PyTree of parameters and a mapping. See ParameterTree for details.
a binary function that determines how per-particle parameters are to be combined
If unspecified then this is taken to be the average of the two per-particle parameters. If species information is provided then the parameters should be specified as either
a scalar
an ndarray of shape
[max_species, max_species]
a ParameterTree containing a PyTree of parameters and a mapping. See ParameterTree for details.
- Return type
Callable
[…,Array
]- Returns
A function
fn_mapped
that takes an ndarray of floats of shape[N, d_in]
of positions and and ndarray of integers of shape[N, max_neighbors]
specifying neighbors.
- jax_md.smap.triplet(fn, displacement_or_metric, species=None, reduce_axis=None, keepdims=False, ignore_unused_parameters=False, **kwargs)[source]#
Promotes a function that acts on triples of particles to one on a system.
Many empirical potentials in jax_md include three-body angular terms (e.g. Stillinger Weber). This utility function simplifies the loss computation in such cases by converting a function that takes in two pairwise displacements or distances to one that only requires the system as input.
- Parameters
fn (
Callable
[…,Array
]) – A function that takes an ndarray of two distances or displacements from a central atom, both of shape[n, m]
or[n, m, d_in]
respectively, as well as kwargs specifying parameters for the function.metric – A function that takes two ndarray of positions of shape
[spatial_dimensions]
and[spatial_dimensions]
respectively and returns an ndarray of distances or displacements of shape[]
or[d_in]
respectively.species (
Optional
[Array
]) – A list of species for the different particles. This should either be None (in which case it is assumed that all the particles have the same species), an integer ndarray of shape[n]
with species data, or an integer in which case the species data will be specified dynamically withspecies
giving the maximum number of types of particles. Note: that dynamic species specification is less efficient, because we cannot specialize shape information.reduce_axis (
Optional
[Tuple
[int
, …]]) – A list of axis to reduce over. This is supplied to np.sum and the same convention is used.keepdims (
bool
) – A boolean specifying whether the empty dimensions should be kept upon reduction. This is supplied to np.sum and so the same convention is used.ignore_unused_parameters (
bool
) – A boolean that denotes whether dynamically specified keyword arguments passed to the mapped function get ignored if they were not first specified as keyword arguments when callingsmap.triplet(...)
.kwargs –
Argument providing parameters to the mapped function. In cases where no species information is provided, these should either be
a scalar
an ndarray of shape
[n]
based on the central atoman ndarray of shape
[n, n, n]
defining triplet interactions.
If species information is provided, then the parameters should be specified as either
a scalar
an ndarray of shape
[max_species]
an ndarray of shape
[max_species, max_species, max_species]
defining triplet interactions.
- Return type
Callable
[…,Array
]- Returns
A function
fn_mapped
.If species is None or statically specified, then
fn_mapped
takes as arguments an ndarray of positions of shape[n, spatial_dimension]
.If species is dynamic then
fn_mapped
takes as input an ndarray of shape[n, spatial_dimension]
, an integer ndarray of species of shape[n]
, and an integer specifying the maximum species.The mapped function can also optionally take keyword arguments that get threaded through the metric.