jax_md.dataclasses module

jax_md.dataclasses module#

Utilities for defining dataclasses that can be used with jax transformations.

This code was copied and adapted from google/flax.

Accessed on 04/29/2020.

jax_md.dataclasses.dataclass(clz=None, *, frozen=True, **dataclass_kwargs)[source]#
Overloads:
  • clz (T), frozen (bool), dataclass_kwargs (Any) → T

  • frozen (bool), dataclass_kwargs (Any) → Callable[[T], T]

Create a class which can be passed to functional transformations.

Jax transformations such as jax.jit and jax.grad require objects that are immutable and can be mapped over using the jax.tree_util methods.

The dataclass decorator makes it easy to define custom classes that can be passed safely to Jax by relying on jax.tree_util.register_dataclass.

Parameters:
  • clz (Optional[TypeVar(T, bound= type[Any])]) – the class that will be transformed by the decorator.

  • frozen (bool) – whether the resulting dataclass should be frozen. Defaults to True.

  • **dataclass_kwargs (Any) – additional keyword arguments forwarded to dataclasses.dataclass.

Returns:

The new class.

jax_md.dataclasses.static_field(*, metadata=None, **field_kwargs)[source]#

Create a field that is treated as static (non-pytree) by JAX.

Returns Any (like typeshed’s dataclasses.field) so the sentinel Field object type-checks against any annotation in a class body.

Return type:

Any

jax_md.dataclasses.unpack(dc)[source]#

Return a tuple of dataclass attribute values.

This is a lightweight alternative to dataclasses.astuple() that avoids recursion and respects custom attribute access defined on the dataclass.

Return type:

tuple[Any, ...]