UMA (Universal Models for Atoms) — JAX-MD#
End-to-end walkthrough of the UMA model: pretrained inference, dataset routing, batched evaluation, molecular dynamics, and structure relaxation.
pip install jax jaxlib flax torch huggingface_hub
huggingface-cli login # accept license at https://huggingface.co/facebook/UMA
[1]:
import os
import jax
import jax.numpy as jnp
from jax import jit, random
import numpy as np
from jax_md._nn.uma import load_pretrained, UMAMoEBackbone
from jax_md._nn.uma.nn.embedding import dataset_names_to_indices
from jax_md._nn.uma.heads import MLPEnergyHead
ON_READTHEDOCS = os.environ.get('READTHEDOCS', '').lower() in {
'1',
'true',
'yes',
}
RUN_PRETRAINED_UMA = not ON_READTHEDOCS
if not RUN_PRETRAINED_UMA:
print(
'Skipping pretrained UMA execution on Read the Docs; '
'the checkpoint is gated and requires HuggingFace authentication.'
)
Skipping pretrained UMA execution on Read the Docs; the checkpoint is gated and requires HuggingFace authentication.
1. Load pretrained model#
[2]:
if RUN_PRETRAINED_UMA:
config, params, head_params = load_pretrained('uma-s-1p2')
model = UMAMoEBackbone(config=config)
head = MLPEnergyHead(
sphere_channels=config.sphere_channels,
hidden_channels=config.hidden_channels,
)
predict = jax.jit(model.apply)
print(
f"Model: {config.num_layers} layers, {config.sphere_channels} channels, "
f"{config.num_experts} experts"
)
print(f"Datasets: {config.dataset_list}")
def build_edges(pos_np, cutoff):
"""All-pairs edges within cutoff (for small non-periodic systems)."""
n = len(pos_np)
s, d = [], []
for i in range(n):
for j in range(n):
if i != j and np.linalg.norm(pos_np[i] - pos_np[j]) < cutoff:
s.append(j)
d.append(i)
return np.array([s, d], dtype=np.int32)
2. Single-system inference (energy + forces)#
[3]:
if RUN_PRETRAINED_UMA:
a = 3.615
cu_pos = np.array([
[0, 0, 0], [a/2, a/2, 0], [a/2, 0, a/2], [0, a/2, a/2],
], dtype=np.float32)
cu_Z = jnp.array([29, 29, 29, 29], dtype=jnp.int32)
cu_batch = jnp.zeros(4, dtype=jnp.int32)
cu_ds = dataset_names_to_indices(['omat'], config.dataset_list)
cu_ei = jnp.array(build_edges(cu_pos, config.cutoff))
cu_pos = jnp.array(cu_pos)
cu_ev = cu_pos[cu_ei[0]] - cu_pos[cu_ei[1]]
charge0 = jnp.array([0], dtype=jnp.int32)
spin0 = jnp.array([0], dtype=jnp.int32)
emb = predict(params, cu_pos, cu_Z, cu_batch, cu_ei, cu_ev,
charge0, spin0, cu_ds)
result = head.apply(head_params, emb['node_embedding'], cu_batch, 1)
print(f"Cu FCC energy: {float(result['energy'][0]):.6f} eV")
def cu_energy(pos):
ev = pos[cu_ei[0]] - pos[cu_ei[1]]
e = model.apply(params, pos, cu_Z, cu_batch, cu_ei, ev,
charge0, spin0, cu_ds)
return head.apply(head_params, e['node_embedding'], cu_batch, 1)['energy'][0]
forces = -jax.grad(cu_energy)(cu_pos)
print(f"Max |force|: {float(jnp.max(jnp.abs(forces))):.6f} eV/A")
3. Dataset routing — same atoms, different DFT levels#
[4]:
if RUN_PRETRAINED_UMA:
print("Cu FCC across datasets:")
for ds_name in config.dataset_list:
ds = dataset_names_to_indices([ds_name], config.dataset_list)
out = predict(params, cu_pos, cu_Z, cu_batch, cu_ei, cu_ev,
charge0, spin0, ds)
l0 = float(out['node_embedding'][:, 0, :].mean())
print(f"{ds_name:5s}: l=0 mean: {l0:+.6f}")
4. Batched inference — multiple systems in one call#
[5]:
if RUN_PRETRAINED_UMA:
h2o_pos = np.array([[0, 0, .12], [0, .76, -.47], [0, -.76, -.47]],
dtype=np.float32)
h2o_Z = [8, 1, 1]
all_pos_np = np.concatenate([np.asarray(cu_pos), h2o_pos])
all_Z = jnp.array([29, 29, 29, 29, 8, 1, 1], dtype=jnp.int32)
batch = jnp.array([0, 0, 0, 0, 1, 1, 1], dtype=jnp.int32)
ei_cu = build_edges(np.asarray(cu_pos), config.cutoff)
ei_h2o = build_edges(h2o_pos, config.cutoff) + 4 # offset for H2O
ei = jnp.array(np.concatenate([ei_cu, ei_h2o], axis=1))
all_pos = jnp.array(all_pos_np)
ev = all_pos[ei[0]] - all_pos[ei[1]]
ds_idx = dataset_names_to_indices(['omat', 'omol'], config.dataset_list)
out = predict(params, all_pos, all_Z, batch, ei, ev,
jnp.array([0, 0], dtype=jnp.int32),
jnp.array([0, 0], dtype=jnp.int32), ds_idx)
result = head.apply(head_params, out['node_embedding'], batch, 2)
print(f"Batched energies: Cu4={float(result['energy'][0]):.4f}, "
f"H2O={float(result['energy'][1]):.4f} eV")
5. Charge and spin (omol task)#
[6]:
if RUN_PRETRAINED_UMA:
print("H2O charge/spin variants:")
for label, q, s in [('neutral', 0, 1), ('cation', 1, 2), ('anion', -1, 2)]:
h2o_ei = jnp.array(build_edges(h2o_pos, config.cutoff))
h2o_jnp = jnp.array(h2o_pos)
h2o_ev = h2o_jnp[h2o_ei[0]] - h2o_jnp[h2o_ei[1]]
omol = dataset_names_to_indices(['omol'], config.dataset_list)
out = predict(params, h2o_jnp, jnp.array([8, 1, 1], dtype=jnp.int32),
jnp.zeros(3, dtype=jnp.int32), h2o_ei, h2o_ev,
jnp.array([q], dtype=jnp.int32),
jnp.array([s], dtype=jnp.int32), omol)
l0 = float(out['node_embedding'][:, 0, :].mean())
print(f"{label:8s} (q={q:+d}, s={s}): l=0 mean: {l0:+.6f}")
6. Molecular dynamics (NVE, periodic Si)#
[7]:
from jax_md import space, energy, simulate, quantity
if RUN_PRETRAINED_UMA:
a_si = 5.43
si_basis = jnp.array([
[0, 0, 0], [.5, .5, 0], [.5, 0, .5], [0, .5, .5],
[.25, .25, .25], [.75, .75, .25], [.75, .25, .75], [.25, .75, .75],
]) * a_si
si_Z = jnp.array([14] * 8, dtype=jnp.int32)
displacement_fn, shift_fn = space.periodic(a_si)
neighbor_fn, init_fn, energy_fn = energy.uma_neighbor_list(
displacement_fn, a_si, checkpoint_path='uma-s-1p2', atoms=si_Z,
)
key = random.PRNGKey(0)
nbrs = neighbor_fn.allocate(si_basis)
md_params = init_fn(key, si_basis, nbrs)
def nve_energy(R, neighbor, **kw):
return energy_fn(md_params, R, neighbor)
init_nve, apply_nve = simulate.nve(nve_energy, shift_fn, dt=0.001)
apply_nve = jit(apply_nve)
key, subkey = random.split(key)
state = init_nve(subkey, si_basis, kT=0.1, neighbor=nbrs)
print(f"NVE on {len(si_Z)} Si atoms:")
for step in range(50):
nbrs = nbrs.update(state.position)
state = apply_nve(state, neighbor=nbrs)
if step % 25 == 0:
KE = float(quantity.kinetic_energy(state.momentum, state.mass))
PE = float(nve_energy(state.position, nbrs))
print(f"step {step:3d}: KE: {KE:.6f} PE: {PE:.6f} total: {KE+PE:.6f}")
# For NVT, use simulate.nvt_nose_hoover(nve_energy, shift_fn, dt, kT, tau)
7. Structure relaxation (FIRE)#
[8]:
from jax_md import minimize
from jax_md._nn.uma.model import UMAConfig
if RUN_PRETRAINED_UMA:
cfg = UMAConfig(
sphere_channels=32, lmax=2, mmax=2, num_layers=1,
hidden_channels=32, cutoff=5.0, edge_channels=32,
num_distance_basis=64, use_dataset_embedding=False,
)
disp_fn, shift_fn_relax = space.periodic(a_si)
nbr_fn, init_fn_relax, e_fn = energy.uma_neighbor_list(
disp_fn, a_si, cfg=cfg, atoms=si_Z,
)
key = random.PRNGKey(42)
perturbed = si_basis + random.normal(key, si_basis.shape) * 0.1
nbrs_relax = nbr_fn.allocate(perturbed)
relax_params = init_fn_relax(key, perturbed, nbrs_relax)
def relax_energy(R, **kw):
return e_fn(relax_params, R, nbrs_relax.update(R))
fire_init, fire_apply = minimize.fire_descent(
relax_energy, shift_fn_relax, dt_start=0.1, dt_max=0.4,
)
fire_apply = jit(fire_apply)
fire_state = fire_init(perturbed)
print(f"FIRE relaxation ({len(si_Z)} Si atoms):")
for step in range(100):
fire_state = fire_apply(fire_state)
if step % 20 == 0:
F = -jax.grad(relax_energy)(fire_state.position)
fmax = float(jnp.max(jnp.abs(F)))
print(f"step {step:3d}: fmax: {fmax:.6f}")
if fmax < 0.01:
print(f"Converged at step {step}")
break
8. Checkpoint conversion (requires torch)#
load_pretrained (section 1) handles this automatically, but here we show the individual steps: download, inspect, convert, save as numpy for torch-free loading.
[9]:
if not RUN_PRETRAINED_UMA:
print("Skipping checkpoint conversion demo on Read the Docs")
else:
try:
from jax_md._nn.uma.pretrained import (
download_pretrained, convert_checkpoint, print_conversion_report,
load_checkpoint_raw, extract_config, PRETRAINED_MODELS,
)
print("Available pretrained models:")
for name, info in PRETRAINED_MODELS.items():
print(f" {name}: {info['description']}")
# Download from HuggingFace
ckpt_path = download_pretrained('uma-s-1p2')
# Inspect raw checkpoint
ckpt = load_checkpoint_raw(ckpt_path)
raw_cfg = extract_config(ckpt)
print(f"Checkpoint config: {raw_cfg.num_layers} layers, "
f"{raw_cfg.sphere_channels} ch, lmax={raw_cfg.lmax}")
print(f"State dict: {len(ckpt.model_state_dict)} parameters")
# Convert to JAX params (preserves all MoE experts)
cfg_conv, jax_params, metadata = convert_checkpoint(ckpt_path, use_ema=True)
print_conversion_report(metadata)
# Save as numpy for torch-free loading
save_dir = os.path.expanduser('~/.cache/fairchem/uma_jax')
os.makedirs(save_dir, exist_ok=True)
flat = jax.tree.leaves_with_path(jax_params)
param_dict = {
'/'.join(
str(p.key) if hasattr(p, 'key') else str(p.idx) for p in path
): np.array(v)
for path, v in flat
}
npz_path = os.path.join(save_dir, 'uma-s-1p2_jax.npz')
np.savez_compressed(npz_path, **param_dict)
print(f"Saved {len(param_dict)} arrays to {npz_path}")
# Reload without torch
loaded = np.load(npz_path)
print(f"Reloaded {len(loaded.files)} arrays (no torch needed)")
except ImportError:
print("torch not installed — skipping checkpoint conversion demo")
except Exception as e:
print(f"Checkpoint conversion skipped: {e}")
Skipping checkpoint conversion demo on Read the Docs