Energy Minimization Routines#
Code to minimize the energy of a system.
This file contains a number of different methods that can be used to find the nearest minimum (inherent structure) to some initial system described by a position R.
Minimization code follows the same overall structure as optimizers in JAX. Optimizers return two functions:
- init_fn:
Function that initializes the state of an optimizer. Should take positions as an ndarray of shape
[n, output_dimension]. Returns a state which will be a dataclass.- apply_fn:
Function that takes a state and produces a new state after one step of optimization.
Minimization Routines#
- jax_md.minimize.gradient_descent(energy_or_force, shift_fn, step_size)[source]#
Defines gradient descent minimization.
This is the simplest optimization strategy that moves particles down their gradient to the nearest minimum. Generally, gradient descent is slower than other methods and is included mostly for its simplicity.
- Parameters:
energy_or_force (
Callable[...,Array]) – A function that produces either an energy or a force from a set of particle positions specified as an ndarray of shape[n, spatial_dimension].shift_fn (
ShiftFn) – A function that displaces positions,R, by an amountdR. BothRanddRshould be ndarrays of shape[n, spatial_dimension].step_size (
float) – A floating point specifying the size of each step.
- Return type:
- Returns:
See above.
- jax_md.minimize.fire_descent(energy_or_force, shift_fn, dt_start=0.1, dt_max=0.4, n_min=5, f_inc=1.1, f_dec=0.5, alpha_start=0.1, f_alpha=0.99)[source]#
Defines FIRE minimization.
This code implements the “Fast Inertial Relaxation Engine” from Bitzek et al. [1]
- Parameters:
energy_or_force (
Callable[...,Array]) – A function that produces either an energy or a force from a set of particle positions specified as an ndarray of shape[n, spatial_dimension].shift_fn (
ShiftFn) – A function that displaces positionsR, by an amountdR. BothRanddRshould be ndarrays of shape[n, spatial_dimension].dt_start (
float) – The initial step size during minimization as a float.dt_max (
float) – The maximum step size during minimization as a float.n_min (
float) – An integer specifying the minimum number of steps moving in the correct direction before dt and f_alpha should be updated.f_inc (
float) – A float specifying the fractional rate by which the step size should be increased.f_dec (
float) – A float specifying the fractional rate by which the step size should be decreased.alpha_start (
float) – A float specifying the initial momentum.f_alpha (
float) – A float specifying the fractional change in momentum.
- Return type:
Tuple[Callable[...,FireDescentState],Callable[...,FireDescentState]]- Returns:
See above.
References
Data Types#
- class jax_md.minimize.FireDescentState(position, momentum, force, mass, dt, alpha, n_pos)[source]#
A dataclass containing state information for the Fire Descent minimizer.
- position#
The current position of particles. An ndarray of floats with shape
[n, spatial_dimension].
- momentum#
The current momentum of particles. An ndarray of floats with shape
[n, spatial_dimension].
- force#
The current force on particles. An ndarray of floats with shape
[n, spatial_dimension].
- mass#
The mass of particles. A float or an ndarray of floats with shape
[n].
- dt#
A float specifying the current step size.
- alpha#
A float specifying the current FIRE mixing parameter.
- n_pos#
The number of consecutive steps with positive power.