Source code for jax_md.mm_forcefields.reaxff.reaxff_forcefield

"""
Contains force field related code

Author: Mehmet Cagri Kaymak
"""

from jax_md import dataclasses, util
from dataclasses import fields
import jax
import jax.numpy as jnp

Array = util.Array


[docs] @dataclasses.dataclass class ForceField(object): """ Container for ReaxFF parameters """ num_atom_types: int = dataclasses.static_field() name_to_index: dict = dataclasses.static_field() params_to_indices: dict = dataclasses.static_field() # these tuples are used to handle symmetric parameters in 3 and 4 body param # lists body3_indices_src: tuple = dataclasses.static_field() body3_indices_dst: tuple = dataclasses.static_field() body4_indices_src: tuple = dataclasses.static_field() body4_indices_dst: tuple = dataclasses.static_field() # self energies for each atom type self_energies: Array # overall energy shift shift: Array low_tap_rad: Array = dataclasses.static_field() up_tap_rad: Array = dataclasses.static_field() cutoff: Array = dataclasses.static_field() cutoff2: Array = dataclasses.static_field() hb_close_cutoff: Array = dataclasses.static_field() hb_far_cutoff: Array = dataclasses.static_field() body2_params_mask: Array = dataclasses.static_field() body3_params_mask: Array = dataclasses.static_field() body4_params_mask: Array = dataclasses.static_field() # since 4 body interactions are created from 3-body, we need to extend # 3 body mask based on the 4-body interactions to not miss any 4 body inter. body34_params_mask: Array = dataclasses.static_field() hb_params_mask: Array = dataclasses.static_field() global_params: Array electronegativity: Array idempotential: Array gamma: Array rvdw: Array p1co: Array p1co_off: Array p1co_off_mask: Array = dataclasses.static_field() eps: Array p2co: Array p2co_off: Array p2co_off_mask: Array = dataclasses.static_field() alf: Array p3co: Array p3co_off: Array p3co_off_mask: Array = dataclasses.static_field() vop: Array amas: Array rat: Array rob1: Array rob1_off: Array rob1_off_mask: Array = dataclasses.static_field() rapt: Array rob2: Array rob2_off: Array rob2_off_mask: Array = dataclasses.static_field() vnq: Array rob3: Array rob3_off: Array rob3_off_mask: Array = dataclasses.static_field() ptp: Array pdp: Array popi: Array pdo: Array bop1: Array bop2: Array de1: Array de2: Array de3: Array psp: Array psi: Array aval: Array vval3: Array bo131: Array bo132: Array bo133: Array ovc: Array = dataclasses.static_field() v13cor: Array = dataclasses.static_field() softcut: Array # acks2 parameter softcut_2d: Array # softcut_2d[i,j] = 0.5 * (softcut[i] + softcut[j]) stlp: Array valf: Array vval1: Array vval2: Array vval3: Array vval4: Array vkac: Array th0: Array vka: Array vkap: Array vka3: Array vka8: Array vval2: Array vlp1: Array valp1: Array vovun: Array vover: Array v1: Array v2: Array v3: Array v4: Array vconj: Array nphb: Array = dataclasses.static_field() rhb: Array dehb: Array vhb1: Array vhb2: Array # global parameters vdw_shiedling: Array trip_stab4: Array trip_stab5: Array trip_stab8: Array trip_stab11: Array over_coord1: Array over_coord2: Array val_par3: Array val_par15: Array val_par17: Array val_par18: Array val_par20: Array val_par21: Array val_par22: Array val_par31: Array val_par34: Array val_par39: Array par_16: Array par_6: Array par_7: Array par_9: Array par_10: Array par_32: Array par_33: Array par_24: Array par_25: Array par_26: Array par_28: Array par_35: Array # ACKS2
[docs] @classmethod def init_from_arg_dict(cls, kwargs): field_set = {f.name for f in fields(cls) if f.init} filtered_kwargs = {k: v for k, v in kwargs.items() if k in field_set} if len(filtered_kwargs) != len(field_set): print('Missing arguments') else: return cls(**filtered_kwargs) return cls(**filtered_kwargs)
[docs] def fill_symm(force_field): """ Fills the parameter arrays based on the symmetries """ # 2 body-params # for now global num_atoms = force_field.num_atom_types body_2_indices = jnp.tril_indices(num_atoms, k=-1) body_3_indices_src = force_field.body3_indices_src body_3_indices_dst = force_field.body3_indices_dst body_4_indices_src = force_field.body4_indices_src body_4_indices_dst = force_field.body4_indices_dst replace_dict = {} body_2_attr = [ 'p1co', 'p2co', 'p3co', 'p1co_off', 'p2co_off', 'p3co_off', 'rob1', 'rob2', 'rob3', 'rob1_off', 'rob2_off', 'rob3_off', 'ptp', 'pdp', 'popi', 'pdo', 'bop1', 'bop2', 'de1', 'de2', 'de3', 'psp', 'psi', 'vover', ] for attr in body_2_attr: arr = getattr(force_field, attr) arr = arr.at[body_2_indices].set(arr.transpose()[body_2_indices]) replace_dict[attr] = arr body_3_attr = ['vval2', 'vkac', 'th0', 'vka', 'vkap', 'vka3', 'vka8'] for attr in body_3_attr: arr = getattr(force_field, attr) arr = arr.at[body_3_indices_dst].set(arr[body_3_indices_src]) replace_dict[attr] = arr body_4_attr = ['v1', 'v2', 'v3', 'v4', 'vconj'] for attr in body_4_attr: arr = getattr(force_field, attr) arr = arr.at[body_4_indices_dst].set(arr[body_4_indices_src]) replace_dict[attr] = arr force_field = dataclasses.replace(force_field, **replace_dict) return force_field
[docs] def fill_off_diag(force_field): """ Fills the off-diagonal entries in the parameter arrays """ num_rows = force_field.num_atom_types rat = force_field.rat rapt = force_field.rapt vnq = force_field.vnq rvdw = force_field.rvdw eps = force_field.eps alf = force_field.alf rob1_off = force_field.rob1_off rob2_off = force_field.rob2_off rob3_off = force_field.rob3_off rob1_off_mask = force_field.rob1_off_mask rob2_off_mask = force_field.rob2_off_mask rob3_off_mask = force_field.rob3_off_mask p1co_off = force_field.p1co_off p2co_off = force_field.p2co_off p3co_off = force_field.p3co_off p1co_off_mask = force_field.p1co_off_mask p2co_off_mask = force_field.p2co_off_mask p3co_off_mask = force_field.p3co_off_mask softcut = force_field.softcut mat1 = rat.reshape(1, -1) mat1 = jnp.tile(mat1, (num_rows, 1)) mat1_tr = mat1.transpose() rob1_temp = (mat1 + mat1_tr) * 0.5 rob1_temp = jnp.where(mat1 > 0.0, rob1_temp, 0.0) rob1_temp = jnp.where(mat1_tr > 0.0, rob1_temp, 0.0) mat1 = rapt.reshape(1, -1) mat1 = jnp.tile(mat1, (num_rows, 1)) mat1_tr = mat1.transpose() rob2_temp = (mat1 + mat1_tr) * 0.5 rob2_temp = jnp.where(mat1 > 0.0, rob2_temp, 0.0) rob2_temp = jnp.where(mat1_tr > 0.0, rob2_temp, 0.0) mat1 = vnq.reshape(1, -1) mat1 = jnp.tile(mat1, (num_rows, 1)) mat1_tr = mat1.transpose() rob3_temp = (mat1 + mat1_tr) * 0.5 rob3_temp = jnp.where(mat1 > 0.0, rob3_temp, 0.0) rob3_temp = jnp.where(mat1_tr > 0.0, rob3_temp, 0.0) p1co_temp = 4.0 * rvdw.reshape(-1, 1).dot(rvdw.reshape(1, -1)) p1co_temp = util.safe_mask(p1co_temp > 0, jnp.sqrt, p1co_temp) p2co_temp = eps.reshape(-1, 1).dot(eps.reshape(1, -1)) p2co_temp = util.safe_mask(p2co_temp > 0, jnp.sqrt, p2co_temp) p3co_temp = alf.reshape(-1, 1).dot(alf.reshape(1, -1)) p3co_temp = util.safe_mask(p3co_temp > 0, jnp.sqrt, p3co_temp) rob1 = jnp.where(rob1_off_mask == 0, rob1_temp, rob1_off) rob2 = jnp.where(rob2_off_mask == 0, rob2_temp, rob2_off) rob3 = jnp.where(rob3_off_mask == 0, rob3_temp, rob3_off) p1co = jnp.where(p1co_off_mask == 0, p1co_temp, p1co_off * 2.0) p2co = jnp.where(p2co_off_mask == 0, p2co_temp, p2co_off) p3co = jnp.where(p3co_off_mask == 0, p3co_temp, p3co_off) softcut_2d = 0.5 * (softcut.reshape(-1, 1) + softcut.reshape(1, -1)) force_field = dataclasses.replace( force_field, rob1=rob1, rob2=rob2, rob3=rob3, p1co=p1co, p2co=p2co, p3co=p3co, softcut_2d=softcut_2d, ) return force_field