jax_md.dataclasses module
Utilities for defining dataclasses that can be used with jax transformations.
This code was copied and adapted from google/flax.
Accessed on 04/29/2020.
-
jax_md.dataclasses.dataclass(clz=None, *, frozen=True, **dataclass_kwargs)[source]
- Overloads:
clz (T), frozen (bool), dataclass_kwargs (Any) → T
frozen (bool), dataclass_kwargs (Any) → Callable[[T], T]
Create a class which can be passed to functional transformations.
Jax transformations such as jax.jit and jax.grad require objects that are
immutable and can be mapped over using the jax.tree_util methods.
The dataclass decorator makes it easy to define custom classes that can be
passed safely to Jax by relying on jax.tree_util.register_dataclass.
- Parameters:
clz (Optional[TypeVar(T, bound= type[Any])]) – the class that will be transformed by the decorator.
frozen (bool) – whether the resulting dataclass should be frozen. Defaults to True.
**dataclass_kwargs (Any) – additional keyword arguments forwarded to
dataclasses.dataclass.
- Returns:
The new class.
-
jax_md.dataclasses.static_field(*, metadata=None, **field_kwargs)[source]
Create a field that is treated as static (non-pytree) by JAX.
Returns Any (like typeshed’s dataclasses.field) so the sentinel
Field object type-checks against any annotation in a class body.
- Return type:
Any
-
jax_md.dataclasses.unpack(dc)[source]
Return a tuple of dataclass attribute values.
This is a lightweight alternative to dataclasses.astuple() that avoids
recursion and respects custom attribute access defined on the dataclass.
- Return type:
tuple[Any, ...]