Source code for jax_md.mm_forcefields.reaxff.reaxff_energy

"""
Contains energy related functions for ReaxFF

Author: Mehmet Cagri Kaymak
"""

from __future__ import annotations
import numpy as onp
import jax.numpy as jnp
import jax
from jax.scipy.sparse import linalg
from jax_md import util
from jax_md.util import safe_mask
from jax_md.util import high_precision_sum
from jax_md.mm_forcefields.reaxff.reaxff_helper import (
  vectorized_cond,
  safe_sqrt,
)
from jax_md.mm_forcefields.reaxff.reaxff_forcefield import ForceField

# to resolve circular dependency
from typing import TYPE_CHECKING

if TYPE_CHECKING:
  from jax_md.mm_forcefields.reaxff.reaxff_interactions import (
    ReaxFFNeighborLists,
  )
from jax import custom_jvp

# Types
f32 = util.f32
f64 = util.f64
Array = util.Array

c1c = 332.0638  # Coulomb energy conversion
rdndgr = 180.0 / onp.pi
dgrrdn = 1.0 / rdndgr


[docs] def calculate_reaxff_energy( species: Array, atomic_numbers: Array, nbr_lists: 'ReaxFFNeighborLists', close_nbr_dists: Array, far_nbr_dists: Array, body_3_angles: Array, body_4_angles: Array, hb_ang_dist: Array, force_field: ForceField, init_charges: Array = None, total_charge: float = 0.0, tol: float = 1e-06, max_solver_iter: int = 500, backprop_solve: bool = False, tors_2013: bool = False, tapered_reaxff: bool = False, solver_model: str = 'EEM', ): """Calculate full ReaxFF potential. Args: species: An ndarray of shape `[n, ]` for the atom types. atomic_numbers: An ndarray of shape `[n, ]` for the atomic numbers of the atoms. nbr_lists: Contains the interaction lists for ReaxFF potential. close_nbr_dists: An ndarray of shape `[n,m]` for bonded interaction distances far_nbr_dists: An ndarray of shape `[n,m]` for non-bonded interaction distances (vdw and Coulomb) body_3_angles, body_4_angles, hb_ang_dist: Angles and distances for many-body interactions force_field: ReaxFF parameters init_charges: Initial charges for the iterative solver An ndarray of shape `[n, ]` or None total_charge: Total charge of the system (float) tol: Tolarence for the charge solver max_solver_iter: Maximum number of solver iterations If set to -1, use direct solve backprop_solve: Control variable to decide whether to do a solve to calculate the gradients of the charges wrt positions. By definition, the gradients should be 0 but if the solver tolerance is high, the gradients might be non-ignorable. tors_2013: Control variable to decide whether to use more stable version of the torsion interactions tapered_reaxff: Control variable to decide whether to use tapered cutoffs for various bonded interactions, causes computational overhead solver_model: Control variable for the solver model ("EEM" or "ACKS") Returns: System energy in kcal/mol """ cou_pot = 0 vdw_pot = 0 charge_pot = 0 cov_pot = 0 lone_pot = 0 val_pot = 0 total_penalty = 0 total_conj = 0 overunder_pot = 0 tor_conj = 0 torsion_pot = 0 hb_pot = 0 coulomb_acks2 = 0 charge_pot_acks2 = 0 self_energy = 0 result_dict = dict() N = len(species) atom_mask = species >= 0 self_energy = ( jnp.sum(force_field.self_energies[species] * atom_mask) + force_field.shift[0] ) result_dict['self_energy'] = self_energy far_nbr_inds = nbr_lists.far_nbrs.idx far_neigh_types = species[far_nbr_inds] close_nbr_inds = nbr_lists.close_nbrs.idx atom_inds = jnp.arange(N).reshape(-1, 1) close_nbr_inds = nbr_lists.close_nbrs.idx[atom_inds, nbr_lists.filter2.idx] close_nbr_inds = jnp.where(nbr_lists.filter2.idx != -1, close_nbr_inds, N) body_3_inds = nbr_lists.filter3.idx body_4_inds = nbr_lists.filter4.idx if nbr_lists.filter_hb != None: hb_inds = nbr_lists.filter_hb.idx else: hb_inds = None # shared accross charge calc, coulomb, and vdw # + 1e-15 # for numerical issues far_nbr_mask = (far_nbr_inds != N) & ( atom_mask.reshape(-1, 1) & atom_mask[far_nbr_inds] ) far_nbr_dists = far_nbr_dists * far_nbr_mask tapered_dists = taper(far_nbr_dists, 0.0, 10.0) tapered_dists = jnp.where( (far_nbr_dists > 10.0) | (far_nbr_dists < 0.001), 0.0, tapered_dists ) # shared accross charge calc and coulomb gamma = jnp.power(force_field.gamma.reshape(-1, 1), 3 / 2) gamma_mat = gamma * gamma.transpose() gamma_mat = gamma_mat[far_neigh_types, species.reshape(-1, 1)] hulp1_mat = far_nbr_dists**3 + (1 / gamma_mat) hulp2_mat = jnp.power(hulp1_mat, 1.0 / 3.0) * far_nbr_mask if solver_model == 'EEM': charges = calculate_eem_charges( species, atom_mask, far_nbr_inds, hulp2_mat, tapered_dists, force_field.idempotential, force_field.electronegativity, init_charges, total_charge, backprop_solve, tol, max_solver_iter, ) else: xcut = force_field.softcut_2d[far_neigh_types, species.reshape(-1, 1)] d = far_nbr_dists / xcut bond_softness = force_field.par_35 * (d**3) * ((1 - d) ** 6) bond_softness = safe_mask( far_nbr_dists < xcut, lambda x: x, bond_softness, 0.0 ) self_mask = jnp.arange(N).reshape(-1, 1) == far_nbr_inds bond_softness = jnp.where(self_mask == 1, 0, bond_softness) charges, effpot = calculate_acks2_charges( species, atom_mask, far_nbr_inds, hulp2_mat, bond_softness, tapered_dists, force_field.idempotential, force_field.electronegativity, total_charge, backprop_solve, tol, ) charge_pot_acks2 = calculate_charge_energy_acks2(charges[:-1], effpot) coulomb_acks2 = calculate_acks2_coulomb_pot( far_nbr_inds, atom_mask, effpot, bond_softness ) result_dict['charges'] = charges[:-1] cou_pot = calculate_coulomb_pot( far_nbr_inds, atom_mask, hulp2_mat, tapered_dists, charges[:-1] ) cou_pot += coulomb_acks2 result_dict['E_coulomb'] = cou_pot charge_pot = calculate_charge_energy( species, charges[:-1], force_field.idempotential, force_field.electronegativity, ) charge_pot += charge_pot_acks2 result_dict['E_charge'] = charge_pot vdw_pot = calculate_vdw_pot( species, far_nbr_mask, far_nbr_inds, far_nbr_dists, tapered_dists, force_field, ) result_dict['E_vdw'] = vdw_pot atomic_num1 = atomic_numbers.reshape(-1, 1) atomic_num2 = atomic_numbers[close_nbr_inds] # O: 8, C:6 triple_bond1 = jnp.logical_and(atomic_num1 == 8, atomic_num2 == 6) triple_bond2 = jnp.logical_and(atomic_num1 == 6, atomic_num2 == 8) triple_bond = jnp.logical_or(triple_bond1, triple_bond2) covbon_mask = (close_nbr_inds != N) & ( atom_mask.reshape(-1, 1) & atom_mask[close_nbr_inds] ) [cov_pot, bo, bopi, bopi2, abo] = calculate_covbon_pot( close_nbr_inds, close_nbr_dists, covbon_mask, species, triple_bond, force_field, ) result_dict['E_covalent'] = cov_pot [lone_pot, vlp] = calculate_lonpar_pot(species, atom_mask, abo, force_field) result_dict['E_lone_pair'] = lone_pot overunder_pot = calculate_ovcor_pot( species, atomic_numbers, atom_mask, close_nbr_inds, close_nbr_dists, close_nbr_inds != N, bo, bopi, bopi2, abo, vlp, force_field, ) result_dict['E_over_under'] = overunder_pot [val_pot, total_penalty, total_conj] = calculate_valency_pot( species, body_3_inds, body_3_angles, body_3_inds[:, 0] != -1, close_nbr_inds, vlp, bo, bopi, bopi2, abo, force_field, ) result_dict['E_valency'] = val_pot result_dict['E_valency_penalty'] = total_penalty result_dict['E_valency_conj'] = total_conj [torsion_pot, tor_conj] = calculate_torsion_pot( species, body_4_inds, body_4_angles, body_4_inds[:, 0] != -1, close_nbr_inds, bo, bopi, abo, force_field, tors_2013, ) result_dict['E_torsion'] = torsion_pot result_dict['E_torsion_conj'] = tor_conj result_dict['E_hbond'] = 0.0 if hb_inds != None: hb_mask = (hb_inds[:, 1] != -1) & (hb_inds[:, 2] != -1) hb_pot = calculate_hb_pot( species, hb_inds, hb_ang_dist, hb_mask, close_nbr_inds, far_nbr_inds, bo, force_field, ) result_dict['E_hbond'] = hb_pot return ( cou_pot + vdw_pot + charge_pot + cov_pot + lone_pot + val_pot + total_penalty + total_conj + overunder_pot + tor_conj + torsion_pot + hb_pot + self_energy ), charges
[docs] def calculate_eem_charges( species: Array, atom_mask: Array, nbr_inds: Array, hulp2_mat: Array, tapered_dists: Array, idempotential: Array, electronegativity: Array, init_charges: Array = None, total_charge: float = 0.0, backprop_solve: bool = False, tol: float = 1e-06, max_solver_iter: int = 500, ): """EEM charge solver. If max_solver_iter is set to -1, use direct solve. Returns: An array of shape [n+1,] where first n entries are the charges and last entry is the electronegativity equalization value. """ if backprop_solve == False: tapered_dists = jax.lax.stop_gradient(tapered_dists) hulp2_mat = jax.lax.stop_gradient(hulp2_mat) prev_dtype = tapered_dists.dtype N = len(species) # might cause nan issues if 0s not handled well A = safe_mask( hulp2_mat != 0, lambda x: tapered_dists * 14.4 / x, hulp2_mat, 0.0 ) my_idemp = idempotential[species] my_elect = electronegativity[species] * atom_mask def to_dense(): """ Create a dense matrix """ A_ = jax.vmap( lambda j: jax.vmap(lambda i: jnp.sum(A[i] * (nbr_inds[i] == j)))( jnp.arange(N) ) )(jnp.arange(N)) A_ = A_.at[jnp.diag_indices(N)].add(2.0 * my_idemp) matrix = jnp.zeros(shape=(N + 1, N + 1), dtype=prev_dtype) matrix = matrix.at[:N, :N].set(A_) matrix = matrix.at[N, :N].set(atom_mask) matrix = matrix.at[:N, N].set(atom_mask) matrix = matrix.at[N, N].set(0.0) return matrix mask = nbr_inds != N def SPMV_dense(vec): """ Matrix-free mat-vec """ res = jnp.zeros(shape=(N + 1,), dtype=jnp.float64) s_vec = vec.astype(prev_dtype)[nbr_inds] * mask vals = jax.vmap(jnp.dot)(A, s_vec) + (my_idemp * 2.0) * vec[:N] + vec[N] res = res.at[:N].set(vals * atom_mask) res = res.at[N].set(jnp.sum(vec[:N] * atom_mask)) # sum of charges return res b = jnp.zeros(shape=(N + 1,), dtype=jnp.float64) b = b.at[:N].set(-1 * my_elect) b = b.at[N].set(total_charge) if max_solver_iter == -1: charges = jnp.linalg.solve(to_dense(), b) else: charges, conv_info = linalg.cg( SPMV_dense, b, x0=init_charges, tol=tol, maxiter=max_solver_iter ) charges = charges.astype(prev_dtype) charges = charges.at[:-1].multiply(atom_mask) return charges
[docs] def calculate_acks2_charges( species: Array, atom_mask: Array, nbr_inds: Array, hulp2_mat: Array, bond_softness: Array, tapered_dists: Array, idempotential: Array, electronegativity: Array, total_charge: float, backprop_solve: bool = False, tol: float = 1e-06, ): if backprop_solve == False: tapered_dists = jax.lax.stop_gradient(tapered_dists) hulp2_mat = jax.lax.stop_gradient(hulp2_mat) bond_softness = jax.lax.stop_gradient(bond_softness) prev_dtype = tapered_dists.dtype N = len(species) # might cause nan issues if 0s not handled well A = jnp.where(hulp2_mat == 0, 0.0, tapered_dists * 14.4 / hulp2_mat) my_idemp = idempotential[species] my_elect = electronegativity[species] B = bond_softness def to_dense(): a_inds = jnp.arange(N) A_ = jax.vmap( lambda j: jax.vmap(lambda i: jnp.sum(A[i] * (nbr_inds[i] == j)))(a_inds) )(a_inds) diag_inds = jnp.diag_indices(N) A_ = A_.at[diag_inds].add(2.0 * my_idemp) B_ = jax.vmap( lambda j: jax.vmap( lambda i: jnp.sum(B[i] * ((nbr_inds[i] == j) & (i != j))) )(a_inds) )(a_inds) diags_B = jnp.sum(B_, axis=0) B_ = B_.at[diag_inds].add(-1 * diags_B) matrix = jnp.zeros(shape=(2 * N + 2, 2 * N + 2), dtype=prev_dtype) matrix = matrix.at[:N, :N].set(A_) matrix = matrix.at[N : 2 * N, N : 2 * N].set(B_) matrix = matrix.at[N : 2 * N, :N].set(jnp.eye(N)) matrix = matrix.at[:N, N : 2 * N].set(jnp.eye(N)) matrix = matrix.at[2 * N, N : 2 * N].set(atom_mask) matrix = matrix.at[N : 2 * N, 2 * N].set(atom_mask) matrix = matrix.at[2 * N + 1, :N].set(atom_mask) matrix = matrix.at[:N, 2 * N + 1].set(atom_mask) return matrix def SPMV_dense(vec): res = jnp.zeros(shape=(N + 1,), dtype=jnp.float64) s_vec = vec.astype(prev_dtype)[nbr_inds] * (nbr_inds != N) vals = jax.vmap(jnp.dot)(A, s_vec) + (my_idemp * 2.0) * vec[:N] + vec[N] res = res.at[:N].set(vals) res = res.at[N].set(jnp.sum(vec[:N])) # sum of charges return res b = jnp.zeros(shape=(2 * N + 2,), dtype=jnp.float64) b = b.at[:N].set(-1 * my_elect) b = b.at[N : 2 * N].set(total_charge / N) b = b.at[2 * N].set(0.0) b = b.at[2 * N + 1].set(total_charge) # matrix = to_dense() # print(matrix) # charges = jnp.linalg.solve(matrix, b) charges, conv_info = linalg.cg(SPMV_dense, b, tol=tol, maxiter=9999) charges = charges.astype(prev_dtype) return charges[:N] * atom_mask, charges[N : 2 * N] * atom_mask
[docs] def calculate_coulomb_pot( nbr_inds: Array, atom_mask: Array, hulp2_mat: Array, tapered_dists: Array, charges: Array, ): N = len(atom_mask) mask = (atom_mask.reshape(-1, 1) * atom_mask[nbr_inds]) * (nbr_inds != N) charge_mat = charges.reshape(-1, 1) * charges[nbr_inds] eph_mat = safe_mask( mask, lambda x: c1c * charge_mat / (x + 1e-20), hulp2_mat, 0.0 ) ephtap_mat = eph_mat * tapered_dists * mask total_pot = high_precision_sum(ephtap_mat) / 2.0 return total_pot
[docs] def calculate_charge_energy( species: Array, charges: Array, idempotential: Array, electronegativity: Array ): ech = high_precision_sum( 23.02 * ( electronegativity[species] * charges + idempotential[species] * jnp.square(charges) ) ) return ech
[docs] def calculate_acks2_coulomb_pot( nbr_inds: Array, atom_mask: Array, effpot: Array, bond_softness: Array ): hulp2 = effpot.reshape(-1, 1) - effpot[nbr_inds] eph = (-0.25 * 23.02) * bond_softness * hulp2 * hulp2 return high_precision_sum(eph)
[docs] def calculate_charge_energy_acks2(charges: Array, effpot: Array): ech_acks2 = high_precision_sum(23.02 * charges * effpot) return ech_acks2
[docs] def calculate_vdw_pot( species: Array, far_nbr_mask: Array, nbr_inds: Array, dists: Array, tapered_dists: Array, force_field: ForceField, ): N = len(species) neigh_types = species[nbr_inds] vop = jnp.power( force_field.vop.reshape(-1, 1), force_field.vdw_shiedling / 2.0 ) gamwh_mat = vop * vop.transpose() # gamwh_mat = (1.0 / gamwh_mat) ** force_field.vdw_shiedling gamwh_mat = 1.0 / gamwh_mat gamwco_mat = gamwh_mat[neigh_types, species.reshape(-1, 1)] # select the required values p1_mat = force_field.p1co[neigh_types, species.reshape(-1, 1)] p2_mat = force_field.p2co[neigh_types, species.reshape(-1, 1)] p3_mat = force_field.p3co[neigh_types, species.reshape(-1, 1)] hulpw_mat = ( safe_mask(dists > 0, lambda x: x**force_field.vdw_shiedling, dists, 0.0) + gamwco_mat ) rrw_mat = jnp.power(hulpw_mat, (1.0 / force_field.vdw_shiedling)) # if p = 0 -> gradient will be 0 temp_val2 = p3_mat * (1.0 - rrw_mat / p1_mat) # gradient nan issue fix h1_mat = jnp.exp(temp_val2) h2_mat = jnp.exp(0.5 * temp_val2) ewh_mat = p2_mat * (h1_mat - 2.0 * h2_mat) ewhtap_mat = ewh_mat * tapered_dists ewhtap_mat = ewhtap_mat * far_nbr_mask total_pot = high_precision_sum(ewhtap_mat) / 2.0 return total_pot
[docs] def calculate_bo( nbr_inds: Array, nbr_dist: Array, species: Array, species_AN: Array, force_field: ForceField, ): """ Usage: first update/allocate neighborlist will be called then the info will be passed to this function for now, assume the format is "Dense" """ N = len(species) atomic_num1 = species_AN.reshape(-1, 1) atomic_num2 = species_AN[nbr_inds] # O: 8, C:6 triple_bond1 = jnp.logical_and(atomic_num1 == 8, atomic_num2 == 6) triple_bond2 = jnp.logical_and(atomic_num1 == 6, atomic_num2 == 8) triple_bond = jnp.logical_or(triple_bond1, triple_bond2) [cov_pot, bo, bopi, bopi2, abo] = calculate_covbon_pot( nbr_inds, nbr_dist, nbr_inds != N, species, triple_bond, force_field ) return bo
[docs] def calculate_covbon_pot( nbr_inds: Array, nbr_dist: Array, nbr_mask: Array, species: Array, triple_bond: Array, force_field: ForceField, tapered_reaxff: bool = False, ): N = len(species) nbr_mask = nbr_mask & (nbr_dist > 0) neigh_types = species[nbr_inds] atom_inds = jnp.arange(N).reshape(-1, 1) species = species.reshape(-1, 1) # save the chosen dtype dtype = nbr_dist.dtype # symm = (atom_inds == nbr_inds).astype(dtype) + 1 # symm = 1.0 / symm # since we store the close nbr list full, we later divide the summation by 2 # to compansate double counting, the self bonds are not double counted # so they will be multipled by 0.5 as expected symm = 1.0 my_rob1 = force_field.rob1[neigh_types, species] my_rob2 = force_field.rob2[neigh_types, species] my_rob3 = force_field.rob3[neigh_types, species] my_ptp = force_field.ptp[neigh_types, species] my_pdp = force_field.pdp[neigh_types, species] my_popi = force_field.popi[neigh_types, species] my_pdo = force_field.pdo[neigh_types, species] my_bop1 = force_field.bop1[neigh_types, species] my_bop2 = force_field.bop2[neigh_types, species] my_de1 = force_field.de1[neigh_types, species] my_de2 = force_field.de2[neigh_types, species] my_de3 = force_field.de3[neigh_types, species] my_psp = force_field.psp[neigh_types, species] my_psi = force_field.psi[neigh_types, species] # TODO: tempo fix, due to numerical problems in this function # use double precision then cast it back to the original type nbr_dist = nbr_dist.astype(jnp.float64) rhulp = safe_mask( my_rob1 > 0, lambda x: nbr_dist / (x + 1e-10), my_rob1, 1e-7 ) rhulp2 = safe_mask( my_rob2 > 0, lambda x: nbr_dist / (x + 1e-10), my_rob2, 1e-7 ) rhulp3 = safe_mask( my_rob3 > 0, lambda x: nbr_dist / (x + 1e-10), my_rob3, 1e-7 ) rh2p = rhulp2**my_ptp ehulpp = jnp.exp(my_pdp * rh2p) rh2pp = rhulp3**my_popi ehulppp = jnp.exp(my_pdo * rh2pp) rh2 = rhulp**my_bop2 ehulp = (1 + force_field.cutoff) * jnp.exp(my_bop1 * rh2) mask1 = (my_rob1 > 0) & nbr_mask mask2 = (my_rob2 > 0) & nbr_mask mask3 = (my_rob3 > 0) & nbr_mask full_mask = mask1 | mask2 | mask3 ehulp = safe_mask(mask1, lambda x: x, ehulp, 0) ehulpp = safe_mask(mask2, lambda x: x, ehulpp, 0) ehulppp = safe_mask(mask3, lambda x: x, ehulppp, 0) bor = ehulp + ehulpp + ehulppp bopi = ehulpp bopi2 = ehulppp if tapered_reaxff: bo = taper_inc(bor, force_field.cutoff, 4.0 * force_field.cutoff) * ( bor - force_field.cutoff ) else: bo = bor - force_field.cutoff bo = jnp.where(bo <= 0, 0.0, bo) abo = jnp.sum(bo, axis=1) bo, bopi, bopi2 = calculate_boncor_pot( nbr_inds, nbr_mask, species.flatten(), bo, bopi, bopi2, abo, force_field ) abo = jnp.sum(bo * nbr_mask, axis=1) bosia = bo - bopi - bopi2 bosia = jnp.clip(bosia, 0, float('inf')) de1h = symm * my_de1 de2h = symm * my_de2 de3h = symm * my_de3 # add 1e-20 so that ln(a) is not nan bopo1 = safe_mask((bosia != 0), lambda x: (x + 1e-20) ** my_psp, bosia, 0) exphu1 = jnp.exp(my_psi * (1.0 - bopo1)) ebh = -de1h * bosia * exphu1 - de2h * bopi - de3h * bopi2 ebh = jnp.where(bo <= 0, 0.0, ebh) # Stabilisation terminal triple bond in CO ba = (bo - 2.5) * (bo - 2.5) exphu = jnp.exp(-force_field.trip_stab8 * ba) abo_j2 = abo[nbr_inds] abo_j1 = abo[atom_inds] obo_a = abo_j1 - bo obo_b = abo_j2 - bo exphua1 = jnp.exp(-force_field.trip_stab4 * obo_a) exphub1 = jnp.exp(-force_field.trip_stab4 * obo_b) my_aval = force_field.aval[species] + force_field.aval[neigh_types] triple_bond = jnp.where(bo < 1.0, 0.0, triple_bond) ovoab = abo_j1 + abo_j2 - my_aval exphuov = jnp.exp(force_field.trip_stab5 * ovoab) hulpov = 1.0 / (1.0 + 25.0 * exphuov) estriph = force_field.trip_stab11 * exphu * hulpov * (exphua1 + exphub1) eb = ebh + estriph * triple_bond eb = safe_mask(full_mask, lambda x: x, eb, 0) cov_pot = high_precision_sum(eb) / 2.0 # cast the arrays back to the original dtype cov_pot = cov_pot.astype(dtype) # bo = bo.astype(dtype) # bopi = bopi.astype(dtype) # bopi2 = bopi2.astype(dtype) # abo = abo.astype(dtype) symm = (atom_inds == nbr_inds).astype(dtype) + 1 symm = 1.0 / symm # to correct for self bonds, multiply by 0.5 # bo = bo * symm # bopi = bopi * symm # bopi2 = bopi2 * symm return [cov_pot, bo, bopi, bopi2, abo]
[docs] def calculate_boncor_pot( nbr_inds: Array, nbr_mask: Array, species: Array, bo: Array, bopi: Array, bopi2: Array, abo: Array, force_field: ForceField, ): neigh_types = species[nbr_inds] species = species.reshape(-1, 1) abo_j2 = abo[nbr_inds] abo_j1 = abo.reshape(-1, 1) aval_j2 = force_field.aval[neigh_types] aval_j1 = force_field.aval[species] vp131 = safe_sqrt(force_field.bo131[species] * force_field.bo131[neigh_types]) vp132 = safe_sqrt(force_field.bo132[species] * force_field.bo132[neigh_types]) vp133 = safe_sqrt(force_field.bo133[species] * force_field.bo133[neigh_types]) my_ovc = force_field.ovc[neigh_types, species] ov_j1 = abo_j1 - aval_j1 ov_j2 = abo_j2 - aval_j2 exp11 = jnp.exp(-force_field.over_coord1 * ov_j1) exp21 = jnp.exp(-force_field.over_coord1 * ov_j2) exphu1 = jnp.exp(-force_field.over_coord2 * ov_j1) exphu2 = jnp.exp(-force_field.over_coord2 * ov_j2) exphu12 = exphu1 + exphu2 ovcor = -(1.0 / force_field.over_coord2) * jnp.log(0.50 * exphu12) huli = aval_j1 + exp11 + exp21 hulj = aval_j2 + exp11 + exp21 corr1 = huli / (huli + ovcor) corr2 = hulj / (hulj + ovcor) corrtot = 0.50 * (corr1 + corr2) corrtot = jnp.where(my_ovc > 0.001, corrtot, 1.0) my_v13cor = force_field.v13cor[neigh_types, species] # update vval3 based on amas value vval3 = jnp.where( force_field.amas < 21.0, force_field.valf, force_field.vval3 ) vval3_j1 = vval3[species] vval3_j2 = vval3[neigh_types] ov_j11 = abo_j1 - vval3_j1 ov_j22 = abo_j2 - vval3_j2 cor1 = vp131 * bo * bo - ov_j11 cor2 = vp131 * bo * bo - ov_j22 exphu3 = jnp.exp(-vp132 * cor1 + vp133) exphu4 = jnp.exp(-vp132 * cor2 + vp133) bocor1 = 1.0 / (1.0 + exphu3) bocor2 = 1.0 / (1.0 + exphu4) bocor1 = jnp.where(my_v13cor > 0.001, bocor1, 1.0) bocor2 = jnp.where(my_v13cor > 0.001, bocor2, 1.0) bo = bo * corrtot * bocor1 * bocor2 threshold = 0.0 # fortran threshold: 1e-10 bo = safe_mask(nbr_mask & (bo > threshold), lambda x: x, bo, 0) corrtot2 = corrtot * corrtot bopi = bopi * corrtot2 * bocor1 * bocor2 bopi2 = bopi2 * corrtot2 * bocor1 * bocor2 bopi = safe_mask(nbr_mask & (bopi > threshold), lambda x: x, bopi, 0) bopi2 = safe_mask(nbr_mask & (bopi2 > threshold), lambda x: x, bopi2, 0) return bo, bopi, bopi2
[docs] def smooth_lone_pair_casting( number, p_lambda=0.9999, l1=-1.3, l2=-0.3, r1=0.3, r2=1.3 ): part_2 = (1 / jnp.pi) * ( jnp.arctan( p_lambda * jnp.sin(2 * jnp.pi * number) / (p_lambda * jnp.cos(2 * jnp.pi * number) - 1) ) ) f_R = number - 1 / 2 - part_2 f_L = number + 1 / 2 - part_2 result = jnp.where( number < l1, f_L, jnp.where( number < l2, f_L * taper(number, l1, l2), jnp.where( number < r1, 0, jnp.where(number <= r2, f_R * taper_inc(number, r1, r2), f_R), ), ), ) return result
[docs] def calculate_lonpar_pot( species: Array, atom_mask: Array, abo: Array, force_field: ForceField ): # handle this part in double preicison prev_type = abo.dtype # abo = abo.astype(jnp.float64) # Determine number of lone pairs on atoms voptlp = 0.5 * (force_field.stlp[species] - force_field.aval[species]) vund = abo - force_field.stlp[species] # vund_div2 = smooth_lone_pair_casting(vund/2.0) # (vund/2.0).astype(np.int32) vund_div2 = (vund / 2.0).astype(jnp.int32).astype(prev_type) vlph = 2.0 * vund_div2 vlpex = vund - vlph expvlp = jnp.exp(-force_field.par_16 * (2.0 + vlpex) * (2.0 + vlpex)) vlp = expvlp - vund_div2 # Calculate lone pair energy diffvlp = voptlp - vlp exphu1 = jnp.exp(-75.0 * diffvlp) hulp1 = 1.0 / (1.0 + exphu1) elph = force_field.vlp1[species] * diffvlp * hulp1 elph = safe_mask(atom_mask, lambda x: x, elph, 0) elp = high_precision_sum(elph) elp = elp.astype(prev_type) vlp = vlp.astype(prev_type) return [elp, vlp]
[docs] def calculate_ovcor_pot( species: Array, atoms_AN: Array, atom_mask: Array, nbr_inds: Array, nbr_dists: Array, nbr_mask: Array, bo: Array, bopi: Array, bopi2: Array, abo: Array, vlp: Array, force_field: ForceField, ): my_stlp = force_field.stlp[species] my_aval = force_field.aval[species] my_amas = force_field.amas[species] my_valp1 = force_field.valp1[species] my_vovun = force_field.vovun[species] neigh_types = species[nbr_inds] # this function is numerically sensitive so use double precision prev_type = nbr_dists.dtype # bo = bo.astype(jnp.float64) # bopi = bopi.astype(jnp.float64) # bopi2 = bopi2.astype(jnp.float64) # abo = abo.astype(jnp.float64) vlptemp = jnp.where(my_amas > 21.0, 0.50 * (my_stlp - my_aval), vlp) dfvl = jnp.where(my_amas > 21.0, 0.0, 1.0) # Calculate overcoordination energy # Valency is corrected for lone pairs voptlp = 0.50 * (my_stlp - my_aval) vlph = voptlp - vlptemp diffvlph = dfvl * vlph diffvlp2 = dfvl.reshape(-1, 1) * vlph[nbr_inds] # Determine coordination neighboring atoms part_1 = bopi + bopi2 part_2 = abo[nbr_inds] - force_field.aval[neigh_types] - diffvlp2 sumov = jnp.sum(part_1 * part_2, axis=1) mult_vov_de1 = force_field.vover * force_field.de1 my_mult_vov_de1 = mult_vov_de1[species.reshape(-1, 1), neigh_types] sumov2 = jnp.sum(my_mult_vov_de1 * bo, axis=1) # Gradient non issue fix exphu1 = jnp.exp(force_field.par_32 * sumov) vho = 1.0 / (1.0 + force_field.par_33 * exphu1) diffvlp = diffvlph * vho vov1 = abo - my_aval - diffvlp # to solve the nan issue exphuo = jnp.exp(my_vovun * vov1) hulpo = 1.0 / (1.0 + exphuo) hulpp = 1.0 / (vov1 + my_aval + 1e-08) eah = sumov2 * hulpp * hulpo * vov1 ea = high_precision_sum(eah * atom_mask) # Calculate undercoordination energy # Gradient non issue fix exphu2 = jnp.exp(force_field.par_10 * sumov) vuhu1 = 1.0 + force_field.par_9 * exphu2 hulpu2 = 1.0 / vuhu1 exphu3 = -jnp.exp(force_field.par_7 * vov1) hulpu3 = -(1.0 + exphu3) dise2 = my_valp1 # Gradient non issue fix exphuu = jnp.exp(-my_vovun * vov1) hulpu = 1.0 / (1.0 + exphuu) eahu = dise2 * hulpu * hulpu2 * hulpu3 eahu = jnp.where(my_valp1 < 0, 0, eahu) eahu = safe_mask(atom_mask, lambda x: x, eahu, 0) ea = ea + high_precision_sum(eahu * atom_mask) # cast the result back to the original type ea = ea.astype(prev_type) # Correction for C2 PART effecting (eplh) # TODO: Most FFs do not activate this part, so I should use lax.cond # to decide if the computation is needed, commented off for now """ par6_mask = jnp.abs(force_field.par_6) > 0.001 src_C_mask = atoms_AN == 6 dst_C_mask = atoms_AN[nbr_inds] == 6 C_C_bonds_mask = src_C_mask.reshape(-1,1) & dst_C_mask C_C_bonds_mask = C_C_bonds_mask & nbr_mask & par6_mask vov4 = abo - my_aval vov4 = vov4[nbr_inds] vov3 = bo - vov4 - 0.040 * (vov4 ** 4) vov3_mask = vov3 > 3.0 elph = force_field.par_6 * (vov3 -3.0)**2 elph = elph * (vov3_mask & C_C_bonds_mask) c2_corr = high_precision_sum(elph) ea = ea + c2_corr """ return ea
[docs] def calculate_valency_pot( species: Array, body_3_inds: Array, body_3_angles: Array, body_3_mask: Array, nbr_inds: Array, vlp: Array, bo: Array, bopi: Array, bopi2: Array, abo: Array, force_field: ForceField, tapered_reaxff: bool = False, ): prev_type = bo.dtype center = body_3_inds[:, 0] neigh1_lcl = body_3_inds[:, 1] neigh2_lcl = body_3_inds[:, 2] neigh1_glb = nbr_inds[center, neigh1_lcl] neigh2_glb = nbr_inds[center, neigh2_lcl] cent_types = species[center] neigh1_types = species[neigh1_glb] neigh2_types = species[neigh2_glb] val_angles = body_3_angles boa = bo[center, neigh1_lcl] bob = bo[center, neigh2_lcl] if tapered_reaxff: boa = taper_inc(boa, force_field.cutoff2, 4.0 * force_field.cutoff2) * ( boa - force_field.cutoff2 ) bob = taper_inc(bob, force_field.cutoff2, 4.0 * force_field.cutoff2) * ( bob - force_field.cutoff2 ) complete_mask = body_3_mask else: # Fortan comment: Scott Habershon recommendation March 2009 mask = jnp.where(boa * bob < 0.00001, 0, 1) complete_mask = mask * body_3_mask boa = boa - force_field.cutoff2 bob = bob - force_field.cutoff2 complete_mask = complete_mask & (boa > 0) & (bob > 0) # thresholding boa = jnp.clip(boa, 0, float('inf')) bob = jnp.clip(bob, 0, float('inf')) # calculate SBO term # calculate sbo2 and vmbo for every atom in the sim.sys. sbo2 = jnp.sum(bopi, axis=1) + jnp.sum(bopi2, axis=1) vmbo = jnp.prod( jnp.exp(-(bo**8)), dtype=jnp.float64, axis=1 ) # .astype(prev_type) my_abo = abo[center] # calculate for every atom exbo = abo - force_field.valf[species] my_exbo = exbo[center] # TODO: (REVISE LATER) cast the data to double to solve nan issue in division my_exbo = jnp.array(my_exbo, dtype=jnp.float64) my_vkac = force_field.vkac[neigh1_types, cent_types, neigh2_types] evboadj = 1.0 # why? # to solve the nan issue, clip the vlaues expun = jnp.exp(-my_vkac * my_exbo) expun2 = jnp.exp(force_field.val_par15 * my_exbo) htun1 = 2.0 + expun2 htun2 = 1.0 + expun + expun2 my_vval4 = force_field.vval4[cent_types] evboadj2 = my_vval4 - (my_vval4 - 1.0) * (htun1 / htun2) evboadj2 = jnp.array(evboadj2, dtype=prev_type) # calculate for every atom exlp1 = abo - force_field.stlp[species] exlp2 = 2.0 * ((exlp1 / 2.0).astype(jnp.int32)) # integer casting # exlp2 = 2.0 * smooth_lone_pair_casting(exlp1/2.0) exlp = exlp1 - exlp2 vlpadj = jnp.where(exlp < 0.0, vlp, 0.0) # vlp comes from lone pair # calculate for every atom sbo2 = sbo2 + (1 - vmbo) * (-exbo - force_field.val_par34 * vlpadj) sbo2 = jnp.clip(sbo2, 0, 2.0) # add 1e-20 so that ln(a) is not nan sbo2 = vectorized_cond( sbo2 < 1, lambda x: (x + 1e-15) ** force_field.val_par17, lambda x: sbo2, sbo2, ) sbo2 = vectorized_cond( sbo2 >= 1, lambda x: 2.0 - (2.0 - x + 1e-15) ** force_field.val_par17, lambda x: sbo2, sbo2, ) expsbo = jnp.exp(-force_field.val_par18 * (2.0 - sbo2)) my_expsbo = expsbo[center] thba = force_field.th0[neigh1_types, cent_types, neigh2_types] thetao = 180.0 - thba * (1.0 - my_expsbo) thetao = thetao * dgrrdn thdif = thetao - val_angles thdi2 = thdif * thdif my_vka = force_field.vka[neigh1_types, cent_types, neigh2_types] my_vka3 = force_field.vka3[neigh1_types, cent_types, neigh2_types] exphu = my_vka * jnp.exp(-my_vka3 * thdi2) exphu2 = my_vka - exphu # To avoid linear Me-H-Me angles (6/6/06) exphu2 = jnp.where(my_vka < 0.0, exphu2 - my_vka, exphu2) my_vval2 = force_field.vval2[neigh1_types, cent_types, neigh2_types] # add 1e-20 so that ln(a) is not nan boap = (boa + 1e-20) ** my_vval2 bobp = (bob + 1e-20) ** my_vval2 my_vval1 = force_field.vval1[cent_types] exa = jnp.exp(-my_vval1 * boap) exb = jnp.exp(-my_vval1 * bobp) exa2 = 1.0 - exa exb2 = 1.0 - exb evh = evboadj2 * evboadj * exa2 * exb2 * exphu2 evh = safe_mask(complete_mask, lambda x: x, evh, 0) total_pot = high_precision_sum(evh * complete_mask).astype(prev_type) # Calculate penalty for two double bonds in valency angle exbo = abo - force_field.aval[species] expov = jnp.exp(force_field.val_par22 * exbo) expov2 = jnp.exp(-force_field.val_par21 * exbo) htov1 = 2.0 + expov2 htov2 = 1.0 + expov + expov2 ecsboadj = htov1 / htov2 ecsboadj = jnp.array(ecsboadj, dtype=prev_type) my_ecsboadj = ecsboadj[center] # for the center atom my_vkap = force_field.vkap[neigh1_types, cent_types, neigh2_types] exphu1 = jnp.exp(-force_field.val_par20 * (boa - 2.0) * (boa - 2.0)) exphu2 = jnp.exp(-force_field.val_par20 * (bob - 2.0) * (bob - 2.0)) epenh = my_vkap * my_ecsboadj * exphu1 * exphu2 epenh = safe_mask(complete_mask, lambda x: x, epenh, 0) total_penalty = high_precision_sum(epenh).astype(prev_type) # Calculate valency angle conjugation energy abo_i = abo[neigh1_glb] abo_k = abo[neigh2_glb] # (i,j,k) will give abo for k unda = abo_i - boa ovb = my_abo - force_field.vval3[cent_types] undc = abo_k - bob ba = (boa - 1.50) * (boa - 1.50) bb = (bob - 1.50) * (bob - 1.50) exphua = jnp.exp(-force_field.val_par31 * ba) exphub = jnp.exp(-force_field.val_par31 * bb) exphuua = jnp.exp(-force_field.val_par39 * unda * unda) exphuob = jnp.exp(force_field.val_par3 * ovb) exphuob = jnp.array(exphuob, dtype=jnp.float64) exphuuc = jnp.exp(-force_field.val_par39 * undc * undc) hulpob = 1.0 / (1.0 + exphuob) hulpob = jnp.array(hulpob, dtype=prev_type) my_vka8 = force_field.vka8[neigh1_types, cent_types, neigh2_types] ecoah = my_vka8 * exphua * exphub * exphuua * exphuuc * hulpob ecoah = safe_mask(complete_mask, lambda x: x, ecoah, 0) total_conj = high_precision_sum(ecoah).astype(prev_type) return [total_pot, total_penalty, total_conj]
[docs] def calculate_torsion_pot( species: Array, body_4_inds: Array, body_4_angles: Array, body_4_mask: Array, nbr_inds: Array, bo: Array, bopi: Array, abo: Array, force_field: ForceField, tapered_reaxff: bool = False, tors_2013: bool = False, ): hsin = body_4_angles[0] # hsin = sinhd * sinhe arg = body_4_angles[1] prev_type = hsin.dtype # bo = bo.astype(jnp.float64) # left : nbr_inds[ind2][n21] or nbr_inds[center1][body_4_inds[:,1]] # center1: ind2 or body_4_inds[:,0] # center2: neigh_inds[ind2][n22] or nbr_inds[center1][body_4_inds[:,2]] # right: nbr_inds[center2][n31] or nbr_inds[center2][body_4_inds[:,3]] center1_glb = body_4_inds[:, 0] left_lcl = body_4_inds[:, 1] # local to center1 left_glb = nbr_inds[center1_glb, left_lcl] center2_lcl = body_4_inds[:, 2] # local to center1 center2_glb = nbr_inds[center1_glb, center2_lcl] right_lcl = body_4_inds[:, 3] # local to center2 right_glb = nbr_inds[center2_glb, right_lcl] my_v1 = force_field.v1[ species[left_glb], species[center1_glb], species[center2_glb], species[right_glb], ] my_v2 = force_field.v2[ species[left_glb], species[center1_glb], species[center2_glb], species[right_glb], ] my_v3 = force_field.v3[ species[left_glb], species[center1_glb], species[center2_glb], species[right_glb], ] my_v4 = force_field.v4[ species[left_glb], species[center1_glb], species[center2_glb], species[right_glb], ] my_vconj = force_field.vconj[ species[left_glb], species[center1_glb], species[center2_glb], species[right_glb], ] exbo1 = abo - force_field.valf[species] exbo1_2 = exbo1[center1_glb] # center1 exbo2_3 = exbo1[center2_glb] # center2 htovt = exbo1_2 + exbo2_3 expov = jnp.exp(force_field.par_26 * htovt) expov2 = jnp.exp(-force_field.par_25 * htovt) htov1 = 2.0 + expov2 htov2 = 1.0 + expov + expov2 etboadj = htov1 / htov2 etboadj = jnp.array(etboadj, dtype=prev_type) bo2t = 2.0 - bopi[center1_glb, center2_lcl] - etboadj bo2p = bo2t * bo2t bocor2 = jnp.exp(my_v4 * bo2p) arg2 = arg * arg ethhulp = ( 0.5 * my_v1 * (1.0 + arg) + my_v2 * bocor2 * (1.0 - arg2) + my_v3 * (0.5 + 2.0 * arg2 * arg - 1.5 * arg) ) boa = bo[center1_glb, left_lcl] bob = bo[center1_glb, center2_lcl] boc = bo[center2_glb, right_lcl] mult_bo_mask = jnp.where(boa * bob * boc > force_field.cutoff2, 1, 0) complete_mask = body_4_mask * mult_bo_mask if tapered_reaxff: boa = taper_inc(boa, force_field.cutoff2, 4.0 * force_field.cutoff2) * ( boa - force_field.cutoff2 ) bob = taper_inc(bob, force_field.cutoff2, 4.0 * force_field.cutoff2) * ( bob - force_field.cutoff2 ) boc = taper_inc(boc, force_field.cutoff2, 4.0 * force_field.cutoff2) * ( boc - force_field.cutoff2 ) else: boa = boa - force_field.cutoff2 bob = bob - force_field.cutoff2 boc = boc - force_field.cutoff2 bo_mask = jnp.where(boa > 0, 1, 0) bo_mask = jnp.where(bob > 0, bo_mask, 0) bo_mask = jnp.where(boc > 0, bo_mask, 0) complete_mask = bo_mask * complete_mask if tors_2013: exphua = jnp.exp(-2 * force_field.par_24 * boa**2) exphub = jnp.exp(-2 * force_field.par_24 * bob**2) exphuc = jnp.exp(-2 * force_field.par_24 * boc**2) else: exphua = jnp.exp(-force_field.par_24 * boa) exphub = jnp.exp(-force_field.par_24 * bob) exphuc = jnp.exp(-force_field.par_24 * boc) bocor4 = (1.0 - exphua) * (1.0 - exphub) * (1.0 - exphuc) eth = hsin * ethhulp * bocor4 eth = safe_mask(complete_mask, lambda x: x, eth, 0) tors_pot = high_precision_sum(eth).astype(prev_type) # calculate conjugation pot ba = (boa - 1.50) * (boa - 1.50) bb = (bob - 1.50) * (bob - 1.50) bc = (boc - 1.50) * (boc - 1.50) exphua1 = jnp.exp(-force_field.par_28 * ba) exphub1 = jnp.exp(-force_field.par_28 * bb) exphuc1 = jnp.exp(-force_field.par_28 * bc) sbo = exphua1 * exphub1 * exphuc1 arghu0 = (arg2 - 1.0) * hsin # hsin = sinhd*sinhe ehulp = my_vconj * (arghu0 + 1.0) ecoh = ehulp * sbo ecoh = safe_mask(complete_mask, lambda x: x, ecoh, 0) conj_pot = high_precision_sum(ecoh).astype(prev_type) return [tors_pot, conj_pot]
[docs] def calculate_hb_pot( species: Array, hbond_inds: Array, hbond_angles: Array, hbond_mask: Array, close_nbr_inds: Array, far_nbr_inds: Array, bo: Array, force_field: ForceField, tapered_reaxff: bool = False, ): # inds: donor ind, local acceptor ind (close neigh.), local ind_2 (far neigh) angles = hbond_angles[0, :] dists = hbond_angles[1, :] prev_type = dists.dtype glb_center = hbond_inds[:, 0] lcl_close_nbr = hbond_inds[:, 1] glb_close_nbr = close_nbr_inds[glb_center, lcl_close_nbr] lcl_far_nbr = hbond_inds[:, 2] glb_far_nbr = far_nbr_inds[glb_center, lcl_far_nbr] cent_types = species[glb_center] close_nbr_types = species[glb_close_nbr] far_nbr_types = species[glb_far_nbr] my_rhb = force_field.rhb[close_nbr_types, cent_types, far_nbr_types] my_dehb = force_field.dehb[close_nbr_types, cent_types, far_nbr_types] my_vhb1 = force_field.vhb1[close_nbr_types, cent_types, far_nbr_types] my_vhb2 = force_field.vhb2[close_nbr_types, cent_types, far_nbr_types] bo = bo.astype(prev_type) boa = bo[glb_center, lcl_close_nbr] if tapered_reaxff: boa_mult = taper_inc(boa, 0.01, 4.0 * 0.01) dist_mult = taper(boa, 0.9 * 7.5, 7.5) else: boa_mult = 1.0 dist_mult = 1.0 boa = jnp.where(boa > 0.01, boa, 0.0) hbond_mask = hbond_mask & (dists < 7.5) & (dists > 0.0) my_rhb = my_rhb + 1e-10 dists = dists + 1e-10 # to not get divide by zero rhu1 = my_rhb / dists rhu2 = dists / my_rhb exphu1 = jnp.exp(-my_vhb1 * boa) exphu2 = jnp.exp(-my_vhb2 * (rhu1 + rhu2 - 2.0)) ehbh = ( ( (1.0 - exphu1) * my_dehb * exphu2 * jnp.power(jnp.sin((angles + 1e-10) / 2.0), 4) ) * boa_mult * dist_mult ) ehbh = safe_mask(hbond_mask, lambda x: x, ehbh, 0) hb_pot = high_precision_sum(ehbh).astype(prev_type) return hb_pot
[docs] def taper(value, low_tap_rad, up_tap_rad): """ Decreasing tapering function 1 at low_tap_rad and 0 at up_tap_rad smoothly taper the value in between """ R = value - low_tap_rad up_tap_rad = up_tap_rad - low_tap_rad low_tap_rad = 0.0 R2 = R * R R3 = R2 * R SWB = up_tap_rad SWA = low_tap_rad D1 = SWB - SWA D7 = D1**7.0 SWA2 = SWA * SWA SWA3 = SWA2 * SWA SWB2 = SWB * SWB SWB3 = SWB2 * SWB SWC7 = 20.0 SWC6 = -70.0 * (SWA + SWB) SWC5 = 84.0 * (SWA2 + 3.0 * SWA * SWB + SWB2) SWC4 = -35.0 * (SWA3 + 9.0 * SWA2 * SWB + 9.0 * SWA * SWB2 + SWB3) SWC3 = 140.0 * (SWA3 * SWB + 3.0 * SWA2 * SWB2 + SWA * SWB3) SWC2 = -210.0 * (SWA3 * SWB2 + SWA2 * SWB3) SWC1 = 140.0 * SWA3 * SWB3 SWC0 = ( -35.0 * SWA3 * SWB2 * SWB2 + 21.0 * SWA2 * SWB3 * SWB2 - 7.0 * SWA * SWB3 * SWB3 + SWB3 * SWB3 * SWB ) SW = ( SWC7 * R3 * R3 * R + SWC6 * R3 * R3 + SWC5 * R3 * R2 + SWC4 * R2 * R2 + SWC3 * R3 + SWC2 * R2 - SWC1 * R + SWC0 ) / D7 SW = jnp.where(R < low_tap_rad, 1.0, jnp.where(R < up_tap_rad, SW, 0.0)) return SW
[docs] def taper_inc(dist, low_tap_rad=0, up_tap_rad=10): """ Increasing tapering function 0 at low_tap_rad and 1 at up_tap_rad smoothly taper the value in between """ return 1 - taper(dist, low_tap_rad, up_tap_rad)