Source code for jax_md.mm_forcefields.reaxff.reaxff_helper

"""
Contains helper functions ReaxFF

Author: Mehmet Cagri Kaymak
"""

import jax
import jax.numpy as jnp
import numpy as onp
from jax_md.mm_forcefields.reaxff.reaxff_forcefield import ForceField
from dataclasses import fields
from jax import custom_jvp
from frozendict import frozendict


@custom_jvp
def safe_sqrt(x):
  """Safe sqrt function (no nan gradients)."""
  return jnp.sqrt(x)


[docs] @safe_sqrt.defjvp def safe_sqrt_jvp(primals, tangents): x = primals[0] x_dot = tangents[0] # print(x[0]) primal_out = safe_sqrt(x) tangent_out = 0.5 * x_dot / jnp.where(x > 0, primal_out, jnp.inf) return primal_out, tangent_out
# it fixes nan values issue, from: https://github.com/google/jax/issues/1052
[docs] def vectorized_cond(pred, true_fun, false_fun, operand): # true_fun and false_fun must act elementwise (i.e. be vectorized) true_op = jnp.where(pred, operand, 0) false_op = jnp.where(pred, 0, operand) return jnp.where(pred, true_fun(true_op), false_fun(false_op))
[docs] def init_params_for_filler_atom_type(FF_field_dict): # TODO: make sure that index -1 doesnt belong to a real atom!!! FF_field_dict['rvdw'][-1] = 1 FF_field_dict['eps'][-1] = 1 FF_field_dict['alf'][-1] = 1 FF_field_dict['vop'][-1] = 1 FF_field_dict['gamma'][-1] = 1 FF_field_dict['electronegativity'][-1] = 1 FF_field_dict['idempotential'][-1] = 1 FF_field_dict['bo131'][-1] = 1 FF_field_dict['bo132'][-1] = 1 FF_field_dict['bo133'][-1] = 1
[docs] def read_force_field( force_field_file, cutoff2=1e-3, hbond_close_cutoff=0.01, hbond_far_cutoff=7.5, dtype=jnp.float32, ): # to store all arguments together before creating the class FF_field_dict = {f.name: None for f in fields(ForceField) if f.init} FF_param_to_index = {} f = open(force_field_file, 'r') header = f.readline().strip() num_params = int(f.readline().strip().split()[0]) global_params = onp.zeros(shape=(num_params, 1), dtype=dtype) name_to_index = dict() body_3_indices_src = [[], [], []] body_3_indices_dst = [[], [], []] body_4_indices_src = [[], [], [], []] body_4_indices_dst = [[], [], [], []] for i in range(num_params): line = f.readline().strip() # to seperate the comment line = line.replace('!', ' ! ') global_params[i] = float(line.split()[0]) FF_field_dict['low_tap_rad'] = global_params[11] FF_field_dict['up_tap_rad'] = global_params[12] FF_field_dict['vdw_shiedling'] = global_params[28] FF_field_dict['cutoff'] = global_params[29] * 0.01 FF_field_dict['cutoff2'] = cutoff2 FF_field_dict['hbond_close_cutff'] = hbond_close_cutoff FF_field_dict['hbond_far_cutoff'] = hbond_far_cutoff FF_field_dict['over_coord1'] = global_params[0] FF_field_dict['over_coord2'] = global_params[1] FF_field_dict['trip_stab4'] = global_params[3] FF_field_dict['trip_stab5'] = global_params[4] FF_field_dict['trip_stab8'] = global_params[7] FF_field_dict['trip_stab11'] = global_params[10] # FF_param_to_index[(1,12,1)] = ("low_tap_rad", (0,)) # FF_param_to_index[(1,13,1)] = ("up_tap_rad", (0,)) FF_param_to_index[(1, 29, 1)] = ('vdw_shiedling', (0,)) # FF_param_to_index[(1,30,1)] = ("cutoff", (0,)) FF_param_to_index[(1, 1, 1)] = ('over_coord1', (0,)) FF_param_to_index[(1, 2, 1)] = ('over_coord2', (0,)) FF_param_to_index[(1, 4, 1)] = ('trip_stab4', (0,)) FF_param_to_index[(1, 5, 1)] = ('trip_stab5', (0,)) FF_param_to_index[(1, 8, 1)] = ('trip_stab8', (0,)) FF_param_to_index[(1, 11, 1)] = ('trip_stab11', (0,)) FF_field_dict['val_par3'] = global_params[2] FF_field_dict['val_par15'] = global_params[14] FF_field_dict['par_16'] = global_params[15] FF_field_dict['val_par17'] = global_params[16] FF_field_dict['val_par18'] = global_params[17] FF_field_dict['val_par20'] = global_params[19] FF_field_dict['val_par21'] = global_params[20] FF_field_dict['val_par22'] = global_params[21] FF_field_dict['val_par31'] = global_params[30] FF_field_dict['val_par34'] = global_params[33] FF_field_dict['val_par39'] = global_params[38] FF_param_to_index[(1, 3, 1)] = ('val_par3', (0,)) FF_param_to_index[(1, 15, 1)] = ('val_par15', (0,)) FF_param_to_index[(1, 16, 1)] = ('par_16', (0,)) FF_param_to_index[(1, 17, 1)] = ('val_par17', (0,)) FF_param_to_index[(1, 18, 1)] = ('val_par18', (0,)) FF_param_to_index[(1, 20, 1)] = ('val_par20', (0,)) FF_param_to_index[(1, 21, 1)] = ('val_par21', (0,)) FF_param_to_index[(1, 22, 1)] = ('val_par22', (0,)) FF_param_to_index[(1, 31, 1)] = ('val_par31', (0,)) FF_param_to_index[(1, 34, 1)] = ('val_par34', (0,)) FF_param_to_index[(1, 39, 1)] = ('val_par39', (0,)) # over under FF_field_dict['par_6'] = global_params[5] FF_field_dict['par_7'] = global_params[6] FF_field_dict['par_9'] = global_params[8] FF_field_dict['par_10'] = global_params[9] FF_field_dict['par_32'] = global_params[31] FF_field_dict['par_33'] = global_params[32] FF_param_to_index[(1, 6, 1)] = ('par_6', (0,)) FF_param_to_index[(1, 7, 1)] = ('par_7', (0,)) FF_param_to_index[(1, 9, 1)] = ('par_9', (0,)) FF_param_to_index[(1, 10, 1)] = ('par_10', (0,)) FF_param_to_index[(1, 32, 1)] = ('par_32', (0,)) FF_param_to_index[(1, 33, 1)] = ('par_33', (0,)) # torsion par_24,par_25, par_26,par_28 FF_field_dict['par_24'] = global_params[23] FF_field_dict['par_25'] = global_params[24] FF_field_dict['par_26'] = global_params[25] FF_field_dict['par_28'] = global_params[27] FF_param_to_index[(1, 24, 1)] = ('par_24', (0,)) FF_param_to_index[(1, 25, 1)] = ('par_25', (0,)) FF_param_to_index[(1, 26, 1)] = ('par_26', (0,)) FF_param_to_index[(1, 28, 1)] = ('par_28', (0,)) # ACKS2 FF_field_dict['par_35'] = global_params[34] FF_param_to_index[(1, 35, 1)] = ('par_35', (0,)) real_num_atom_types = int(f.readline().strip().split()[0]) num_atom_types = real_num_atom_types + 1 # 1 extra to store dummy atoms # self energies of the atoms FF_field_dict['self_energies'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['shift'] = onp.zeros(1, dtype=dtype) FF_field_dict['num_atom_types'] = num_atom_types # skip 3 lines of comment f.readline() f.readline() f.readline() atom_names = [] line_ctr = 0 # line 1 FF_field_dict['rat'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['aval'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['amas'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['rvdw'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['eps'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['gamma'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['rapt'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['stlp'] = onp.zeros(num_atom_types, dtype=dtype) # line 2 FF_field_dict['alf'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['vop'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['valf'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['valp1'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['electronegativity'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['idempotential'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['nphb'] = onp.zeros(num_atom_types, dtype=jnp.int32) # line 3 FF_field_dict['vnq'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['vlp1'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['bo131'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['bo132'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['bo133'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['softcut'] = onp.zeros(num_atom_types, dtype=dtype) # line 4 FF_field_dict['vovun'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['vval1'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['vval3'] = onp.zeros(num_atom_types, dtype=dtype) FF_field_dict['vval4'] = onp.zeros(num_atom_types, dtype=dtype) for i in range(real_num_atom_types): # first line line = f.readline().strip() split_line = line.split() atom_names.append(str(split_line[0])) name_to_index[atom_names[i]] = i FF_field_dict['rat'][i] = float(split_line[1]) FF_field_dict['aval'][i] = float(split_line[2]) FF_field_dict['amas'][i] = float(split_line[3]) FF_field_dict['rvdw'][i] = float(split_line[4]) # vdw FF_field_dict['eps'][i] = float(split_line[5]) # vdw FF_field_dict['gamma'][i] = float(split_line[6]) # coulomb FF_field_dict['rapt'][i] = float(split_line[7]) FF_field_dict['stlp'][i] = float(split_line[8]) # valency FF_param_to_index[(2, i + 1, 1)] = ('rat', (i,)) FF_param_to_index[(2, i + 1, 2)] = ('aval', (i,)) # FF_param_to_index[(2,i+1,3)] = ("amas", (i,)) FF_param_to_index[(2, i + 1, 4)] = ('rvdw', (i,)) FF_param_to_index[(2, i + 1, 5)] = ('eps', (i,)) FF_param_to_index[(2, i + 1, 6)] = ('gamma', (i,)) FF_param_to_index[(2, i + 1, 7)] = ('rapt', (i,)) FF_param_to_index[(2, i + 1, 8)] = ('stlp', (i,)) # second line line = f.readline().strip() split_line = line.split() FF_field_dict['alf'][i] = float(split_line[0]) # vdw FF_field_dict['vop'][i] = float(split_line[1]) # vdw FF_field_dict['valf'][i] = float(split_line[2]) # valency FF_field_dict['valp1'][i] = float(split_line[3]) # over-under coord FF_field_dict['electronegativity'][i] = float(split_line[5]) # coulomb # eta will be mult. by 2 FF_field_dict['idempotential'][i] = float(split_line[6]) # needed for hbond #needed to find acceptor-donor FF_field_dict['nphb'][i] = int(float(split_line[7])) FF_param_to_index[(2, i + 1, 9)] = ('alf', (i,)) FF_param_to_index[(2, i + 1, 10)] = ('vop', (i,)) FF_param_to_index[(2, i + 1, 11)] = ('valf', (i,)) FF_param_to_index[(2, i + 1, 12)] = ('valp1', (i,)) FF_param_to_index[(2, i + 1, 14)] = ('electronegativity', (i,)) FF_param_to_index[(2, i + 1, 15)] = ('idempotential', (i,)) # third line line = f.readline().strip() split_line = line.split() FF_field_dict['vnq'][i] = float(split_line[0]) FF_field_dict['vlp1'][i] = float(split_line[1]) FF_field_dict['bo131'][i] = float(split_line[3]) FF_field_dict['bo132'][i] = float(split_line[4]) FF_field_dict['bo133'][i] = float(split_line[5]) FF_field_dict['softcut'][i] = float(split_line[6]) # ACKS2 FF_param_to_index[(2, i + 1, 17)] = ('vnq', (i,)) FF_param_to_index[(2, i + 1, 18)] = ('vlp1', (i,)) FF_param_to_index[(2, i + 1, 20)] = ('bo131', (i,)) FF_param_to_index[(2, i + 1, 21)] = ('bo132', (i,)) FF_param_to_index[(2, i + 1, 22)] = ('bo133', (i,)) FF_param_to_index[(2, i + 1, 23)] = ('softcut', (i,)) # fourth line line = f.readline().strip() split_line = line.split() FF_field_dict['vovun'][i] = float(split_line[0]) # over-under coord FF_field_dict['vval1'][i] = float(split_line[1]) FF_field_dict['vval3'][i] = float(split_line[3]) FF_field_dict['vval4'][i] = float(split_line[4]) FF_param_to_index[(2, i + 1, 25)] = ('vovun', (i,)) FF_param_to_index[(2, i + 1, 26)] = ('vval1', (i,)) FF_param_to_index[(2, i + 1, 28)] = ('vval3', (i,)) FF_param_to_index[(2, i + 1, 29)] = ('vval4', (i,)) # This part is moved to the related part in energy calculation # if FF_field_dict['amas'][i] < 21.0: # FF_field_dict['vval3'][i] = FF_field_dict['valf'][i] FF_field_dict['name_to_index'] = name_to_index FF_field_dict['body2_params_mask'] = onp.zeros( (num_atom_types, num_atom_types), dtype=jnp.bool_ ) # line 1 FF_field_dict['de1'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['de2'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['de3'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['psi'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['pdo'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['v13cor'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['popi'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['vover'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) # line 2 FF_field_dict['psp'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['pdp'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['ptp'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['bop1'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['bop2'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['ovc'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) line = f.readline().strip() num_bonds = int(line.split()[0]) f.readline() # skip next line (comment) for b in range(num_bonds): # first line line = f.readline().strip() split_line = line.split() i = int(split_line[0]) - 1 # index starts at 0 j = int(split_line[1]) - 1 FF_field_dict['body2_params_mask'][i, j] = 1 FF_field_dict['body2_params_mask'][j, i] = 1 FF_field_dict['de1'][i, j] = float(split_line[2]) FF_field_dict['de2'][i, j] = float(split_line[3]) FF_field_dict['de3'][i, j] = float(split_line[4]) FF_field_dict['psi'][i, j] = float(split_line[5]) FF_field_dict['pdo'][i, j] = float(split_line[6]) FF_field_dict['v13cor'][i, j] = float(split_line[7]) FF_field_dict['popi'][i, j] = float(split_line[8]) FF_field_dict['vover'][i, j] = float(split_line[9]) FF_param_to_index[(3, b + 1, 1)] = ('de1', (i, j)) FF_param_to_index[(3, b + 1, 2)] = ('de2', (i, j)) FF_param_to_index[(3, b + 1, 3)] = ('de3', (i, j)) FF_param_to_index[(3, b + 1, 4)] = ('psi', (i, j)) FF_param_to_index[(3, b + 1, 5)] = ('pdo', (i, j)) # FF_param_to_index[(3,b+1,6)] = ("v13cor", (i,j)) FF_param_to_index[(3, b + 1, 7)] = ('popi', (i, j)) FF_param_to_index[(3, b + 1, 8)] = ('vover', (i, j)) # v13cor is static, so content needs to be finaized here # hence symm. part FF_field_dict['v13cor'][j, i] = FF_field_dict['v13cor'][i, j] # second line line = f.readline().strip() split_line = line.split() FF_field_dict['psp'][i, j] = float(split_line[0]) FF_field_dict['pdp'][i, j] = float(split_line[1]) FF_field_dict['ptp'][i, j] = float(split_line[2]) FF_field_dict['bop1'][i, j] = float(split_line[4]) FF_field_dict['bop2'][i, j] = float(split_line[5]) FF_field_dict['ovc'][i, j] = float(split_line[6]) # v13cor is static, so content needs to be finaized here FF_field_dict['ovc'][j, i] = FF_field_dict['ovc'][i, j] FF_param_to_index[(3, b + 1, 9)] = ('psp', (i, j)) FF_param_to_index[(3, b + 1, 10)] = ('pdp', (i, j)) FF_param_to_index[(3, b + 1, 11)] = ('ptp', (i, j)) FF_param_to_index[(3, b + 1, 13)] = ('bop1', (i, j)) FF_param_to_index[(3, b + 1, 14)] = ('bop2', (i, j)) # FF_param_to_index[(3,b+1,8)] = ("ovc", (i,j)) line = f.readline().strip() num_off_diag = int(line.split()[0]) FF_field_dict['rob1_off'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['rob1_off_mask'] = onp.zeros( (num_atom_types, num_atom_types), dtype=jnp.bool_ ) FF_field_dict['rob2_off'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['rob2_off_mask'] = onp.zeros( (num_atom_types, num_atom_types), dtype=jnp.bool_ ) FF_field_dict['rob3_off'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['rob3_off_mask'] = onp.zeros( (num_atom_types, num_atom_types), dtype=jnp.bool_ ) FF_field_dict['p1co_off'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['p1co_off_mask'] = onp.zeros( (num_atom_types, num_atom_types), dtype=jnp.bool_ ) FF_field_dict['p2co_off'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['p2co_off_mask'] = onp.zeros( (num_atom_types, num_atom_types), dtype=jnp.bool_ ) FF_field_dict['p3co_off'] = onp.zeros( (num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['p3co_off_mask'] = onp.zeros( (num_atom_types, num_atom_types), dtype=jnp.bool_ ) for i in range(num_off_diag): line = f.readline().strip() split_line = line.split() nodm1 = int(split_line[0]) nodm2 = int(split_line[1]) deodmh = float(split_line[2]) rodmh = float(split_line[3]) godmh = float(split_line[4]) rsig = float(split_line[5]) rpi = float(split_line[6]) rpi2 = float(split_line[7]) # TODO: handle the mapping of the "params" later nodm1 = nodm1 - 1 # index starts from 0 nodm2 = nodm2 - 1 # index starts from 0 FF_field_dict['rob1_off'][nodm1, nodm2] = rsig FF_field_dict['rob2_off'][nodm1, nodm2] = rpi FF_field_dict['rob3_off'][nodm1, nodm2] = rpi2 FF_field_dict['p1co_off'][nodm1, nodm2] = rodmh FF_field_dict['p2co_off'][nodm1, nodm2] = deodmh FF_field_dict['p3co_off'][nodm1, nodm2] = godmh if ( rsig > 0 and FF_field_dict['rat'][nodm1] > 0 and FF_field_dict['rat'][nodm2] > 0 ): FF_field_dict['rob1_off_mask'][nodm1, nodm2] = 1 FF_param_to_index[(4, i + 1, 4)] = ('rob1_off', (nodm1, nodm2)) if ( rpi > 0 and FF_field_dict['rapt'][nodm1] > 0 and FF_field_dict['rapt'][nodm2] > 0 ): FF_field_dict['rob2_off_mask'][nodm1, nodm2] = 1 FF_param_to_index[(4, i + 1, 5)] = ('rob2_off', (nodm1, nodm2)) if ( rpi2 > 0 and FF_field_dict['vnq'][nodm1] > 0 and FF_field_dict['vnq'][nodm2] > 0 ): FF_field_dict['rob3_off_mask'][nodm1, nodm2] = 1 FF_param_to_index[(4, i + 1, 6)] = ('rob3_off', (nodm1, nodm2)) if rodmh > 0: FF_field_dict['p1co_off_mask'][nodm1, nodm2] = 1 FF_param_to_index[(4, i + 1, 2)] = ('p1co_off', (nodm1, nodm2)) if deodmh > 0: FF_field_dict['p2co_off_mask'][nodm1, nodm2] = 1 FF_param_to_index[(4, i + 1, 1)] = ('p2co_off', (nodm1, nodm2)) if godmh > 0: FF_field_dict['p3co_off_mask'][nodm1, nodm2] = 1 FF_param_to_index[(4, i + 1, 3)] = ('p3co_off', (nodm1, nodm2)) # valency angle parameters line = f.readline().strip() num_val_params = int(line.split()[0]) FF_field_dict['body3_params_mask'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types), dtype=jnp.bool_ ) FF_field_dict['th0'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['vka'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['vka3'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['vka8'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['vkac'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['vkap'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['vval2'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types), dtype=dtype ) for val in range(num_val_params): line = f.readline().strip() split_line = line.split() ind1 = int(split_line[0]) ind2 = int(split_line[1]) ind3 = int(split_line[2]) th0 = float(split_line[3]) vka = float(split_line[4]) vka3 = float(split_line[5]) vka8 = float(split_line[6]) vkac = float(split_line[7]) vkap = float(split_line[8]) vval2 = float(split_line[9]) ind1 = ind1 - 1 # index starts from 0 ind2 = ind2 - 1 # index starts from 0 ind3 = ind3 - 1 # index starts from 0 FF_field_dict['th0'][ind1, ind2, ind3] = th0 FF_field_dict['vka'][ind1, ind2, ind3] = vka FF_field_dict['vka3'][ind1, ind2, ind3] = vka3 FF_field_dict['vka8'][ind1, ind2, ind3] = vka8 FF_field_dict['vkac'][ind1, ind2, ind3] = vkac FF_field_dict['vkap'][ind1, ind2, ind3] = vkap FF_field_dict['vval2'][ind1, ind2, ind3] = vval2 FF_param_to_index[(5, val + 1, 1)] = ('th0', (ind1, ind2, ind3)) FF_param_to_index[(5, val + 1, 2)] = ('vka', (ind1, ind2, ind3)) FF_param_to_index[(5, val + 1, 3)] = ('vka3', (ind1, ind2, ind3)) FF_param_to_index[(5, val + 1, 4)] = ('vka8', (ind1, ind2, ind3)) FF_param_to_index[(5, val + 1, 5)] = ('vkac', (ind1, ind2, ind3)) FF_param_to_index[(5, val + 1, 6)] = ('vkap', (ind1, ind2, ind3)) FF_param_to_index[(5, val + 1, 7)] = ('vval2', (ind1, ind2, ind3)) body_3_indices_dst[0].append(ind3) body_3_indices_dst[1].append(ind2) body_3_indices_dst[2].append(ind1) body_3_indices_src[0].append(ind1) body_3_indices_src[1].append(ind2) body_3_indices_src[2].append(ind3) if abs(vka) > 0.001: FF_field_dict['body3_params_mask'][ind1, ind2, ind3] = 1.0 FF_field_dict['body3_params_mask'][ind3, ind2, ind1] = 1.0 # torsion parameters line = f.readline().strip() num_tors_params = int(line.split()[0]) FF_field_dict['body34_params_mask'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types), dtype=jnp.bool_ ) FF_field_dict['body4_params_mask'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types, num_atom_types), dtype=jnp.bool_, ) FF_field_dict['v1'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types, num_atom_types), dtype=dtype, ) FF_field_dict['v2'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types, num_atom_types), dtype=dtype, ) FF_field_dict['v3'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types, num_atom_types), dtype=dtype, ) FF_field_dict['v4'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types, num_atom_types), dtype=dtype, ) FF_field_dict['vconj'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types, num_atom_types), dtype=dtype, ) torsion_param_sets = set() lines_with_negative_vals = [] for tors in range(num_tors_params): line = f.readline().strip() split_line = line.split() ind1 = int(split_line[0]) ind2 = int(split_line[1]) ind3 = int(split_line[2]) ind4 = int(split_line[3]) v1 = float(split_line[4]) v2 = float(split_line[5]) v3 = float(split_line[6]) v4 = float(split_line[7]) vconj = float(split_line[8]) # v2bo = float(split_line[9]) # v3bo = float(split_line[10]) ind1 = ind1 - 1 # index starts from 0 ind2 = ind2 - 1 # index starts from 0 ind3 = ind3 - 1 # index starts from 0 ind4 = ind4 - 1 # index starts from 0 # if all parameters are 0, skip if v1 == 0.0 and v2 == 0.0 and v3 == 0.0 and v4 == 0.0 and vconj == 0.0: continue # TODO: handle 0 indices in the param. file later if ind1 > -1 and ind4 > -1: if (ind1, ind2, ind3, ind4) in torsion_param_sets: print( f'[WARNING] 4-body parameters for ({ind1 + 1},{ind2 + 1},{ind3 + 1},{ind4 + 1}) appeared twice!' ) print('Might cause numerical inaccuracies!') print('Skipping the dublicate occurance...') continue FF_field_dict['v1'][ind1, ind2, ind3, ind4] = v1 FF_field_dict['v2'][ind1, ind2, ind3, ind4] = v2 FF_field_dict['v3'][ind1, ind2, ind3, ind4] = v3 FF_field_dict['v4'][ind1, ind2, ind3, ind4] = v4 FF_field_dict['vconj'][ind1, ind2, ind3, ind4] = vconj FF_param_to_index[(6, tors + 1, 1)] = ('v1', (ind1, ind2, ind3, ind4)) FF_param_to_index[(6, tors + 1, 2)] = ('v2', (ind1, ind2, ind3, ind4)) FF_param_to_index[(6, tors + 1, 3)] = ('v3', (ind1, ind2, ind3, ind4)) FF_param_to_index[(6, tors + 1, 4)] = ('v4', (ind1, ind2, ind3, ind4)) FF_param_to_index[(6, tors + 1, 5)] = ('vconj', (ind1, ind2, ind3, ind4)) FF_field_dict['body4_params_mask'][ind1, ind2, ind3, ind4] = 1 FF_field_dict['body4_params_mask'][ind4, ind3, ind2, ind1] = 1 FF_field_dict['body34_params_mask'][ind1, ind2, ind3] = 1 FF_field_dict['body34_params_mask'][ind3, ind2, ind1] = 1 FF_field_dict['body34_params_mask'][ind2, ind3, ind4] = 1 FF_field_dict['body34_params_mask'][ind4, ind3, ind2] = 1 body_4_indices_dst[0].append(ind4) body_4_indices_dst[1].append(ind3) body_4_indices_dst[2].append(ind2) body_4_indices_dst[3].append(ind1) body_4_indices_src[0].append(ind1) body_4_indices_src[1].append(ind2) body_4_indices_src[2].append(ind3) body_4_indices_src[3].append(ind4) torsion_param_sets.add((ind1, ind2, ind3, ind4)) torsion_param_sets.add((ind4, ind3, ind2, ind1)) elif ind1 == -1 and ind4 == -1: lines_with_negative_vals.append( [tors, ind1, ind2, ind3, ind4, v1, v2, v3, v4, vconj] ) else: print(f'Invalid torsion parameter section, line:{tors + 1}') return None # the lines with negative values affect mutliple types, so they need to # be processed at the end # if line dedicated to [ind1, ind2, ind3, ind4] is not available # then [-1, line2, line3, -1] will be used instead for vals in lines_with_negative_vals: tors, ind1, ind2, ind3, ind4, v1, v2, v3, v4, vconj = vals # Last index is reserved for this part sel_ind = FF_field_dict['num_atom_types'] - 1 FF_field_dict['v1'][sel_ind, ind2, ind3, sel_ind] = v1 FF_field_dict['v2'][sel_ind, ind2, ind3, sel_ind] = v2 FF_field_dict['v3'][sel_ind, ind2, ind3, sel_ind] = v3 FF_field_dict['v4'][sel_ind, ind2, ind3, sel_ind] = v4 FF_field_dict['vconj'][sel_ind, ind2, ind3, sel_ind] = vconj FF_param_to_index[(6, tors + 1, 1)] = ('v1', (sel_ind, ind2, ind3, sel_ind)) FF_param_to_index[(6, tors + 1, 2)] = ('v2', (sel_ind, ind2, ind3, sel_ind)) FF_param_to_index[(6, tors + 1, 3)] = ('v3', (sel_ind, ind2, ind3, sel_ind)) FF_param_to_index[(6, tors + 1, 4)] = ('v4', (sel_ind, ind2, ind3, sel_ind)) FF_param_to_index[(6, tors + 1, 5)] = ( 'vconj', (sel_ind, ind2, ind3, sel_ind), ) for i in range(real_num_atom_types): for j in range(real_num_atom_types): if FF_field_dict['body4_params_mask'][i, ind2, ind3, j] == 0: body_4_indices_src[0].append(sel_ind) body_4_indices_src[1].append(ind2) body_4_indices_src[2].append(ind3) body_4_indices_src[3].append(sel_ind) body_4_indices_dst[0].append(i) body_4_indices_dst[1].append(ind2) body_4_indices_dst[2].append(ind3) body_4_indices_dst[3].append(j) FF_field_dict['body4_params_mask'][i, ind2, ind3, j] = 1 FF_field_dict['body34_params_mask'][i, ind2, ind3] = 1 FF_field_dict['body34_params_mask'][ind3, ind2, i] = 1 FF_field_dict['body34_params_mask'][ind2, ind3, j] = 1 FF_field_dict['body34_params_mask'][j, ind3, ind2] = 1 if FF_field_dict['body4_params_mask'][j, ind3, ind2, i] == 0: body_4_indices_src[0].append(sel_ind) body_4_indices_src[1].append(ind2) body_4_indices_src[2].append(ind3) body_4_indices_src[3].append(sel_ind) body_4_indices_dst[0].append(j) body_4_indices_dst[1].append(ind3) body_4_indices_dst[2].append(ind2) body_4_indices_dst[3].append(i) FF_field_dict['body4_params_mask'][j, ind3, ind2, i] = 1 FF_field_dict['body34_params_mask'][i, ind2, ind3] = 1 FF_field_dict['body34_params_mask'][ind3, ind2, i] = 1 FF_field_dict['body34_params_mask'][ind2, ind3, j] = 1 FF_field_dict['body34_params_mask'][j, ind3, ind2] = 1 # hbond parameters line = f.readline().strip() num_hbond_params = int(line.split()[0]) FF_field_dict['hb_params_mask'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types), dtype=jnp.bool_ ) FF_field_dict['rhb'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['dehb'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['vhb1'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types), dtype=dtype ) FF_field_dict['vhb2'] = onp.zeros( (num_atom_types, num_atom_types, num_atom_types), dtype=dtype ) for i in range(num_hbond_params): line = f.readline().strip() split_line = line.split() ind1 = int(split_line[0]) - 1 ind2 = int(split_line[1]) - 1 ind3 = int(split_line[2]) - 1 rhb = float(split_line[3]) dehb = float(split_line[4]) vhb1 = float(split_line[5]) vhb2 = float(split_line[6]) FF_field_dict['rhb'][ind1, ind2, ind3] = rhb FF_field_dict['dehb'][ind1, ind2, ind3] = dehb FF_field_dict['vhb1'][ind1, ind2, ind3] = vhb1 FF_field_dict['vhb2'][ind1, ind2, ind3] = vhb2 FF_field_dict['hb_params_mask'][ind1, ind2, ind3] = 1 FF_param_to_index[(7, i + 1, 1)] = ('rhb', (ind1, ind2, ind3)) FF_param_to_index[(7, i + 1, 2)] = ('dehb', (ind1, ind2, ind3)) FF_param_to_index[(7, i + 1, 3)] = ('vhb1', (ind1, ind2, ind3)) FF_param_to_index[(7, i + 1, 4)] = ('vhb2', (ind1, ind2, ind3)) f.close() for i in range(3): body_3_indices_src[i] = onp.array(body_3_indices_src[i], dtype=onp.int32) body_3_indices_dst[i] = onp.array(body_3_indices_dst[i], dtype=onp.int32) for i in range(4): body_4_indices_src[i] = onp.array(body_4_indices_src[i], dtype=onp.int32) body_4_indices_dst[i] = onp.array(body_4_indices_dst[i], dtype=onp.int32) FF_field_dict['body3_indices_src'] = tuple(body_3_indices_src) FF_field_dict['body3_indices_dst'] = tuple(body_3_indices_dst) FF_field_dict['body4_indices_src'] = tuple(body_4_indices_src) FF_field_dict['body4_indices_dst'] = tuple(body_4_indices_dst) # TODO: this function call is not needed after the energy function is refactored init_params_for_filler_atom_type(FF_field_dict) # placeholders for params to be filled later FF_field_dict['rob1'] = onp.zeros_like(FF_field_dict['rob1_off']) FF_field_dict['rob2'] = onp.zeros_like(FF_field_dict['rob1_off']) FF_field_dict['rob3'] = onp.zeros_like(FF_field_dict['rob1_off']) FF_field_dict['p1co'] = onp.zeros_like(FF_field_dict['p1co_off']) FF_field_dict['p2co'] = onp.zeros_like(FF_field_dict['p1co_off']) FF_field_dict['p3co'] = onp.zeros_like(FF_field_dict['p1co_off']) FF_field_dict['softcut_2d'] = onp.zeros_like(FF_field_dict['p1co_off']) FF_field_dict['params_to_indices'] = frozendict(FF_param_to_index) FF_fields = ForceField.__dataclass_fields__ for k in FF_field_dict: is_static = k in FF_fields and FF_fields[k].metadata.get('static', False) if type(FF_field_dict[k]) == onp.ndarray: FF_field_dict[k] = jnp.array(FF_field_dict[k]) elif type(FF_field_dict[k]) == float: FF_field_dict[k] = jnp.array(FF_field_dict[k], dtype=dtype) force_field = ForceField.init_from_arg_dict(FF_field_dict) return force_field