Higher Order Functions#

Code to transform functions on individual tuples of particles to sets.

class jax_md.smap.ParameterTree(tree, mapping)[source]#

A container denoting that parameters are in the form of a PyTree.

tree#

A JAX PyTree containing a tree of parameters. Before being fed into mapped functions, these parameters are processed according to the mapping.

Type

Any

mapping#

A ParameterTreeMapping object that specifies how the parameters are processed.

Type

jax_md.smap.ParameterTreeMapping

class jax_md.smap.ParameterTreeMapping(value)[source]#

An enum specifying how parameters are processed in mapped functions.

Global#

Global parameters are passed directly to the mapped function.

PerParticle#

PerParticle parameters are combined in pairs based on the particle index. E.g. p_ij = combinator(p_i, p_j) for particles i and j. These parameters are expected to have a leading axis of length the number of particles.

PerBond#

PerBond parameters are expected to have leading two dimensions equal to the number of particles in the system.

PerSpecies#

PerSpecies parameters are expected to have two leading dimensions equal to the number of species. For particles of species s_i and s_j parameters are combined according to p_ij = combinator(p[s_i], p[s_j]).

jax_md.smap.bond(fn, displacement_or_metric, static_bonds=None, static_bond_types=None, ignore_unused_parameters=False, **kwargs)[source]#

Promotes a function that acts on a single pair to one on a set of bonds.

TODO(schsam): It seems like bonds might potentially have poor memory access. Should think about this a bit and potentially optimize.

Parameters
  • fn (Callable[…, Array]) – A function that takes an ndarray of pairwise distances or displacements of shape [n, m] or [n, m, d_in] respectively as well as kwargs specifying parameters for the function. fn returns an ndarray of evaluations of shape [n, m, d_out].

  • metric – A function that takes two ndarray of positions of shape [spatial_dimension] and [spatial_dimension] respectively and returns an ndarray of distances or displacements of shape [] or [d_in] respectively. The metric can optionally take a floating point time as a third argument.

  • static_bonds (Optional[Array]) – An ndarray of integer pairs wth shape [b, 2] where each pair specifies a bond. static_bonds are baked into the returned compute function statically and cannot be changed after the fact.

  • static_bond_types (Optional[Array]) – An ndarray of integers of shape [b] specifying the type of each bond. Only specify bond types if you want to specify bond parameters by type. One can also specify constant or per-bond parameters (see below).

  • ignore_unused_parameters (bool) – A boolean that denotes whether dynamically specified keyword arguments passed to the mapped function get ignored if they were not first specified as keyword arguments when calling smap.bond(...).

  • kwargs

    Arguments providing parameters to the mapped function. In cases where no bond type information is provided these should be either

    1. a scalar

    2. an ndarray of shape [b].

    If bond type information is provided then the parameters should be specified as either

    1. a scalar

    2. an ndarray of shape [max_bond_type].

    3. a ParameterTree containing a PyTree of parameters and a mapping. See ParameterTree for details.

Return type

Callable[…, Array]

Returns

A function fn_mapped. Note that fn_mapped can take arguments bonds and bond_types which will be bonds that are specified dynamically. This will incur a recompilation when the number of bonds changes. Improving this state of affairs I will leave as a TODO until someone actually uses this feature and runs into speed issues.

jax_md.smap.pair(fn, displacement_or_metric, species=None, reduce_axis=None, keepdims=False, ignore_unused_parameters=False, **kwargs)[source]#

Promotes a function that acts on a pair of particles to one on a system.

Parameters
  • fn (Callable[…, Array]) – A function that takes an ndarray of pairwise distances or displacements of shape [n, m] or [n, m, d_in] respectively as well as kwargs specifying parameters for the function. fn returns an ndarray of evaluations of shape [n, m, d_out].

  • metric – A function that takes two ndarray of positions of shape [spatial_dimension] and [spatial_dimension] respectively and returns an ndarray of distances or displacements of shape [] or [d_in] respectively. The metric can optionally take a floating point time as a third argument.

  • species (Optional[Array]) – A list of species for the different particles. This should either be None (in which case it is assumed that all the particles have the same species), an integer ndarray of shape [n] with species data, or an integer in which case the species data will be specified dynamically with species giving the maximum number of types of particles. Note: that dynamic species specification is less efficient, because we cannot specialize shape information.

  • reduce_axis (Optional[Tuple[int, …]]) – A list of axes to reduce over. This is supplied to jnp.sum and so the same convention is used.

  • keepdims (bool) – A boolean specifying whether the empty dimensions should be kept upon reduction. This is supplied to jnp.sum and so the same convention is used.

  • ignore_unused_parameters (bool) – A boolean that denotes whether dynamically specified keyword arguments passed to the mapped function get ignored if they were not first specified as keyword arguments when calling smap.pair(...).

  • kwargs

    Arguments providing parameters to the mapped function. In cases where no species information is provided these should be either

    1. a scalar

    2. an ndarray of shape [n]

    3. an ndarray of shape [n, n],

    4. a ParameterTree containing a PyTree of parameters and a mapping. See ParameterTree for details.

    5. a binary function that determines how per-particle parameters are to be combined

    6. a binary function as well as a default set of parameters as in 2) or 4).

    If unspecified then this is taken to be the average of the two per-particle parameters. If species information is provided then the parameters should be specified as either

    1. a scalar

    2. an ndarray of shape [max_species, max_species]

    3. a ParameterTree containing a PyTree of parameters and a mapping. See ParameterTree for details.

