Source code for jax_md.dataclasses
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for defining dataclasses that can be used with jax transformations.
This code was copied and adapted from https://github.com/google/flax/struct.py.
Accessed on 04/29/2020.
"""
import dataclasses
from dataclasses import Field as _Field
from dataclasses import asdict as _asdict
from dataclasses import astuple as _astuple
from dataclasses import field as _field
from dataclasses import fields as _fields
from dataclasses import is_dataclass as _is_dataclass
from dataclasses import replace as _replace
from typing import Any, Callable, Optional, TypeVar, overload
import jax
__all__ = (
'dataclass',
'static_field',
'unpack',
'replace',
'asdict',
'astuple',
'is_dataclass',
'fields',
'field',
)
T = TypeVar('T', bound=type[Any])
@overload
def dataclass(clz: T, *, frozen: bool = True, **dataclass_kwargs: Any) -> T: ...
@overload
def dataclass(
*, frozen: bool = True, **dataclass_kwargs: Any
) -> Callable[[T], T]: ...
[docs]
def dataclass(
clz: T | None = None, *, frozen: bool = True, **dataclass_kwargs: Any
) -> T | Callable[[T], T]:
"""Create a class which can be passed to functional transformations.
Jax transformations such as `jax.jit` and `jax.grad` require objects that are
immutable and can be mapped over using the `jax.tree_util` methods.
The `dataclass` decorator makes it easy to define custom classes that can be
passed safely to Jax by relying on `jax.tree_util.register_dataclass`.
Args:
clz: the class that will be transformed by the decorator.
frozen: whether the resulting dataclass should be frozen. Defaults to True.
**dataclass_kwargs: additional keyword arguments forwarded to
`dataclasses.dataclass`.
Returns:
The new class.
"""
if 'frozen' in dataclass_kwargs:
requested_frozen = dataclass_kwargs.pop('frozen')
if requested_frozen != frozen:
raise TypeError(
"'frozen' must match the decorator argument when provided in dataclass_kwargs"
)
def decorate(target_clz: T) -> T:
data_clz = dataclasses.dataclass(frozen=frozen, **dataclass_kwargs)(
target_clz
)
registered_clz = jax.tree_util.register_dataclass(data_clz)
def _set(self, **kwargs):
return _replace(self, **kwargs)
setattr(registered_clz, 'set', _set)
return registered_clz
if clz is None:
return decorate
return decorate(clz)
[docs]
def static_field(
*, metadata: dict[str, Any] | None = None, **field_kwargs: Any
) -> _Field[Any]:
"""Create a field that is treated as static (non-pytree) by JAX."""
combined_metadata = dict(metadata or {})
combined_metadata.setdefault('static', True)
combined_metadata['pytree_node'] = False
return _field(metadata=combined_metadata, **field_kwargs)
[docs]
def unpack(dc: Any) -> tuple[Any, ...]:
"""Return a tuple of dataclass attribute values.
This is a lightweight alternative to :func:`dataclasses.astuple` that avoids
recursion and respects custom attribute access defined on the dataclass.
"""
return tuple(getattr(dc, field.name) for field in _fields(dc))
replace = _replace
asdict = _asdict
astuple = _astuple
is_dataclass = _is_dataclass
fields = _fields
field = _field