# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code to calculate the elastic modulus tensor for athermal systems.
The elastic modulus tensor describes a material's response to different
boundary deformations. Specifically, for a small deformation given by a
symmetric strain tensor e, the change in energy is
.. math::
U / V^0 = U^0/V^0 + s^0_{ij} e_{ji} + (1/2) C_{ijkl} e_{ij} e_{kl} + \ldots{}
where :math:`V^0` is the volume, :math:`U^0` is the initial energy, :math:`s^0` is the residual stress
tensor of the undeformed system, and :math:`C` is the elastic modulus tensor. :math:`C` is
a fourth-rank tensor of shape (dimension,dimension,dimension,dimension), with
the following symmetries.
Minor symmetries:
.. math::
C_{ijkl} = C_{jikl} = C_{ijlk}
Major symmetries:
.. math::
C_{ijkl} = C_{lkij}
The minor symmetries are also reflected in the symmetric nature
of stress and strain tensors:
.. math::
s_{ij} = s_{ji}
e_{ij} = e_{ji}
In general, there are 21 independent elastic constants in 3 dimension (6 in 2
dimensions). While systems with additional symmetries (e.g. isotropic,
orthotropic, etc.) can be expressed with fewer constants, we do not assume any
such additional symmetries.
At zero temperature, the response of every particle to a deformation can be
calculated explicitly to linear order, enabling the exact calculation of
the elastic modulus tensor without the need of any finite differences or any
approximations.
"""
from functools import partial
from typing import Dict, Callable, Union
from absl import logging
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import jit, lax, grad, vmap, jacfwd, jacrev, jvp
from jax_md import quantity
from jax_md.util import Array
from jax_md.util import f32
from jax_md.util import f64
def _get_strain_tensor_list(dim, dtype) -> Array:
if dim == 2:
strain_tensors = jnp.array([[[1, 0],[0, 0]],
[[0, 0],[0, 1]],
[[0, 1],[1, 0]],
[[1, 0],[0, 1]],
[[1, 1],[1, 0]],
[[0, 1],[1, 1]]], dtype=dtype)
elif dim == 3:
strain_tensors = jnp.array([[[1, 0, 0], [0, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 1, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 1]],
[[0, 0, 0], [0, 0, 1], [0, 1, 0]],
[[0, 0, 1], [0, 0, 0], [1, 0, 0]],
[[0, 1, 0], [1, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 1, 0], [0, 0, 1]],
[[1, 0, 0], [0, 0, 0], [0, 0, 1]],
[[1, 0, 0], [0, 1, 0], [0, 0, 0]],
[[1, 0, 0], [0, 0, 1], [0, 1, 0]],
[[1, 0, 1], [0, 0, 0], [1, 0, 0]],
[[1, 1, 0], [1, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 1, 1], [0, 1, 0]],
[[0, 0, 1], [0, 1, 0], [1, 0, 0]],
[[0, 1, 0], [1, 1, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 1], [0, 1, 1]],
[[0, 0, 1], [0, 0, 0], [1, 0, 1]],
[[0, 1, 0], [1, 0, 0], [0, 0, 1]],
[[0, 0, 1], [0, 0, 1], [1, 1, 0]],
[[0, 1, 0], [1, 0, 1], [0, 1, 0]],
[[0, 1, 1], [1, 0, 0], [1, 0, 0]]], dtype=dtype)
else:
raise AssertionError('not implemented for {} dimensions'.format(dim))
return strain_tensors
def _convert_responses_to_elastic_constants(response_all: Array) -> Array:
if response_all.shape[0] == 6:
cxxxx = response_all[0]
cyyyy = response_all[1]
cxyxy = 0.25 * response_all[2]
cxxyy = 0.5 * (response_all[3] - cxxxx - cyyyy)
cxxxy = 0.25 * (response_all[4] - 4 * cxyxy - cxxxx)
cyyxy = 0.25 * (response_all[5] - 4 * cxyxy - cyyyy)
C = jnp.array(
[[[[cxxxx, cxxxy], [cxxxy, cxxyy]],
[[cxxxy, cxyxy], [cxyxy, cyyxy]]],
[[[cxxxy, cxyxy], [cxyxy, cyyxy]],
[[cxxyy, cyyxy], [cyyxy, cyyyy]]]])
elif response_all.shape[0] == 21:
cxxxx = response_all[0];
cyyyy = response_all[1];
czzzz = response_all[2];
cyzyz = response_all[3] / 4.;
cxzxz = response_all[4] / 4.;
cxyxy = response_all[5] / 4.;
cyyzz = (response_all[6] - cyyyy - czzzz) / 2.;
cxxzz = (response_all[7] - cxxxx - czzzz) / 2.;
cxxyy = (response_all[8] - cxxxx - cyyyy) / 2.;
cxxyz = (response_all[9] - cxxxx - 4. * cyzyz) / 4.;
cxxxz = (response_all[10] - cxxxx - 4. * cxzxz) / 4.;
cxxxy = (response_all[11] - cxxxx - 4. * cxyxy) / 4.;
cyyyz = (response_all[12] - cyyyy - 4. * cyzyz) / 4.;
cyyxz = (response_all[13] - cyyyy - 4. * cxzxz) / 4.;
cyyxy = (response_all[14] - cyyyy - 4. * cxyxy) / 4.;
czzyz = (response_all[15] - czzzz - 4. * cyzyz) / 4.;
czzxz = (response_all[16] - czzzz - 4. * cxzxz) / 4.;
czzxy = (response_all[17] - czzzz - 4. * cxyxy) / 4.;
cyzxz = (response_all[18] - 4. * cyzyz - 4. * cxzxz) / 8.;
cyzxy = (response_all[19] - 4. * cyzyz - 4. * cxyxy) / 8.;
cxzxy = (response_all[20] - 4. * cxzxz - 4. * cxyxy) / 8.;
C = jnp.array(
[[[[cxxxx, cxxxy, cxxxz],
[cxxxy, cxxyy, cxxyz],
[cxxxz, cxxyz, cxxzz]],
[[cxxxy, cxyxy, cxzxy],
[cxyxy, cyyxy, cyzxy],
[cxzxy, cyzxy, czzxy]],
[[cxxxz, cxzxy, cxzxz],
[cxzxy, cyyxz, cyzxz],
[cxzxz, cyzxz, czzxz]]],
[[[cxxxy, cxyxy, cxzxy],
[cxyxy, cyyxy, cyzxy],
[cxzxy, cyzxy, czzxy]],
[[cxxyy, cyyxy, cyyxz],
[cyyxy, cyyyy, cyyyz],
[cyyxz, cyyyz, cyyzz]],
[[cxxyz, cyzxy, cyzxz],
[cyzxy, cyyyz, cyzyz],
[cyzxz, cyzyz, czzyz]]],
[[[cxxxz, cxzxy, cxzxz],
[cxzxy, cyyxz, cyzxz],
[cxzxz, cyzxz, czzxz]],
[[cxxyz, cyzxy, cyzxz],
[cyzxy, cyyyz, cyzyz],
[cyzxz, cyzyz, czzyz]],
[[cxxzz, czzxy, czzxz],
[czzxy, cyyzz, czzyz],
[czzxz, czzyz, czzzz]]]])
else:
raise AssertionError('response_all has incorrect shape')
return C
[docs]def athermal_moduli(energy_fn: Callable[..., Array],
tether_strength: float=1e-10,
gradient_check: Array=None,
cg_tol: float=1e-7,
check_convergence: bool=False
) -> Callable[..., Array]:
""" Setup calculation of elastic modulus tensor.
Args:
energy_fn: A function that computes the energy of the system. This
function must take as an argument `perturbation` which perturbs the
box shape. Any energy function constructed using `smap` or in `energy.py`
with a standard space will satisfy this property.
tether_strength: Scalar. Strength of the "tether" applied to each particle,
which can be necessary to make the Hessian matrix non-singular. Solving
for the non-affine response of each particle requires that the Hessian is
positive definite. However, there can often be zero modes (eigenvectors of
the Hessian with zero eigenvalue) that do not couple to the boundary, and
therefore do not affect the elastic constants despite the zero eigenvalue.
The most common example is the global translational modes. To solve for
the non-affine response, we consider the "tethered Hessian"
`H + tether_strength * I`,
where `I` is the identity matrix and `tether_strength` is a small constant.
gradient_check: `None` or scalar. If not `None`, a check will be performed
to guarantee that the maximum component of the gradient is less than
`gradient_check`. In other words, that
`jnp.amax(jnp.abs(grad(energy_fn)(R, box=box))) < gradient_check == True`
NOTE: JAX currently does not support proper runtime error handling.
Therefore, if this check fails, the calculation will return an array
of `jnp.nan`'s. It is the users responsibility, if they want to use this
check, to then ensure that the returned array is not full of nans.
cg_tol: scalar. Tolerance used when solving for the non-affine response.
check_convergence: bool. If true, `calculate_EMT` will return a boolean
flag specifying if the cg solve routine converged to the desired
tolerance. The default is `False`, but convergence checking is highly
recommended especially when using 32-bit precision data.
Return: A function to calculate the elastic modulus tensor
"""
def calculate_emt(R: Array,
box: Array,
**kwargs) -> Array:
"""Calculate the elastic modulus tensor.
`energy_fn(R)` corresponds to the state around which we are expanding
Args:
R: array of shape `[N,dimension]` of particle positions. This does not
generalize to arbitrary dimensions and is only implemented for
`dimension == 2` and `dimension == 3`.
box: A box specifying the shape of the simulation volume. Used to infer
the volume of the unit cell.
Return: C or the tuple (C,converged)
where C is the Elastic modulus tensor as an array of shape `[dimension,
dimension,dimension,dimension]` that respects the major and minor
symmetries, and converged is a boolean flag (see above).
"""
if not (R.shape[-1] == 2 or R.shape[-1] == 3):
raise AssertionError('Only implemented for 2d and 3d systems.')
if R.dtype is not jnp.dtype('float64'):
logging.warning('Elastic modulus calculations can sometimes lose '
'precision when not using 64-bit precision.')
dim = R.shape[-1]
def setup_energy_fn_general(strain_tensor):
I = jnp.eye(dim, dtype=R.dtype)
@jit
def energy_fn_general(R, gamma):
perturbation = I + gamma * strain_tensor
return energy_fn(R, perturbation=perturbation, **kwargs)
return energy_fn_general
def get_affine_response(strain_tensor):
energy_fn_general = setup_energy_fn_general(strain_tensor)
d2U_dRdgamma = jacfwd(jacrev(energy_fn_general,argnums=0),argnums=1)(R,0.)
d2U_dgamma2 = jacfwd(jacrev(energy_fn_general,argnums=1),argnums=1)(R,0.)
return d2U_dRdgamma, d2U_dgamma2
strain_tensors = _get_strain_tensor_list(dim, R.dtype)
d2U_dRdgamma_all,d2U_dgamma2_all = vmap(get_affine_response)(strain_tensors)
#Solve the system of equations.
energy_fn_Ronly = partial(energy_fn, **kwargs)
def hvp(f, primals, tangents):
return jvp(grad(f), primals, tangents)[1]
def hvp_specific_with_tether(v):
return hvp(energy_fn_Ronly, (R,), (v,)) + tether_strength * v
non_affine_response_all = jsp.sparse.linalg.cg(
vmap(hvp_specific_with_tether),
d2U_dRdgamma_all,
tol=cg_tol
)[0]
#The above line should be functionally equivalent to:
#H0=hessian(energy_fn)(R, box=box, **kwargs).reshape(R.size,R.size) \
# + tether_strength * jnp.identity(R.size)
#non_affine_response_all = jnp.transpose(jnp.linalg.solve(
# H0,
# jnp.transpose(d2U_dRdgamma_all))
# )
residual = jnp.linalg.norm(vmap(hvp_specific_with_tether)(
non_affine_response_all) - d2U_dRdgamma_all
)
converged = residual / jnp.linalg.norm(d2U_dRdgamma_all) < cg_tol
response_all = d2U_dgamma2_all - jnp.einsum("nij,nij->n",
d2U_dRdgamma_all,
non_affine_response_all)
vol_0 = quantity.volume(dim, box)
response_all = response_all / vol_0
C = _convert_responses_to_elastic_constants(response_all)
# JAX does not allow proper runtime error handling in jitted function.
# Instead, if the user requests a gradient check and the check fails,
# we convert C into jnp.nan's. While this doesn't raise an exception,
# it at least is very "loud".
if gradient_check is not None:
maxgrad = jnp.amax(jnp.abs(grad(energy_fn)(R, **kwargs)))
C = lax.cond(maxgrad > gradient_check,
lambda _: jnp.nan * C,
lambda _: C,
None)
if check_convergence:
return C, converged
else:
return C
return calculate_emt
def _get_mandel_mapping_weight(dim, dtype):
if dim == 2:
m_map = jnp.array([[0,0],[1,1],[0,1]], dtype=jnp.int8)
weight = jnp.array([1,1,jnp.sqrt(2)], dtype=dtype)
return m_map, weight
elif dim == 3:
m_map = jnp.array([[0,0],[1,1],[2,2],[1,2],[0,2],[0,1]], dtype=jnp.int8)
weight = jnp.array([1,1,1,jnp.sqrt(2),jnp.sqrt(2),jnp.sqrt(2)], dtype=dtype)
return m_map, weight
else:
raise AssertionError('dim must be 2 or 3')
[docs]def tensor_to_mandel(T: Array) -> Array:
""" Convert a tensor to Mandel notation.
Mandel notation is a way to represent symmetric second-rank tensors and
fourth-rank tensors with minor symmetries in a reduced form. Pairs of indices
are combined as follows:
For tensors of shape (2,2) or (2,2,2,2):
tensor indices Mandel indices
0,0 --------------> 0
1,1 --------------> 1
0,1 or 1,0 -------> 2
For tensors of shape (3,3) or (3,3,3,3):
tensor indices Mandel indices
0,0 --------------> 0
1,1 --------------> 1
2,2 --------------> 2
1,2 or 2,1 -------> 3
0,2 or 2,0 -------> 4
0,1 or 1,0 -------> 5
If mandel_index(i,j) performs the above index mapping, then the input T and
output M satisfy
M[mandel_index(i,j)] = T[i,j] * w
or
M[mandel_index(i,j), mandel_index(k,l)] = T[i,j,k,l] * w(i,j) * w(k,l)
where
w(i,j) = 1 if i==j
= sqrt(2) if i!=j
is a weight that is used to ensure proper contraction rules. Here (and only
here) we do not assume major symmetries in fourth-rank tensors.
Args:
T: Array with 4 possible shapes:
#. T.shape == (2,2)
Convert a symmetric array of shape (2,2) to an array of shape (3,)
#. T.shape == (3,3)
Convert a symmetric array of shape (3,3) to an array of shape (6,)
#. T.shape == (2,2,2,2)
Convert a tensor of shape (2,2,2,2) with minor symmetries to an array
of shape (3,3)
#. T.shape == (3,3,3,3)
Convert a tensor of shape (3,3,3,3) with minor symmetries to an array
of shape (6,6)
Output: Array of shape (3,), (6,), (3,3), or (6,6)
see: https://sbrisard.github.io/janus/mandel.html (accessed 21 April, 2021)
"""
dim = T.shape[0]
if not (dim ==2 or dim == 3):
raise AssertionError('dim must be 2 or 3')
rank = len(T.shape)
if not (rank == 2 or rank == 4):
raise AssertionError('T must have rank 2 or 4')
m_map, weight = _get_mandel_mapping_weight(dim, T.dtype)
if rank == 2:
extract = lambda idx, w: T[idx[0], idx[1]] * w
M = vmap(extract, in_axes=(0,0))(m_map, weight)
else:
extract = lambda idx0, idx1, w0, w1: T[
idx0[0], idx0[1], idx1[0], idx1[1]
] * w0 * w1
M = vmap(
vmap(extract,
in_axes=(0,None,0,None)),
in_axes=(None,0,None,0))(m_map, m_map, weight, weight)
return M
[docs]def mandel_to_tensor(M: Array) -> Array:
""" Perform the inverse of M = tensor_to_mandel(T).
Args:
M: Array of shape (3,), (6,), (3,3), or (6,6)
Output: Array of shape (2,2), (3,3), (2,2,2,2), or (3,3,3,3)
"""
mandel_dim = M.shape[0]
if not (mandel_dim == 3 or mandel_dim == 6):
raise AssertionError('M.shape[0] must be 3 or 6')
rank = len(M.shape)
if not (rank == 1 or rank == 2):
raise AssertionError('T must have rank 1 or 2')
def mandel_index(i, j):
return lax.cond(i==j,
lambda ij: ij[0],
lambda ij: mandel_dim - ij[0] - ij[1],
(i,j))
if mandel_dim == 3:
dimension = 2
else:
dimension = 3
tensor_range = jnp.arange(dimension)
_, weight = _get_mandel_mapping_weight(dimension, M.dtype)
if rank == 1:
def extract(i, j):
idx = mandel_index(i, j)
return M[idx] / weight[idx]
T = vmap(
vmap(extract,
in_axes=(0,None)),
in_axes=(None,0))(tensor_range, tensor_range)
else:
def extract(i, j, k, l):
idx0 = mandel_index(i, j)
idx1 = mandel_index(k, l)
return M[idx0, idx1] / (weight[idx0] * weight[idx1])
T = vmap(
vmap(
vmap(
vmap(extract,
in_axes=(0,None,None,None)),
in_axes=(None,0,None,None)),
in_axes=(None,None,0,None)),
in_axes=(None,None,None,0))(tensor_range,
tensor_range,
tensor_range,
tensor_range)
return T
@partial(jit,static_argnums=(1,))
def _extract_elements(C, as_dict):
if C.shape[0] == 2:
indices = ( [0, 1, 0, 0, 0, 0],
[0, 1, 1, 0, 0, 1],
[0, 1, 0, 1, 0, 1],
[0, 1, 1, 1, 1, 1])
clist = C[ indices ]
if as_dict:
names = ['cxxxx','cyyyy','cxyxy','cxxyy','cxxxy','cyyxy']
return dict(zip(names, clist))
else:
return clist
elif C.shape[0] == 3:
indices = ( [0, 1, 2, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0],
[0, 1, 2, 2, 2, 1, 1, 0, 0, 0, 0, 0, 1, 2, 1, 2, 2, 1, 2, 1, 1],
[0, 1, 2, 1, 0, 0, 2, 2, 1, 1, 0, 0, 1, 1, 1, 2, 2, 2, 1, 1, 0],
[0, 1, 2, 2, 2, 1, 2, 2, 1, 2, 2, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2])
clist = C[ indices ]
if as_dict:
names = ['cxxxx','cyyyy','czzzz','cyzyz','cxzxz','cxyxy','cyyzz','cxxzz',
'cxxyy','cxxyz','cxxxz','cxxxy','cyyyz','cyyxz','cyyxy','czzyz',
'czzxz','czzxy','cyzxz','cyzxy','cxzxy']
return dict(zip(names, clist))
else:
return clist
else:
raise AssertionError('C has wrong shape')