Return type

Callable[…, Array]

Returns

A function fn_mapped.

If species is None or statically specified then fn_mapped takes as arguments an ndarray of positions of shape [n, spatial_dimension].

If species is dynamic then fn_mapped takes as input an ndarray of shape [n, spatial_dimension], an integer ndarray of species of shape [n], and an integer specifying the maximum species.

The mapped function can also optionally take keyword arguments that get threaded through the metric.

jax_md.smap.pair_neighbor_list(fn, displacement_or_metric, species=None, reduce_axis=None, ignore_unused_parameters=False, **kwargs)[source]#

Promotes a function acting on pairs of particles to use neighbor lists.

Parameters
  • fn (Callable[…, Array]) – A function that takes an ndarray of pairwise distances or displacements of shape [n, m] or [n, m, d_in] respectively as well as kwargs specifying parameters for the function. fn returns an ndarray of evaluations of shape [n, m, d_out].

  • metric – A function that takes two ndarray of positions of shape [spatial_dimension] and [spatial_dimension] respectively and returns an ndarray of distances or displacements of shape [] or [d_in] respectively. The metric can optionally take a floating point time as a third argument.

  • species (Optional[Array]) – Species information for the different particles. Should either be None (in which case it is assumed that all the particles have the same species), an integer array of shape [n] with species data. Note that species data can be specified dynamically by passing a species keyword argument to the mapped function.

  • reduce_axis (Optional[Tuple[int, …]]) – A list of axes to reduce over. We use a convention where axis 0 corresponds to the particles, axis 1 corresponds to neighbors, and the remaining axes correspond to the output axes of fn. Note that it is not well-defined to sum over particles without summing over neighbors. One also cannot report per-particle values (excluding axis 0) for neighbor lists whose format is OrderedSparse.

  • ignore_unused_parameters (bool) – A boolean that denotes whether dynamically specified keyword arguments passed to the mapped function get ignored if they were not first specified as keyword arguments when calling smap.pair_neighbor_list(...).

  • kwargs

    Arguments providing parameters to the mapped function. In cases where no species information is provided these should be either

    1. a scalar

    2. an ndarray of shape [n]

    3. an ndarray of shape [n, n],

    4. a ParameterTree containing a PyTree of parameters and a mapping. See ParameterTree for details.

    5. a binary function that determines how per-particle parameters are to be combined

    If unspecified then this is taken to be the average of the two per-particle parameters. If species information is provided then the parameters should be specified as either

    1. a scalar

    2. an ndarray of shape [max_species, max_species]

    3. a ParameterTree containing a PyTree of parameters and a mapping. See ParameterTree for details.

Return type

Callable[…, Array]

Returns

A function fn_mapped that takes an ndarray of floats of shape [N, d_in] of positions and and ndarray of integers of shape [N, max_neighbors] specifying neighbors.

jax_md.smap.triplet(fn, displacement_or_metric, species=None, reduce_axis=None, keepdims=False, ignore_unused_parameters=False, **kwargs)[source]#

Promotes a function that acts on triples of particles to one on a system.

Many empirical potentials in jax_md include three-body angular terms (e.g. Stillinger Weber). This utility function simplifies the loss computation in such cases by converting a function that takes in two pairwise displacements or distances to one that only requires the system as input.

Parameters
  • fn (Callable[…, Array]) – A function that takes an ndarray of two distances or displacements from a central atom, both of shape [n, m] or [n, m, d_in] respectively, as well as kwargs specifying parameters for the function.

  • metric – A function that takes two ndarray of positions of shape [spatial_dimensions] and [spatial_dimensions] respectively and returns an ndarray of distances or displacements of shape [] or [d_in] respectively.

  • species (Optional[Array]) – A list of species for the different particles. This should either be None (in which case it is assumed that all the particles have the same species), an integer ndarray of shape [n] with species data, or an integer in which case the species data will be specified dynamically with species giving the maximum number of types of particles. Note: that dynamic species specification is less efficient, because we cannot specialize shape information.

  • reduce_axis (Optional[Tuple[int, …]]) – A list of axis to reduce over. This is supplied to np.sum and the same convention is used.

  • keepdims (bool) – A boolean specifying whether the empty dimensions should be kept upon reduction. This is supplied to np.sum and so the same convention is used.

  • ignore_unused_parameters (bool) – A boolean that denotes whether dynamically specified keyword arguments passed to the mapped function get ignored if they were not first specified as keyword arguments when calling smap.triplet(...).

  • kwargs

    Argument providing parameters to the mapped function. In cases where no species information is provided, these should either be

    1. a scalar

    2. an ndarray of shape [n] based on the central atom

    3. an ndarray of shape [n, n, n] defining triplet interactions.

    If species information is provided, then the parameters should be specified as either

    1. a scalar

    2. an ndarray of shape [max_species]

    3. an ndarray of shape [max_species, max_species, max_species] defining triplet interactions.

Return type

Callable[…, Array]

Returns

A function fn_mapped.

If species is None or statically specified, then fn_mapped takes as arguments an ndarray of positions of shape [n, spatial_dimension].

If species is dynamic then fn_mapped takes as input an ndarray of shape [n, spatial_dimension], an integer ndarray of species of shape [n], and an integer specifying the maximum species.

The mapped function can also optionally take keyword arguments that get threaded through the metric.