Source code for jax_md.minimize

# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Code to minimize the energy of a system.

This file contains a number of different methods that can be used to find the
nearest minimum (inherent structure) to some initial system described by a
position R.

Minimization code follows the same overall structure as optimizers in JAX.
Optimizers return two functions:

init_fn:
  Function that initializes the  state of an optimizer. Should take
  positions as an ndarray of shape `[n, output_dimension]`. Returns a state
  which will be a dataclass.

apply_fn:
  Function that takes a state and produces a new state after one
  step of optimization.
"""

from typing import TypeVar, Callable, Tuple, Union, Any

import jax
import jax.numpy as jnp
from jax.scipy.sparse.linalg import cg
from jax.tree_util import tree_leaves, tree_map, tree_reduce

from jax_md import quantity
from jax_md import dataclasses
from jax_md import partition
from jax_md import util
from jax_md import space
from jax_md import simulate

# Types

PyTree = Any
Array = util.Array
f32 = util.f32
f64 = util.f64

ShiftFn = space.ShiftFn

T = TypeVar('T')
InitFn = Callable[..., T]
ApplyFn = Callable[..., T]
# Parameterize the inner aliases so `Minimizer[X]` stays generic in T;
# bare `InitFn` would collapse to `Callable[..., Unknown]`.
Minimizer = Tuple[InitFn[T], ApplyFn[T]]


[docs] def gradient_descent( energy_or_force: Callable[..., Array], shift_fn: ShiftFn, step_size: float ) -> Minimizer[Array]: """Defines gradient descent minimization. This is the simplest optimization strategy that moves particles down their gradient to the nearest minimum. Generally, gradient descent is slower than other methods and is included mostly for its simplicity. Args: energy_or_force: A function that produces either an energy or a force from a set of particle positions specified as an ndarray of shape `[n, spatial_dimension]`. shift_fn: A function that displaces positions, `R`, by an amount `dR`. Both `R` and `dR` should be ndarrays of shape `[n, spatial_dimension]`. step_size: A floating point specifying the size of each step. Returns: See above. """ force = quantity.canonicalize_force(energy_or_force) def init_fn(R: Array, **unused_kwargs) -> Array: return R def apply_fn(R: Array, **kwargs) -> Array: R = shift_fn(R, step_size * force(R, **kwargs), **kwargs) return R return init_fn, apply_fn
[docs] @dataclasses.dataclass class FireDescentState: """A dataclass containing state information for the Fire Descent minimizer. Attributes: position: The current position of particles. An ndarray of floats with shape `[n, spatial_dimension]`. momentum: The current momentum of particles. An ndarray of floats with shape `[n, spatial_dimension]`. force: The current force on particles. An ndarray of floats with shape `[n, spatial_dimension]`. mass: The mass of particles. A float or an ndarray of floats with shape `[n]`. dt: A float specifying the current step size. alpha: A float specifying the current FIRE mixing parameter. n_pos: The number of consecutive steps with positive power. """ position: Array momentum: Array force: Array mass: Array | float dt: Array | float alpha: Array | float n_pos: Array | int
[docs] def fire_descent( energy_or_force: Callable[..., Array], shift_fn: ShiftFn, dt_start: float = 0.1, dt_max: float = 0.4, n_min: float = 5, f_inc: float = 1.1, f_dec: float = 0.5, alpha_start: float = 0.1, f_alpha: float = 0.99, ) -> Minimizer[FireDescentState]: """Defines FIRE minimization. This code implements the "Fast Inertial Relaxation Engine" from Bitzek et al. [#bitzek]_ Args: energy_or_force: A function that produces either an energy or a force from a set of particle positions specified as an ndarray of shape `[n, spatial_dimension]`. shift_fn: A function that displaces positions `R`, by an amount `dR`. Both `R` and `dR` should be ndarrays of shape `[n, spatial_dimension]`. dt_start: The initial step size during minimization as a float. dt_max: The maximum step size during minimization as a float. n_min: An integer specifying the minimum number of steps moving in the correct direction before dt and f_alpha should be updated. f_inc: A float specifying the fractional rate by which the step size should be increased. f_dec: A float specifying the fractional rate by which the step size should be decreased. alpha_start: A float specifying the initial momentum. f_alpha: A float specifying the fractional change in momentum. Returns: See above. .. rubric:: References .. [#bitzek] Bitzek, Erik, Pekka Koskinen, Franz Gahler, Michael Moseler, and Peter Gumbsch. "Structural relaxation made simple." Physical review letters 97, no. 17 (2006): 170201. """ dt_start, dt_max, n_min, f_inc, f_dec, alpha_start, f_alpha = ( util.static_cast( dt_start, dt_max, n_min, f_inc, f_dec, alpha_start, f_alpha ) ) nve_init_fn, nve_step_fn = simulate.nve(energy_or_force, shift_fn, dt_start) force = quantity.canonicalize_force(energy_or_force) def init_fn( R: PyTree, mass: Array | float = 1.0, **kwargs ) -> FireDescentState: P = tree_map(lambda x: jnp.zeros_like(x), R) n_pos = jnp.zeros((), jnp.int32) F = force(R, **kwargs) state = FireDescentState( R, P, F, mass, dt_start, alpha_start, n_pos ) # pytype: disable=wrong-arg-count return simulate.canonicalize_mass(state) def apply_fn(state: FireDescentState, **kwargs) -> FireDescentState: state = nve_step_fn(state, dt=state.dt, **kwargs) R, P, F, M, dt, alpha, n_pos = dataclasses.unpack(state) # NOTE(schsam): This will be wrong if F_norm ~< 1e-8. # TODO(schsam): We should check for forces below 1e-6. @ErrorChecking F_norm = jnp.sqrt( tree_reduce(lambda accum, f: accum + jnp.sum(f**2) + 1e-6, F, 0.0) ) P_norm = jnp.sqrt( tree_reduce(lambda accum, p: accum + jnp.sum(p**2), P, 0.0) ) # NOTE: In the original FIRE algorithm, the quantity that determines when # to reset the momenta is F.V rather than F.P. However, all of the JAX MD # simulations are in momentum space for easier agreement with prior work / # rigid body physics. We only use the sign of F.P here, which shouldn't # differ from F.V, however if there are regressions then we should # reconsider this choice. F_dot_P = tree_reduce( lambda accum, f_dot_p: accum + f_dot_p, tree_map(lambda f, p: jnp.sum(f * p), F, P), ) P = tree_map(lambda p, f: p + alpha * (f * P_norm / F_norm - p), P, F) # NOTE(schsam): Can we clean this up at all? n_pos = jnp.where(F_dot_P >= 0, n_pos + 1, 0) dt_choice = jnp.array([dt * f_inc, dt_max]) dt = jnp.where( F_dot_P > 0, jnp.where(n_pos > n_min, jnp.min(dt_choice), dt), dt ) dt = jnp.where(F_dot_P < 0, dt * f_dec, dt) alpha = jnp.where( F_dot_P > 0, jnp.where(n_pos > n_min, alpha * f_alpha, alpha), alpha ) alpha = jnp.where(F_dot_P < 0, alpha_start, alpha) P = tree_map(lambda p: (F_dot_P >= 0) * p, P) return FireDescentState( R, P, F, M, dt, alpha, n_pos ) # pytype: disable=wrong-arg-count return init_fn, apply_fn
@dataclasses.dataclass class PreconFireDescentState: """State for the preconditioned FIRE minimizer. Attributes: position: The current position of particles. An ndarray of floats with shape `[n, spatial_dimension]`. velocity: The current optimizer velocity. An ndarray of floats with shape `[n, spatial_dimension]`. force: The current raw force on particles. An ndarray of floats with shape `[n, spatial_dimension]`. dt: A float specifying the current step size. alpha: A float specifying the current FIRE mixing parameter. n_pos: The number of consecutive steps with positive power. initialized: Whether the first-step velocity initialization has happened. preconditioner_position: Reference positions for cached graph preconditioners. preconditioner_previous_position: Previous positions used for preconditioner rebuild checks. preconditioner_previous_initialized: Whether ``preconditioner_previous_position`` has been initialized. momentum: Alias for ``velocity`` for compatibility with momentum-based minimizer states. """ position: Array velocity: Array force: Array dt: Array | float alpha: Array | float n_pos: Array | int initialized: Array | bool preconditioner_position: PyTree preconditioner_previous_position: PyTree preconditioner_previous_initialized: Array | bool @property def momentum(self) -> Array: return self.velocity def exp_preconditioner( displacement_fn: space.DisplacementFn, r_cut: Array | float | None = None, r_NN: Array | float | None = None, A: float = 3.0, mu: float = 1.0, c_stab: float = 0.1, solve_tol: float = 1e-5, maxiter: int | None = None, reference_position: Array | None = None, solver: str = 'cg', ) -> Tuple[Callable[..., Array], Callable[..., Array]]: """Builds exponential graph preconditioner callables. This implements the position-only universal preconditioner from Packwood et al., J. Chem. Phys. 144, 164109 (2016), in a matrix-free, JAX-friendly form. The scalar atom graph is expanded over Cartesian components as ``P = L kron I_d``. Args: displacement_fn: Function returning pair displacements. r_cut: Neighbor cutoff. If ``None``, uses ``2 * r_NN``. r_NN: Nearest-neighbor distance. If ``None``, estimated from ``R``. A: Exponential decay parameter. ``A=0`` gives constant neighbor weights. mu: Energy scale multiplying the preconditioner. c_stab: Diagonal stabilization coefficient. solve_tol: Conjugate-gradient solve tolerance. maxiter: Optional maximum CG iterations. reference_position: Optional positions used to build a fixed preconditioner. If ``None``, the graph is rebuilt from the current ``R`` on every call. solver: Linear solver to use. ``'cg'`` is matrix-free; ``'dense'`` builds the scalar atom matrix and uses a direct dense solve. Returns: ``(preconditioner, preconditioner_dot)`` callables suitable for :func:`precon_fire_descent`. """ def pair_distances(R: Array, **kwargs) -> Array: d = jax.vmap( jax.vmap( lambda Ra, Rb: displacement_fn(Ra, Rb, **kwargs), in_axes=(None, 0) ), in_axes=(0, None), )(R, R) return jnp.sqrt(jnp.sum(d**2, axis=-1)) def estimate_r_NN(dist: Array) -> Array: N = dist.shape[0] dist_no_self = jnp.where(jnp.eye(N, dtype=bool), jnp.inf, dist) return jnp.max(jnp.min(dist_no_self, axis=1)) def graph_weights(R: Array, **kwargs) -> Array: R_graph = kwargs.get('preconditioner_position', reference_position) R_graph = R if R_graph is None else R_graph N = R_graph.shape[0] A_value = jnp.asarray(A, dtype=R_graph.dtype) if r_NN is None: r_NN_value = estimate_r_NN(pair_distances(R_graph, **kwargs)) else: r_NN_value = jnp.asarray(r_NN, dtype=R_graph.dtype) r_cut_value = ( 2.0 * r_NN_value if r_cut is None else jnp.asarray(r_cut, dtype=R_graph.dtype) ) neighbor = kwargs.get('neighbor') if neighbor is not None: if hasattr(neighbor, 'shifts'): if neighbor.format is partition.Dense: idx = neighbor.idx senders = idx.reshape(-1) receivers = jnp.repeat(jnp.arange(N), idx.shape[1]) shifts = neighbor.shifts.reshape((-1, R_graph.shape[-1])) else: receivers, senders = neighbor.idx shifts = neighbor.shifts valid = jnp.logical_and(receivers < N, senders < N) send_safe = jnp.clip(senders, 0, N - 1) recv_safe = jnp.clip(receivers, 0, N - 1) shift_cart = shifts.astype(R_graph.dtype) @ neighbor.box.T dR = R_graph[send_safe] + shift_cart - R_graph[recv_safe] dist = jnp.sqrt(jnp.sum(dR**2, axis=-1)) weights_edge = jnp.exp(-A_value * (dist / r_NN_value - 1.0)) weights_edge = jnp.where(valid, weights_edge, 0.0) weights = ( jnp.zeros((N, N), dtype=R_graph.dtype) .at[recv_safe, send_safe] .add(weights_edge) ) if neighbor.format is partition.OrderedSparse: weights = weights.at[send_safe, recv_safe].add(weights_edge) return weights if neighbor.format is partition.Dense: senders = neighbor.idx.reshape(-1) receivers = jnp.repeat(jnp.arange(N), neighbor.idx.shape[1]) else: receivers, senders = neighbor.idx valid = jnp.logical_and(receivers < N, senders < N) send_safe = jnp.clip(senders, 0, N - 1) recv_safe = jnp.clip(receivers, 0, N - 1) dR = jax.vmap( lambda r_recv, r_send: displacement_fn(r_recv, r_send, **kwargs) )(R_graph[recv_safe], R_graph[send_safe]) dist = jnp.sqrt(jnp.sum(dR**2, axis=-1)) weights_edge = jnp.exp(-A_value * (dist / r_NN_value - 1.0)) weights_edge = jnp.where(valid, weights_edge, 0.0) weights = ( jnp.zeros((N, N), dtype=R_graph.dtype) .at[recv_safe, send_safe] .add(weights_edge) ) if neighbor.format is partition.OrderedSparse: weights = weights.at[send_safe, recv_safe].add(weights_edge) return weights dist = pair_distances(R_graph, **kwargs) mask = jnp.logical_and(dist < r_cut_value, ~jnp.eye(N, dtype=bool)) weights = jnp.exp(-A_value * (dist / r_NN_value - 1.0)) return jnp.where(mask, weights, 0.0) def graph_matrix(R: Array, **kwargs) -> Array: weights = graph_weights(R, **kwargs) degree = jnp.sum(weights, axis=1) R_graph = kwargs.get('preconditioner_position', reference_position) R_graph = R if R_graph is None else R_graph mu_value = jnp.asarray(mu, dtype=R_graph.dtype) c_stab_value = jnp.asarray(c_stab, dtype=R_graph.dtype) return mu_value * (jnp.diag(degree + c_stab_value) - weights) def graph_matvec(R: Array, X: Array, **kwargs) -> Array: weights = graph_weights(R, **kwargs) degree = jnp.sum(weights, axis=1) mu_value = jnp.asarray(mu, dtype=X.dtype) c_stab_value = jnp.asarray(c_stab, dtype=X.dtype) return mu_value * ((degree[:, None] + c_stab_value) * X - weights @ X) def preconditioner(R: Array, F: Array, **kwargs) -> Array: if solver == 'dense': return jnp.linalg.solve(graph_matrix(R, **kwargs), F) solve = lambda X: graph_matvec(R, X, **kwargs) return cg(solve, F, tol=solve_tol, maxiter=maxiter)[0] def preconditioner_dot(R: Array, X: Array, Y: Array, **kwargs) -> Array: return jnp.sum(X * graph_matvec(R, Y, **kwargs)) return preconditioner, preconditioner_dot def c1_preconditioner( displacement_fn: space.DisplacementFn, r_cut: Array | float | None = None, r_NN: Array | float | None = None, mu: float = 1.0, c_stab: float = 0.1, solve_tol: float = 1e-5, maxiter: int | None = None, reference_position: Array | None = None, solver: str = 'cg', ) -> Tuple[Callable[..., Array], Callable[..., Array]]: """Builds C1 graph preconditioner callables. This is the constant-weight special case of :func:`exp_preconditioner`. """ return exp_preconditioner( displacement_fn, r_cut=r_cut, r_NN=r_NN, A=0.0, mu=mu, c_stab=c_stab, solve_tol=solve_tol, maxiter=maxiter, reference_position=reference_position, solver=solver, ) def estimate_exp_mu( energy_or_force: Callable[..., Array], displacement_fn: space.DisplacementFn, R: Array, r_cut: Array | float | None = None, r_NN: Array | float | None = None, A: float = 3.0, c_stab: float = 0.1, min_mu: float = 1.0, **kwargs, ) -> Array: """Estimate the Exp preconditioner energy scale. The estimate matches the finite-difference curvature equation from Packwood et al., J. Chem. Phys. 144, 164109 (2016): ``(grad E(R + V) - grad E(R)) . V = mu * V^T P_{mu=1} V``. Args: energy_or_force: Function producing either an energy or a force. displacement_fn: Function returning pair displacements. R: Reference positions. r_cut: Neighbor cutoff. If ``None``, uses ``2 * r_NN``. r_NN: Nearest-neighbor distance. If ``None``, estimated from ``R``. A: Exponential decay parameter for the unit preconditioner. c_stab: Diagonal stabilization coefficient. min_mu: Lower bound applied to the estimate, matching ASE's cap at 1. **kwargs: Extra arguments forwarded to ``energy_or_force`` and ``displacement_fn``. Returns: Scalar estimate for ``mu``. """ force = quantity.canonicalize_force(energy_or_force) def pair_distances(R: Array) -> Array: d = jax.vmap( jax.vmap( lambda Ra, Rb: displacement_fn(Ra, Rb, **kwargs), in_axes=(None, 0) ), in_axes=(0, None), )(R, R) return jnp.sqrt(jnp.sum(d**2, axis=-1)) def estimate_r_NN_from_R(R: Array) -> Array: dist = pair_distances(R) N = dist.shape[0] dist_no_self = jnp.where(jnp.eye(N, dtype=bool), jnp.inf, dist) return jnp.max(jnp.min(dist_no_self, axis=1)) r_NN_value = estimate_r_NN_from_R(R) if r_NN is None else r_NN amplitude = 1e-2 * r_NN_value L = jnp.max(R, axis=0) - jnp.min(R, axis=0) safe_L = jnp.where(L == 0, 1.0, L) mode = jnp.where(L == 0, 0.0, jnp.sin(R / safe_L)) V = amplitude * mode _, unit_dot = exp_preconditioner( displacement_fn, r_cut=r_cut, r_NN=r_NN_value, A=A, mu=1.0, c_stab=c_stab, ) F = force(R, **kwargs) F_plus = force(R + V, **kwargs) lhs = jnp.sum((F - F_plus) * V) rhs = unit_dot(R, V, V, **kwargs) return jnp.maximum(lhs / rhs, min_mu) def precon_fire_descent( energy_or_force: Callable[..., Array], shift_fn: ShiftFn, dt_start: float = 0.1, dt_max: float = 1.0, max_move: float = 0.2, n_min: float = 5, f_inc: float = 1.1, f_dec: float = 0.5, alpha_start: float = 0.1, f_alpha: float = 0.99, theta: float = 0.1, armijo_tol: float = 0.0, use_armijo: bool = True, preconditioner_update_threshold: float | None = None, preconditioner: Callable[..., PyTree] | None = None, preconditioner_dot: Callable[..., Array] | None = None, ) -> Minimizer[PreconFireDescentState]: """Defines preconditioned FIRE minimization. This optimizer uses the graph-metric preconditioning idea of Packwood et al., J. Chem. Phys. 144, 164109 (2016). Unlike :func:`fire_descent`, it stores velocity rather than momentum and applies the preconditioned force direction directly. Args: energy_or_force: A function that produces either an energy or a force from a set of particle positions specified as an ndarray of shape `[n, spatial_dimension]`. shift_fn: A function that displaces positions `R`, by an amount `dR`. Both `R` and `dR` should be ndarrays of shape `[n, spatial_dimension]`. dt_start: The initial step size during minimization as a float. dt_max: The maximum step size during minimization as a float. max_move: The maximum Euclidean norm of the displacement in one step. n_min: An integer specifying the minimum number of positive-power steps before ``dt`` and ``alpha`` should be updated. f_inc: A float specifying the fractional rate by which the step size should be increased. f_dec: A float specifying the fractional rate by which the step size should be decreased. alpha_start: A float specifying the initial FIRE mixing parameter. f_alpha: A float specifying the fractional change in ``alpha``. theta: Armijo sufficient-decrease parameter. armijo_tol: Numerical tolerance for the Armijo comparison. Trials within this tolerance of the sufficient-decrease boundary are rejected. The default ``0.0`` matches ASE's strict comparison. use_armijo: Whether to use ASE's Armijo trial-step rejection. If ``True``, ``energy_or_force`` must be an energy function. preconditioner_update_threshold: Optional maximum absolute displacement threshold for updating cached preconditioner reference positions. ASE uses ``0.5 * r_NN``. preconditioner: Optional callable that returns the preconditioned force direction ``H^{-1} F``. It should take ``(R, F, **kwargs)`` and return a PyTree with the same structure as ``F``. preconditioner_dot: Optional callable for the preconditioner metric. When ``preconditioner`` is supplied, this must take ``(R, X, Y, **kwargs)`` and return ``X^T H Y``. Returns: See above. """ if preconditioner is not None and preconditioner_dot is None: raise ValueError( 'preconditioner_dot must be supplied when preconditioner is supplied.' ) force = quantity.canonicalize_force(energy_or_force) def tree_dot(X: PyTree, Y: PyTree) -> Array | float: return tree_reduce( lambda accum, x_dot_y: accum + x_dot_y, tree_map(lambda x, y: jnp.sum(x * y), X, Y), 0.0, ) def preconditioned_force(R: PyTree, F: PyTree, **kwargs) -> PyTree: if preconditioner is None: return F return preconditioner(R, F, **kwargs) def velocity_metric_dot( R: PyTree, X: PyTree, Y: PyTree, **kwargs ) -> Array | float: if preconditioner is None: return tree_dot(X, Y) metric = preconditioner_dot if metric is None: raise ValueError( 'preconditioner_dot must be supplied when preconditioner is supplied.' ) return metric(R, X, Y, **kwargs) def shift_position(R: PyTree, dR: PyTree, **kwargs) -> PyTree: s_fn = shift_fn if isinstance(s_fn, Callable): s_fn = tree_map(lambda r: shift_fn, R) return tree_map(lambda s, r, dr: s(r, dr, **kwargs), s_fn, R, dR) def init_fn(R: PyTree, **kwargs) -> PreconFireDescentState: dtype = tree_leaves(R)[0].dtype V = tree_map(lambda x: jnp.zeros_like(x), R) F = force(R, **kwargs) return PreconFireDescentState( R, V, F, jnp.asarray(dt_start, dtype=dtype), jnp.asarray(alpha_start, dtype=dtype), jnp.zeros((), jnp.int32), jnp.array(False), R, R, jnp.array(False), ) # pytype: disable=wrong-arg-count def apply_fn( state: PreconFireDescentState, **kwargs ) -> PreconFireDescentState: ( R, V, F, dt, alpha, n_pos, initialized, precon_R, precon_previous_R, precon_previous_initialized, ) = dataclasses.unpack(state) def tree_max_abs(X: PyTree) -> Array | float: return tree_reduce( lambda accum, x: jnp.maximum(accum, jnp.max(jnp.abs(x))), X, 0.0, ) if preconditioner_update_threshold is not None: max_old_disp = tree_max_abs( tree_map(lambda r, r0: r - r0, R, precon_previous_R) ) update_precon = jnp.logical_and( initialized, jnp.logical_and( precon_previous_initialized, max_old_disp >= preconditioner_update_threshold, ), ) precon_R = tree_map( lambda r, r0: jnp.where(update_precon, r, r0), R, precon_R ) set_old = initialized precon_previous_R = tree_map( lambda r, r0: jnp.where(set_old, r, r0), R, precon_previous_R ) precon_previous_initialized = jnp.logical_or( precon_previous_initialized, initialized ) precon_kwargs = dict(kwargs) precon_kwargs['preconditioner_position'] = precon_R G = preconditioned_force(R, F, **precon_kwargs) def finish_step(V_bar, dt_bar, alpha_bar, n_pos_bar): V_new = tree_map(lambda v, g: v + dt_bar * g, V_bar, G) dR_raw = tree_map(lambda v: dt_bar * v, V_new) dR_norm = jnp.sqrt(tree_dot(dR_raw, dR_raw)) dR_scale = jnp.where(dR_norm > max_move, max_move / dR_norm, 1.0) dR = tree_map(lambda dr: dR_scale * dr, dR_raw) R_new = shift_position(R, dR, **kwargs) F_new = force(R_new, **kwargs) return PreconFireDescentState( R_new, V_new, F_new, dt_bar, alpha_bar, n_pos_bar, jnp.array(True), precon_R, precon_previous_R, precon_previous_initialized, ) # pytype: disable=wrong-arg-count F_dot_V = tree_dot(F, V) is_first = jnp.logical_not(initialized) is_positive = jnp.logical_and(initialized, F_dot_V > 0) V_zero = tree_map(jnp.zeros_like, V) V_norm = jnp.sqrt(velocity_metric_dot(R, V, V, **precon_kwargs)) F_norm = jnp.sqrt(tree_dot(F, G)) V_positive = tree_map( lambda v, g: (1.0 - alpha) * v + alpha * (V_norm / F_norm) * g, V, G, ) grow = n_pos > n_min dt_choice = jnp.array([dt * f_inc, dt_max]) dt_positive = jnp.where(grow, jnp.min(dt_choice), dt) alpha_positive = jnp.where(grow, alpha * f_alpha, alpha) n_pos_positive = n_pos + 1 V_bar = tree_map( lambda v_pos, v_zero: jnp.where(is_positive, v_pos, v_zero), V_positive, V_zero, ) dt_bar = jnp.where(is_positive, dt_positive, dt * f_dec) alpha_bar = jnp.where(is_positive, alpha_positive, alpha_start) n_pos_bar = jnp.where(is_positive, n_pos_positive, jnp.zeros((), jnp.int32)) V_bar = tree_map( lambda v_first, v_later: jnp.where(is_first, v_first, v_later), V_zero, V_bar, ) dt_bar = jnp.where(is_first, dt, dt_bar) alpha_bar = jnp.where(is_first, alpha, alpha_bar) n_pos_bar = jnp.where(is_first, n_pos, n_pos_bar) armijo_fail = jnp.array(False) if use_armijo: V_test = tree_map(lambda v, g: v + dt * g, V, G) dR_test = tree_map(lambda v: dt * v, V_test) R_test = shift_position(R, dR_test, **kwargs) E = energy_or_force(R, **kwargs) E_test = energy_or_force(R_test, **kwargs) armijo_rhs = E - theta * dt * tree_dot(V_test, F) failed_decrease = E_test > armijo_rhs - armijo_tol armijo_fail = jnp.logical_and( initialized, jnp.logical_or(failed_decrease, ~jnp.isfinite(E_test)) ) V_bar = tree_map( lambda v: jnp.where(armijo_fail, jnp.zeros_like(v), v), V_bar ) dt_bar = jnp.where(armijo_fail, dt * f_dec, dt_bar) alpha_bar = jnp.where(armijo_fail, alpha_start, alpha_bar) n_pos_bar = jnp.where(armijo_fail, jnp.zeros((), jnp.int32), n_pos_bar) return finish_step(V_bar, dt_bar, alpha_bar, n_pos_bar) return init_fn, apply_fn # Box optimization @dataclasses.dataclass class FireBoxDescentState: """State for the combined atom + box FIRE minimizer. Box degrees of freedom are parameterized by the deformation gradient ``F`` relative to a reference box, following Tadmor et al. (1999). The current box is reconstructed as ``F @ reference_box``. The ``box_factor`` (default ``N``) scales the deformation-gradient DOFs to be comparable to atomic positions, acting as a preconditioner. Attributes: position: Atomic fractional positions, shape ``(N, dim)``. momentum: Atomic momenta, shape ``(N, dim)``. force: Atomic forces, shape ``(N, dim)``. mass: Atomic masses. box: Current simulation box (``F @ reference_box``), shape ``(dim, dim)``. reference_box: Initial box (constant), shape ``(dim, dim)``. box_position: ``box_factor * F``, shape ``(dim, dim)``. box_momentum: Momentum in deformation-gradient space, ``(dim, dim)``. box_force: Force in deformation-gradient space, ``(dim, dim)``. box_mass: Box mass (scalar). box_factor: Scaling preconditioner (scalar, default ``N``). dt: Current FIRE step size. alpha: Current FIRE momentum-mixing parameter. n_pos: Number of consecutive steps with positive power. """ position: Array momentum: Array force: Array mass: Array | float box: Array reference_box: Array box_position: Array box_momentum: Array box_force: Array box_mass: Array box_factor: Array dt: Array | float alpha: Array | float n_pos: Array | int def fire_descent_box( energy_fn: Callable[..., Array], shift_fn: ShiftFn, dt_start: float = 0.1, dt_max: float = 0.4, n_min: float = 5, f_inc: float = 1.1, f_dec: float = 0.5, alpha_start: float = 0.1, f_alpha: float = 0.99, scalar_pressure: float = 0.0, hydrostatic_strain: bool = False, constant_volume: bool = False, mask: Array | None = None, ) -> Minimizer[FireBoxDescentState]: """FIRE minimization of both atomic positions and the simulation box. Args: energy_fn: Energy function taking ``(R, box=box, **kwargs)``. shift_fn: Shift function from :func:`~jax_md.space.periodic_general` (``fractional_coordinates=True``). dt_start: Initial FIRE step size. dt_max: Maximum FIRE step size. n_min: Minimum positive-power steps before increasing dt. f_inc: Factor to increase dt. f_dec: Factor to decrease dt on overshoot. alpha_start: Initial FIRE mixing parameter. f_alpha: Factor to decrease alpha. scalar_pressure: Target external pressure. hydrostatic_strain: Constrain box to isotropic deformation. constant_volume: Project out volume changes. mask: Strain-component mask, ``(dim, dim)``. Returns: ``(init_fn, apply_fn)`` pair. """ dt_start, dt_max, n_min, f_inc, f_dec, alpha_start, f_alpha = ( util.static_cast( dt_start, dt_max, n_min, f_inc, f_dec, alpha_start, f_alpha ) ) force_fn = quantity.canonicalize_force(energy_fn) def virial_fn(R, box, **kwargs): """Compute the constrained virial for the box force. The virial is ``W = -dE/deps - P_ext * V * I`` where ``dE/deps`` is the strain derivative obtained via autodiff. The sign follows JAX-MD's stress convention (``quantity.stress`` returns ``(1/V)(-dU/deps)`` with positive = tension). ``-dE/deps`` points in the direction of energy decrease, matching the simple ``cell_step`` approach (``box += dt * stress``). Optional constraint projections (hydrostatic, mask, constant-volume) are applied before returning. """ dim = R.shape[1] I_d = jnp.eye(dim, dtype=box.dtype) zero = jnp.zeros((dim, dim), dtype=box.dtype) def U(eps): return energy_fn(R, box=box, perturbation=(I_d + eps), **kwargs) dUdeps = jax.grad(U)(zero) v = -dUdeps if scalar_pressure != 0.0: vol = quantity.volume(dim, box) v = v - scalar_pressure * vol * I_d if hydrostatic_strain: tr = jnp.trace(v) v = jnp.eye(dim, dtype=v.dtype) * (tr / dim) if mask is not None: v = v * jnp.asarray(mask, dtype=v.dtype) if constant_volume: tr = jnp.trace(v) v = v - jnp.eye(dim, dtype=v.dtype) * (tr / dim) return v def box_force_fn(virial, F_deform, bf): """Transform virial to deformation-gradient space and scale. Maps the virial from box-space to F-space via the Jacobian ``F^{-T}``: ``virial_F = solve(F, virial^T)^T``. This matches ASE ``UnitCellFilter.get_forces`` and TorchSim ``compute_cell_forces``. At ``F = I`` the transformation is identity. """ W_F = jnp.linalg.solve(F_deform, virial.T).T return W_F / bf def init_fn( R: Array, box: Array, mass: Array | float = 1.0, box_mass: Array | None = None, box_factor: Array | None = None, **kwargs, ) -> FireBoxDescentState: N = R.shape[0] dim = R.shape[1] if box_mass is None: box_mass = f32(1.0) if box_factor is None: box_factor = f32(N) F_atom = force_fn(R, box=box, **kwargs) I_d = jnp.eye(dim, dtype=box.dtype) virial = virial_fn(R, box, **kwargs) F_b = box_force_fn(virial, I_d, box_factor) state = FireBoxDescentState( R, jnp.zeros_like(R), F_atom, mass, box, box, box_factor * I_d, jnp.zeros_like(box), F_b, box_mass, box_factor, dt_start, alpha_start, jnp.zeros((), jnp.int32), ) # pytype: disable=wrong-arg-count return simulate.canonicalize_mass(state) def apply_fn( state: FireBoxDescentState, **kwargs, ) -> FireBoxDescentState: (R, P, F, M, _, ref_box, X_b, P_b, F_b, M_b, bf, dt, alpha, n_pos) = ( dataclasses.unpack(state) ) dt_2 = f32(dt / 2) # TODO(ag): Add different integration schemes following: # https://doi.org/10.1016/j.commatsci.2020.109584 # velocity Verlet: half-step momenta P = P + dt_2 * F P_b = P_b + dt_2 * F_b # position update F_deform = X_b / bf box = F_deform @ ref_box R = shift_fn(R, dt * P / M, box=box, **kwargs) X_b = X_b + dt * P_b / M_b F_deform = X_b / bf box = F_deform @ ref_box # recompute forces at new (R, box) F = force_fn(R, box=box, **kwargs) virial = virial_fn(R, box, **kwargs) F_b = box_force_fn(virial, F_deform, bf) # velocity Verlet: half-step momenta P = P + dt_2 * F P_b = P_b + dt_2 * F_b # FIRE: combined power check F_dot_P = jnp.sum(F * P) + jnp.sum(F_b * P_b) F_norm = jnp.sqrt(jnp.sum(F**2) + jnp.sum(F_b**2) + 1e-6) P_norm = jnp.sqrt(jnp.sum(P**2) + jnp.sum(P_b**2)) # FIRE: momentum mixing P = P + alpha * (F * P_norm / F_norm - P) P_b = P_b + alpha * (F_b * P_norm / F_norm - P_b) # FIRE: adaptive dt / alpha n_pos = jnp.where(F_dot_P >= 0, n_pos + 1, 0) dt_choice = jnp.array([dt * f_inc, dt_max]) dt = jnp.where( F_dot_P > 0, jnp.where(n_pos > n_min, jnp.min(dt_choice), dt), dt, ) dt = jnp.where(F_dot_P < 0, dt * f_dec, dt) alpha = jnp.where( F_dot_P > 0, jnp.where(n_pos > n_min, alpha * f_alpha, alpha), alpha, ) alpha = jnp.where(F_dot_P < 0, alpha_start, alpha) # reset momenta on overshoot keep = f32(F_dot_P >= 0) P = keep * P P_b = keep * P_b return FireBoxDescentState( R, P, F, M, box, ref_box, X_b, P_b, F_b, M_b, bf, dt, alpha, n_pos, ) # pytype: disable=wrong-arg-count return init_fn, apply_fn @dataclasses.dataclass class PreconFireBoxDescentState: """State for the atom + box preconditioned FIRE minimizer.""" position: Array velocity: Array force: Array box: Array reference_box: Array box_position: Array box_velocity: Array box_force: Array box_factor: Array dt: Array | float alpha: Array | float n_pos: Array | int initialized: Array | bool preconditioner_position: PyTree preconditioner_previous_position: PyTree preconditioner_previous_initialized: Array | bool @property def momentum(self) -> Array: return self.velocity @property def box_momentum(self) -> Array: return self.box_velocity def precon_fire_descent_box( energy_fn: Callable[..., Array], shift_fn: ShiftFn, dt_start: float = 0.1, dt_max: float = 1.0, max_move: float = 0.2, n_min: float = 5, f_inc: float = 1.1, f_dec: float = 0.5, alpha_start: float = 0.1, f_alpha: float = 0.99, theta: float = 0.1, armijo_tol: float = 0.0, use_armijo: bool = True, preconditioner: Callable[..., PyTree] | None = None, preconditioner_dot: Callable[..., Array] | None = None, cell_preconditioner: Array | float = 1.0, preconditioner_update_threshold: float | None = None, scalar_pressure: float = 0.0, hydrostatic_strain: bool = False, constant_volume: bool = False, mask: Array | None = None, ) -> Minimizer[PreconFireBoxDescentState]: """Preconditioned FIRE minimization of atoms and the simulation box.""" del shift_fn if preconditioner is not None and preconditioner_dot is None: raise ValueError( 'preconditioner_dot must be supplied when preconditioner is supplied.' ) force_fn = quantity.canonicalize_force(energy_fn) def tree_dot(X: PyTree, Y: PyTree) -> Array | float: return tree_reduce( lambda accum, x_dot_y: accum + x_dot_y, tree_map(lambda x, y: jnp.sum(x * y), X, Y), 0.0, ) def preconditioned_force(R: PyTree, F: PyTree, **kwargs) -> PyTree: if preconditioner is None: return F return preconditioner(R, F, **kwargs) def velocity_metric_dot( R: PyTree, X: PyTree, Y: PyTree, **kwargs ) -> Array | float: if preconditioner is None: return tree_dot(X, Y) metric = preconditioner_dot if metric is None: raise ValueError( 'preconditioner_dot must be supplied when preconditioner is supplied.' ) return metric(R, X, Y, **kwargs) def tree_max_abs(X: PyTree) -> Array | float: return tree_reduce( lambda accum, x: jnp.maximum(accum, jnp.max(jnp.abs(x))), X, 0.0, ) def virial_fn(R, box, **kwargs): dim = R.shape[1] I_d = jnp.eye(dim, dtype=box.dtype) zero = jnp.zeros((dim, dim), dtype=box.dtype) def U(eps): return energy_fn(R, box=box, perturbation=(I_d + eps), **kwargs) dUdeps = jax.grad(U)(zero) v = -dUdeps v = (v + v.T) / 2.0 if scalar_pressure != 0.0: vol = quantity.volume(dim, box) v = v - scalar_pressure * vol * I_d if hydrostatic_strain: tr = jnp.trace(v) v = jnp.eye(dim, dtype=v.dtype) * (tr / dim) if mask is not None: v = v * jnp.asarray(mask, dtype=v.dtype) if constant_volume: tr = jnp.trace(v) v = v - jnp.eye(dim, dtype=v.dtype) * (tr / dim) return v def actual_box(ref_box, F_deform): return ref_box @ F_deform.T def actual_position(q_R, F_deform): return q_R @ F_deform.T def generalized_position(R, F_deform): return jnp.linalg.solve(F_deform, R.T).T def generalized_force(F_real, F_deform): return F_real @ F_deform def box_force_fn(virial, F_deform, bf): W_F = jnp.linalg.solve(F_deform, virial.T).T return W_F / bf def init_fn( R: Array, box: Array, box_factor: Array | None = None, **kwargs, ) -> PreconFireBoxDescentState: dtype = tree_leaves(R)[0].dtype N = R.shape[0] dim = R.shape[1] if box_factor is None: box_factor = jnp.asarray(N, dtype=dtype) I_d = jnp.eye(dim, dtype=box.dtype) F_deform = I_d q_R = generalized_position(R, F_deform) R_actual = actual_position(q_R, F_deform) F_real = force_fn(R_actual, box=box, **kwargs) F = generalized_force(F_real, F_deform) F_b = box_force_fn(virial_fn(R_actual, box, **kwargs), F_deform, box_factor) return PreconFireBoxDescentState( # type: ignore[call-arg] q_R, # type: ignore[call-arg] tree_map(jnp.zeros_like, q_R), F, box, box, box_factor * I_d, jnp.zeros_like(box), F_b, box_factor, jnp.asarray(dt_start, dtype=dtype), jnp.asarray(alpha_start, dtype=dtype), jnp.zeros((), jnp.int32), jnp.array(False), R, R, jnp.array(False), ) # pytype: disable=wrong-arg-count def apply_fn( state: PreconFireBoxDescentState, **kwargs ) -> PreconFireBoxDescentState: ( R, V, F, _, ref_box, X_b, V_b, F_b, bf, dt, alpha, n_pos, initialized, precon_R, precon_previous_R, precon_previous_initialized, ) = dataclasses.unpack(state) F_deform = X_b / bf box = actual_box(ref_box, F_deform) R_actual = actual_position(R, F_deform) if preconditioner_update_threshold is not None: max_old_disp = tree_max_abs( tree_map(lambda r, r0: r - r0, R, precon_previous_R) ) update_precon = jnp.logical_and( initialized, jnp.logical_and( precon_previous_initialized, max_old_disp >= preconditioner_update_threshold, ), ) precon_R = tree_map( lambda r, r0: jnp.where(update_precon, r, r0), R, precon_R ) precon_previous_R = tree_map( lambda r, r0: jnp.where(initialized, r, r0), R, precon_previous_R ) precon_previous_initialized = jnp.logical_or( precon_previous_initialized, initialized ) precon_kwargs = dict(kwargs) precon_kwargs['box'] = box precon_kwargs['preconditioner_position'] = precon_R G = preconditioned_force(R, F, **precon_kwargs) cell_precon = jnp.asarray(cell_preconditioner, dtype=F_b.dtype) G_b = F_b if preconditioner is None else F_b / cell_precon F_dot_V = tree_dot(F, V) + jnp.sum(F_b * V_b) is_first = jnp.logical_not(initialized) is_positive = jnp.logical_and(initialized, F_dot_V > 0) V_zero = tree_map(jnp.zeros_like, V) V_b_zero = jnp.zeros_like(V_b) box_velocity_norm = ( jnp.sum(V_b**2) if preconditioner is None else cell_precon * jnp.sum(V_b**2) ) V_norm = jnp.sqrt( velocity_metric_dot(R, V, V, **precon_kwargs) + box_velocity_norm ) F_norm = jnp.sqrt(tree_dot(F, G) + jnp.sum(F_b * G_b)) V_positive = tree_map( lambda v, g: (1.0 - alpha) * v + alpha * (V_norm / F_norm) * g, V, G, ) V_b_positive = (1.0 - alpha) * V_b + alpha * (V_norm / F_norm) * G_b grow = n_pos > n_min dt_choice = jnp.array([dt * f_inc, dt_max]) dt_positive = jnp.where(grow, jnp.min(dt_choice), dt) alpha_positive = jnp.where(grow, alpha * f_alpha, alpha) n_pos_positive = n_pos + 1 V_bar = tree_map( lambda v_pos, v_zero: jnp.where(is_positive, v_pos, v_zero), V_positive, V_zero, ) V_b_bar = jnp.where(is_positive, V_b_positive, V_b_zero) dt_bar = jnp.where(is_positive, dt_positive, dt * f_dec) alpha_bar = jnp.where(is_positive, alpha_positive, alpha_start) n_pos_bar = jnp.where(is_positive, n_pos_positive, jnp.zeros((), jnp.int32)) V_bar = tree_map( lambda v_first, v_later: jnp.where(is_first, v_first, v_later), V_zero, V_bar, ) V_b_bar = jnp.where(is_first, V_b_zero, V_b_bar) dt_bar = jnp.where(is_first, dt, dt_bar) alpha_bar = jnp.where(is_first, alpha, alpha_bar) n_pos_bar = jnp.where(is_first, n_pos, n_pos_bar) armijo_fail = jnp.array(False) if use_armijo: V_test = tree_map(lambda v, g: v + dt * g, V, G) V_b_test = V_b + dt * G_b dR_test = tree_map(lambda v: dt * v, V_test) X_b_test = X_b + dt * V_b_test box_test = actual_box(ref_box, X_b_test / bf) q_R_test = R + dR_test E = energy_fn(R_actual, box=box, **kwargs) E_test = energy_fn( actual_position(q_R_test, X_b_test / bf), box=box_test, **kwargs ) power_test = tree_dot(V_test, F) + jnp.sum(V_b_test * F_b) armijo_rhs = E - theta * dt * power_test failed_decrease = E_test > armijo_rhs - armijo_tol armijo_fail = jnp.logical_and( initialized, jnp.logical_or(failed_decrease, ~jnp.isfinite(E_test)) ) V_bar = tree_map( lambda v: jnp.where(armijo_fail, jnp.zeros_like(v), v), V_bar ) V_b_bar = jnp.where(armijo_fail, V_b_zero, V_b_bar) dt_bar = jnp.where(armijo_fail, dt * f_dec, dt_bar) alpha_bar = jnp.where(armijo_fail, alpha_start, alpha_bar) n_pos_bar = jnp.where(armijo_fail, jnp.zeros((), jnp.int32), n_pos_bar) V_new = tree_map(lambda v, g: v + dt_bar * g, V_bar, G) V_b_new = V_b_bar + dt_bar * G_b dR_raw = tree_map(lambda v: dt_bar * v, V_new) dX_b_raw = dt_bar * V_b_new dR_norm = jnp.sqrt(tree_dot(dR_raw, dR_raw) + jnp.sum(dX_b_raw**2)) dR_scale = jnp.where(dR_norm > max_move, max_move / dR_norm, 1.0) dR = tree_map(lambda dr: dR_scale * dr, dR_raw) dX_b = dR_scale * dX_b_raw R_new = R + dR X_b_new = X_b + dX_b F_deform_new = X_b_new / bf box_new = actual_box(ref_box, F_deform_new) R_actual_new = actual_position(R_new, F_deform_new) F_real_new = force_fn(R_actual_new, box=box_new, **kwargs) F_new = generalized_force(F_real_new, F_deform_new) F_b_new = box_force_fn( virial_fn(R_actual_new, box_new, **kwargs), F_deform_new, bf ) return PreconFireBoxDescentState( # type: ignore[call-arg] R_new, # type: ignore[call-arg] V_new, F_new, box_new, ref_box, X_b_new, V_b_new, F_b_new, bf, dt_bar, alpha_bar, n_pos_bar, jnp.array(True), precon_R, precon_previous_R, precon_previous_initialized, ) # pytype: disable=wrong-arg-count return init_fn, apply_fn