# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spaces in which particles are simulated.
Spaces are pairs of functions containing:
`displacement_fn(Ra, Rb, **kwargs)`:
Computes displacements between pairs of particles. `Ra` and `Rb` should
be ndarrays of shape `[spatial_dim]`. Returns an ndarray of shape `[spatial_dim]`.
To compute the displacement over more than one particle at a time see the
:meth:`map_product`, :meth:`map_bond`, and :meth:`map_neighbor` functions.
`shift_fn(R, dR, **kwargs)`:
Moves points at position `R` by an amount `dR`.
Spaces can accept keyword arguments allowing the space to be changed over the
course of a simulation. For an example of this use see :meth:`periodic_general`.
Although displacement functions are compute the displacement between two
points, it is often useful to compute displacements between multiple particles
in a vectorized fashion. To do this we provide three functions: `map_product`,
`map_bond`, and `map_neighbor`:
map_product:
Computes displacements between all pairs of points such that if
`Ra` has shape `[n, spatial_dim]` and `Rb` has shape `[m, spatial_dim]` then the
output has shape `[n, m, spatial_dim]`.
map_bond:
Computes displacements between all points in a list such that if
`Ra` has shape `[n, spatial_dim]` and `Rb` has shape `[m, spatial_dim]` then the
output has shape `[n, spatial_dim]`.
map_neighbor:
Computes displacements between points and all of their
neighbors such that if `Ra` has shape `[n, spatial_dim]` and `Rb` has shape
`[n, neighbors, spatial_dim]` then the output has shape
`[n, neighbors, spatial_dim]`.
"""
from typing import Callable, Union, Tuple, Any, Optional
from jax.core import ShapedArray
from jax import eval_shape
from jax import vmap
from jax import custom_jvp
import jax
import jax.numpy as jnp
from jax_md.util import Array
from jax_md.util import f32
from jax_md.util import f64
from jax_md.util import safe_mask
# Types
DisplacementFn = Callable[[Array, Array], Array]
MetricFn = Callable[[Array, Array], float]
DisplacementOrMetricFn = Union[DisplacementFn, MetricFn]
ShiftFn = Callable[[Array, Array], Array]
Space = Tuple[DisplacementFn, ShiftFn]
Box = Array
# Exceptions
class UnexpectedBoxException(Exception):
pass
# Primitive Spatial Transforms
def inverse(box: Box) -> Box:
"""Compute the inverse of an affine transformation."""
if jnp.isscalar(box) or box.size == 1:
return 1 / box
elif box.ndim == 1:
return 1 / box
elif box.ndim == 2:
return jnp.linalg.inv(box)
raise ValueError(('Box must be either: a scalar, a vector, or a matrix. '
f'Found {box}.'))
def _get_free_indices(n: int) -> str:
return ''.join([chr(ord('a') + i) for i in range(n)])
def raw_transform(box: Box, R: Array) -> Array:
"""Apply an affine transformation to positions.
See `periodic_general` for a description of the semantics of `box`.
Args:
box: An affine transformation described in `periodic_general`.
R: Array of positions. Should have shape `(..., spatial_dimension)`.
Returns:
A transformed array positions of shape `(..., spatial_dimension)`.
"""
if jnp.isscalar(box) or box.size == 1:
return R * box
elif box.ndim == 1:
indices = _get_free_indices(R.ndim - 1) + 'i'
return jnp.einsum(f'i,{indices}->{indices}', box, R)
elif box.ndim == 2:
free_indices = _get_free_indices(R.ndim - 1)
left_indices = free_indices + 'j'
right_indices = free_indices + 'i'
return jnp.einsum(f'ij,{left_indices}->{right_indices}', box, R)
raise ValueError(('Box must be either: a scalar, a vector, or a matrix. '
f'Found {box}.'))
@transform.defjvp
def transform_jvp(primals, tangents):
box, R = primals
dbox, dR = tangents
return (transform(box, R), dR + transform(dbox, R))
[docs]def pairwise_displacement(Ra: Array, Rb: Array) -> Array:
"""Compute a matrix of pairwise displacements given two sets of positions.
Args:
Ra: Vector of positions; `ndarray(shape=[spatial_dim])`.
Rb: Vector of positions; `ndarray(shape=[spatial_dim])`.
Returns:
Matrix of displacements; `ndarray(shape=[spatial_dim])`.
"""
if len(Ra.shape) != 1:
msg = (
'Can only compute displacements between vectors. To compute '
'displacements between sets of vectors use vmap or TODO.'
)
raise ValueError(msg)
if Ra.shape != Rb.shape:
msg = 'Can only compute displacement between vectors of equal dimension.'
raise ValueError(msg)
return Ra - Rb
def periodic_displacement(side: Box, dR: Array) -> Array:
"""Wraps displacement vectors into a hypercube.
Args:
side: Specification of hypercube size. Either,
(a) float if all sides have equal length.
(b) ndarray(spatial_dim) if sides have different lengths.
dR: Matrix of displacements; `ndarray(shape=[..., spatial_dim])`.
Returns:
Matrix of wrapped displacements; `ndarray(shape=[..., spatial_dim])`.
"""
return jnp.mod(dR + side * f32(0.5), side) - f32(0.5) * side
[docs]def square_distance(dR: Array) -> Array:
"""Computes square distances.
Args:
dR: Matrix of displacements; `ndarray(shape=[..., spatial_dim])`.
Returns:
Matrix of squared distances; `ndarray(shape=[...])`.
"""
return jnp.sum(dR ** 2, axis=-1)
[docs]def distance(dR: Array) -> Array:
"""Computes distances.
Args:
dR: Matrix of displacements; `ndarray(shape=[..., spatial_dim])`.
Returns:
Matrix of distances; `ndarray(shape=[...])`.
"""
dr = square_distance(dR)
return safe_mask(dr > 0, jnp.sqrt, dr)
[docs]def periodic_shift(side: Box, R: Array, dR: Array) -> Array:
"""Shifts positions, wrapping them back within a periodic hypercube."""
return jnp.mod(R + dR, side)
""" Spaces """
[docs]def free() -> Space:
"""Free boundary conditions."""
def displacement_fn(Ra: Array, Rb: Array, perturbation: Optional[Array]=None,
**unused_kwargs) -> Array:
dR = pairwise_displacement(Ra, Rb)
if perturbation is not None:
dR = raw_transform(perturbation, dR)
return dR
def shift_fn(R: Array, dR: Array, **unused_kwargs) -> Array:
return R + dR
return displacement_fn, shift_fn
[docs]def periodic(side: Box, wrapped: bool=True) -> Space:
"""Periodic boundary conditions on a hypercube of sidelength side.
Args:
side: Either a float or an ndarray of shape [spatial_dimension] specifying
the size of each side of the periodic box.
wrapped: A boolean specifying whether or not particle positions are
remapped back into the box after each step
Returns:
`(displacement_fn, shift_fn)` tuple.
"""
def displacement_fn(Ra: Array, Rb: Array,
perturbation: Optional[Array] = None,
**unused_kwargs) -> Array:
if 'box' in unused_kwargs:
raise UnexpectedBoxException(('`space.periodic` does not accept a box '
'argument. Perhaps you meant to use '
'`space.periodic_general`?'))
dR = periodic_displacement(side, pairwise_displacement(Ra, Rb))
if perturbation is not None:
dR = raw_transform(perturbation, dR)
return dR
if wrapped:
def shift_fn(R: Array, dR: Array, **unused_kwargs) -> Array:
if 'box' in unused_kwargs:
raise UnexpectedBoxException(('`space.periodic` does not accept a box '
'argument. Perhaps you meant to use '
'`space.periodic_general`?'))
return periodic_shift(side, R, dR)
else:
def shift_fn(R: Array, dR: Array, **unused_kwargs) -> Array:
if 'box' in unused_kwargs:
raise UnexpectedBoxException(('`space.periodic` does not accept a box '
'argument. Perhaps you meant to use '
'`space.periodic_general`?'))
return R + dR
return displacement_fn, shift_fn
[docs]def periodic_general(box: Box,
fractional_coordinates: bool=True,
wrapped: bool=True) -> Space:
"""Periodic boundary conditions on a parallelepiped.
This function defines a simulation on a parallelepiped, :math:`X`, formed by
applying an affine transformation, :math:`T`, to the unit hypercube
:math:`U = [0, 1]^d` along with periodic boundary conditions across all
of the faces.
Formally, the space is defined such that :math:`X = {Tu : u \in [0, 1]^d}`.
The affine transformation, :math:`T`, can be specified in a number of different
ways. For a parallelepiped that is: 1) a cube of side length :math:`L`, the affine
transformation can simply be a scalar; 2) an orthorhombic unit cell can be
specified by a vector `[Lx, Ly, Lz]` of lengths for each axis; 3) a general
triclinic cell can be specified by an upper triangular matrix.
There are a number of ways to parameterize a simulation on :math:`X`.
`periodic_general` supports two parametrizations of :math:`X` that can be selected
using the `fractional_coordinates` keyword argument.
1) When `fractional_coordinates=True`, particle positions are stored in the
unit cube, :math:`u\in U`. Here, the displacement function computes the
displacement between :math:`x, y \in X` as :math:`d_X(x, y) = Td_U(u, v)` where
:math:`d_U` is the displacement function on the unit cube, :math:`U`, :math:`x = Tu`, and
:math:`v = Tv` with :math:`u, v \in U`. The derivative of the displacement function
is defined so that derivatives live in :math:`X` (as opposed to being
backpropagated to :math:`U`). The shift function, `shift_fn(R, dR)` is defined
so that :math:`R` is expected to lie in :math:`U` while :math:`dR` should lie in :math:`X`. This
combination enables code such as `shift_fn(R, force_fn(R))` to work as
intended.
2) When `fractional_coordinates=False`, particle positions are stored in
the parallelepiped :math:`X`. Here, for :math:`x, y \in X`, the displacement function
is defined as :math:`d_X(x, y) = Td_U(T^{-1}x, T^{-1}y)`. Since there is an
extra multiplication by :math:`T^{-1}`, this parameterization is typically
slower than `fractional_coordinates=False`. As in 1), the displacement
function is defined to compute derivatives in :math:`X`. The shift function
is defined so that :math:`R` and :math:`dR` should both lie in :math:`X`.
Example:
.. code-block:: python
from jax import random
side_length = 10.0
disp_frac, shift_frac = periodic_general(side_length,
fractional_coordinates=True)
disp_real, shift_real = periodic_general(side_length,
fractional_coordinates=False)
# Instantiate random positions in both parameterizations.
R_frac = random.uniform(random.PRNGKey(0), (4, 3))
R_real = side_length * R_frac
# Make some shift vectors.
dR = random.normal(random.PRNGKey(0), (4, 3))
disp_real(R_real[0], R_real[1]) == disp_frac(R_frac[0], R_frac[1])
transform(side_length, shift_frac(R_frac, 1.0)) == shift_real(R_real, 1.0)
It is often desirable to deform a simulation cell either: using a finite
deformation during a simulation, or using an infinitesimal deformation while
computing elastic constants. To do this using fractional coordinates, we can
supply a new affine transformation as `displacement_fn(Ra, Rb, box=new_box)`.
When using real coordinates, we can specify positions in a space :math:`X` defined
by an affine transformation :math:`T` and compute displacements in a deformed space
:math:`X'` defined by an affine transformation :math:`T'`. This is done by writing
`displacement_fn(Ra, Rb, new_box=new_box)`.
There are a few caveats when using `periodic_general`. `periodic_general`
uses the minimum image convention, and so it will fail for potentials whose
cutoff is longer than the half of the side-length of the box. It will also
fail to find the correct image when the box is too deformed. We hope to add a
more robust box for small simulations soon (TODO) along with better error
checking. In the meantime caution is recommended.
Args:
box: A `(spatial_dim, spatial_dim)` affine transformation.
fractional_coordinates: A boolean specifying whether positions are stored
in the parallelepiped or the unit cube.
wrapped: A boolean specifying whether or not particle positions are
remapped back into the box after each step
Returns:
`(displacement_fn, shift_fn)` tuple.
"""
inv_box = inverse(box)
def displacement_fn(Ra, Rb, perturbation=None, **kwargs):
_box, _inv_box = box, inv_box
if 'box' in kwargs:
_box = kwargs['box']
if not fractional_coordinates:
_inv_box = inverse(_box)
if 'new_box' in kwargs:
_box = kwargs['new_box']
if not fractional_coordinates:
Ra = transform(_inv_box, Ra)
Rb = transform(_inv_box, Rb)
dR = periodic_displacement(f32(1.0), pairwise_displacement(Ra, Rb))
dR = transform(_box, dR)
if perturbation is not None:
dR = raw_transform(perturbation, dR)
return dR
def u(R, dR):
if wrapped:
return periodic_shift(f32(1.0), R, dR)
return R + dR
def shift_fn(R, dR, **kwargs):
if not fractional_coordinates and not wrapped:
return R + dR
_box, _inv_box = box, inv_box
if 'box' in kwargs:
_box = kwargs['box']
_inv_box = inverse(_box)
if 'new_box' in kwargs:
_box = kwargs['new_box']
dR = transform(_inv_box, dR)
if not fractional_coordinates:
R = transform(_inv_box, R)
R = u(R, dR)
if not fractional_coordinates:
R = transform(_box, R)
return R
return displacement_fn, shift_fn
[docs]def metric(displacement: DisplacementFn) -> MetricFn:
"""Takes a displacement function and creates a metric."""
return lambda Ra, Rb, **kwargs: distance(displacement(Ra, Rb, **kwargs))
[docs]def map_product(metric_or_displacement: DisplacementOrMetricFn
) -> DisplacementOrMetricFn:
"""Vectorizes a metric or displacement function over all pairs."""
return vmap(vmap(metric_or_displacement, (0, None), 0), (None, 0), 0)
[docs]def map_bond(metric_or_displacement: DisplacementOrMetricFn
) -> DisplacementOrMetricFn:
"""Vectorizes a metric or displacement function over bonds."""
return vmap(metric_or_displacement, (0, 0), 0)
[docs]def map_neighbor(metric_or_displacement: DisplacementOrMetricFn
) -> DisplacementOrMetricFn:
"""Vectorizes a metric or displacement function over neighborhoods."""
def wrapped_fn(Ra, Rb, **kwargs):
return vmap(vmap(metric_or_displacement, (0, None)))(Rb, Ra, **kwargs)
return wrapped_fn
[docs]def canonicalize_displacement_or_metric(displacement_or_metric):
"""Checks whether or not a displacement or metric was provided."""
for dim in range(1, 4):
try:
R = ShapedArray((dim,), f32)
dR_or_dr = eval_shape(displacement_or_metric, R, R, t=0)
if len(dR_or_dr.shape) == 0:
return displacement_or_metric
else:
return metric(displacement_or_metric)
except TypeError:
continue
except ValueError:
continue
raise ValueError(
'Canonicalize displacement not implemented for spatial dimension larger'
'than 4.')