junrushao created an issue (apache/tvm-ffi#553)
## Summary
This RFC is highly inspired by `tvm::Op` and derived from its design.
Add first-class enum support to TVM-FFI. An enum is a registered Object type
whose instances are **named, frozen singletons** — the same model as `tvm::Op`,
but generalized as a decorator that integrates with `@py_class` / `@c_class`.
Two forms are supported:
- **Attribute-carrying enum** (primary): fields declared as annotations,
entries declared via `ClassVar[Self] = entry(kwarg=val)`
- **Simple integer enum**: `ClassVar[Self] = entry(int)` with an implicit
`value: int` field
Both forms support `tvm::Op`-style extensible per-variant attributes
(`def_attr`), cross-language access by string name (`Cls.get("name")` /
`Cls::Get("name")`), and distributed C++ entry registration across files.
## Motivation
TVM-FFI provides `@py_class` and `@c_class` for defining cross-language
dataclasses, but has no equivalent for enumerations. Today, enums like
`DataTypeCode` and `DLDeviceType` are manually defined as Python `IntEnum`
subclasses with no FFI registration — invisible to other languages and to the
reflection system.
TVM's `tvm::Op` demonstrates the pattern we want: a string-keyed singleton
registry where each entry is an Object, extensible attributes can be attached
from anywhere, and lookup is by name (`Op::Get("nn.relu")`). This RFC
generalizes that pattern into a first-class enum abstraction.
## Design
### Core Model
An enum is a registered Object type whose instances are named, frozen
singletons:
```
┌─────────────────────────────────────────────┐
│ Activation (type_key = "nn.Activation") │
│ ───────────────────────────────────────── │
│ fields: approximate: bool, inplace: bool │ ← same as @py_class
│ │
│ entries (singleton instances): │
│ "ReLU" → Activation(False, False) │
│ "GeLU" → Activation(True, False) │
│ "SiLU" → Activation(False, True) │
│ │
│ ext attrs (def_attr columns): │
│ "f_compute" → {"ReLU": fn, "GeLU": fn} │
│ "f_gradient" → {"ReLU": fn} │
└─────────────────────────────────────────────┘
```
- Each entry is a frozen `ObjectRef`, stored as a **TypeAttr** on the enum's
`type_index`.
- Compared by identity (`is` / pointer equality).
- Crossed through FFI as `ObjectRef` — no special wire encoding.
### Python API: `@py_enum`
#### Attribute-carrying enum (primary form)
```python
from __future__ import annotations
from typing import ClassVar
from tvm_ffi.dataclasses import py_enum, entry
@py_enum("nn.Activation")
class Activation(EnumObject):
approximate: bool = False
inplace: bool = False
ReLU: ClassVar[Activation] = entry()
GeLU: ClassVar[Activation] = entry(approximate=True)
SiLU: ClassVar[Activation] = entry(inplace=True)
```
Like `@py_class`, the decorator accepts an optional `type_key` as the first
positional argument, or auto-generates it from `{module}.{qualname}` when used
bare:
```python
@py_enum # type_key = "{module}.Activation"
class Activation(EnumObject): ...
@py_enum("nn.Activation") # explicit type_key
class Activation(EnumObject): ...
@py_enum("nn.Activation", eq=True) # with py_class options
class Activation(EnumObject): ...
```
Fields are declared via annotations (exactly like `@py_class`). Entries are
declared via `entry()` with a `ClassVar[Self]` annotation, which accepts
keyword arguments matching the field names:
```python
Activation.ReLU # frozen singleton (Activation instance)
Activation.ReLU.approximate # False — direct field access
Activation.GeLU.approximate # True
Activation.get("ReLU") # same object as Activation.ReLU
Activation.ReLU is Activation.get("ReLU") # True
```
#### Simple integer enum
When all entries use bare `int` values in `entry()`, `@py_enum` creates a
single implicit `value: int` field:
```python
@py_enum("nn.Precision")
class Precision(EnumObject):
FP32: ClassVar[Precision] = entry(0)
FP16: ClassVar[Precision] = entry(1)
BF16: ClassVar[Precision] = entry(2)
```
This is sugar for:
```python
@py_enum("nn.Precision")
class Precision(EnumObject):
value: int
FP32: ClassVar[Precision] = entry(value=0)
FP16: ClassVar[Precision] = entry(value=1)
BF16: ClassVar[Precision] = entry(value=2)
```
```python
Precision.FP32.value # 0
Precision.get("FP16") # Precision.FP16 singleton
```
#### Type checker compatibility
The `ClassVar[Self]` pattern works natively with mypy, pyright, and ty — no
stubs, no plugins:
1. **`from __future__ import annotations`** (already required by TVM-FFI) makes
`Activation` in `ClassVar[Activation]` a string — the forward self-reference
resolves.
2. **`entry()` returns `Any`** — `Any` is assignable to `ClassVar[Activation]`,
so no type mismatch.
3. **Type checkers see `ReLU: ClassVar[Activation]`** — so `Activation.ReLU`
has type `Activation`.
4. **`@py_class` already skips `ClassVar` annotations** — entries are not
registered as fields.
```python
Activation.ReLU # Activation ✓
Activation.ReLU.approximate # bool ✓
Activation.get("ReLU") # Activation (via Self) ✓
Activation.entries() # dict[str, Activation] ✓
```
The `get()`, `entries()`, and `def_attr()` methods are typed on the
`EnumObject` base class using `Self`:
```python
class EnumObject(Object):
"""Base for @py_enum and @c_enum types."""
@classmethod
def get(cls, name: str) -> Self: ...
@classmethod
def entries(cls) -> dict[str, Self]: ...
@classmethod
def def_attr(cls, name: str) -> EnumAttrMap[Self]: ...
```
#### The `entry()` sentinel
`entry()` returns a lightweight sentinel that `@py_enum` processes after the
class body executes:
```python
class _EnumEntry:
"""Sentinel returned by entry(). Processed by @py_enum."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.args = args # positional: entry(0) for simple int
self.kwargs = kwargs # keyword: entry(approximate=True)
def entry(*args: Any, **kwargs: Any) -> Any:
return _EnumEntry(*args, **kwargs)
```
`@py_enum` scans the class `__dict__` for `_EnumEntry` instances, separates
them from field annotations, and creates the singleton instances after the
underlying `@py_class` registration finishes.
#### How `@py_enum` works internally
```
@py_enum("nn.Activation")
class Activation(EnumObject):
approximate: bool = False ← field annotations
inplace: bool = False
ReLU: ClassVar[Activation] = entry() ← entry (ClassVar, skipped
by @py_class)
GeLU: ClassVar[Activation] = entry(approximate=True)
```
1. **Separate fields from entries.** Scan class `__dict__` for `_EnumEntry`
instances. `ClassVar` annotations are already skipped by `@py_class`, so
entries don't interfere with field registration.
2. **Register as `@py_class`.** Delegate to `@py_class(frozen=True)` for type
registration, field registration, and dunder installation. The enum class gets
a `type_key`, `type_index`, and reflection metadata like any other `@py_class`.
If all entries are positional-int (`entry(0)`), inject an implicit `value: int`
field first.
3. **Create singleton instances.** For each entry, instantiate the class with
the specified kwargs (filling in defaults from field declarations), then freeze
the instance.
4. **Register entries via TypeAttr.** Store the full `name -> singleton`
mapping as a TypeAttr on the enum's `type_index`:
```python
core._type_register_attr(type_index, "__enum_entries__", entries_map)
```
5. **Set class attributes.** For each `(name, singleton)` pair:
- `cls.ReLU = singleton`
- `singleton.__entry_name__ = "ReLU"`
#### Query API
```python
# By class attribute
act = Activation.ReLU
# By string name (Op::Get pattern)
act = Activation.get("ReLU")
# Iteration
for name, member in Activation.entries():
print(name, member.approximate)
# Identity comparison (singletons)
assert conv.activation is Activation.ReLU
```
#### Usage in `@py_class` fields
```python
@py_class
class Conv2D(Object):
activation: Activation
channels: int
```
`TypeSchema.from_annotation()` detects registered enum types (they are
`@py_class` types that happen to have entries) and the field getter/setter
works normally since the enum is a proper `ObjectRef`.
### C++ API: `c_enum`
On the C++ side, enums follow the `tvm::Op` registration pattern: the type is a
normal Object class, and entries are registered from static init blocks across
multiple files.
**There are no C++ `enum class` literals.** Variants are accessed by string
name, like `Op::Get("nn.relu")`.
#### Define the Object type
The enum's Object class is a normal C++ class with fields:
```cpp
#include <tvm/ffi/tvm_ffi.h>
namespace ffi = tvm::ffi;
class ActivationObj : public ffi::Object {
public:
bool approximate = false;
bool inplace = false;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL(
"nn.Activation", ActivationObj, ffi::Object);
};
class Activation : public ffi::ObjectRef {
public:
/*! \brief Look up a registered entry by name. */
static Activation Get(const std::string& name);
TVM_FFI_DEFINE_OBJECT_REF_METHODS(
Activation, ffi::ObjectRef, ActivationObj);
};
```
#### Register fields and mark as enum
```cpp
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
// Register type + fields (like ObjectDef for @c_class)
refl::ObjectDef<ActivationObj>()
.as_enum() // marks as enum type
.def_ro("approximate", &ActivationObj::approximate)
.def_ro("inplace", &ActivationObj::inplace);
}
```
#### Register entries (distributed across files)
```cpp
// file: activations_core.cc
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::EnumEntry<ActivationObj>("ReLU"); // all defaults
refl::EnumEntry<ActivationObj>("GeLU").set("approximate", true);
}
// file: activations_extra.cc (different file, different library)
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::EnumEntry<ActivationObj>("SiLU").set("inplace", true);
}
```
Each `EnumEntry` is a RAII builder. Its destructor:
1. Creates a singleton `ActivationObj` instance (via `make_object`).
2. Sets the specified fields (all others use class defaults).
3. Freezes the instance (marks immutable).
4. Registers via TypeAttr (read-modify-write):
```cpp
// Read existing entries Map from TypeAttr
auto entries = TypeAttrColumn("__enum_entries__")[type_index]
.cast<Map<String, ObjectRef>>();
// Add new entry
entries.Set(String(name_), singleton);
// Write back
TVMFFITypeRegisterAttr(type_index, "__enum_entries__", entries);
```
#### Query
```cpp
// By name (the primary access pattern — like Op::Get)
Activation relu = Activation::Get("ReLU");
relu->approximate; // false
relu->inplace; // false
Activation gelu = Activation::Get("GeLU");
gelu->approximate; // true
// Identity comparison
assert(relu.same_as(Activation::Get("ReLU")));
```
`Activation::Get(name)` reads the entries Map from TypeAttr and does a Map
lookup.
#### Extensible attributes (C++)
```cpp
// file: compute_registry.cc
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::EnumAttrDef<ActivationObj>("f_compute")
.set("ReLU", relu_compute)
.set("GeLU", gelu_compute);
}
// file: gradient_registry.cc (different file)
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::EnumAttrDef<ActivationObj>("f_gradient")
.set("ReLU", relu_gradient);
}
```
Query:
```cpp
EnumAttrMap<ActivationObj> f_compute("f_compute");
Function fn = f_compute[relu];
```
### `@c_enum`: Python wrapper for C++-defined enums
Analogous to `@c_class`. Wraps a C++-registered enum type, discovers entries,
and installs them as class attributes:
```python
from tvm_ffi.dataclasses import c_enum
@c_enum("nn.Activation")
class Activation(EnumObject):
approximate: bool
inplace: bool
```
`@c_enum(type_key)`:
1. Calls `@c_class(type_key, init=False)` — registers the Python class against
the existing C++ type.
2. Reads the `"__enum_entries__"` TypeAttr for the type's `type_index` to
discover all registered entries.
3. For each entry, sets it as a class attribute (`cls.ReLU = singleton`).
4. Inherits `cls.get(name)`, `cls.entries()`, and `cls.def_attr()` from
`EnumObject`.
Entries registered later (e.g., from a dynamically loaded module) are not
automatically discovered — call `Activation._refresh_entries()` or access via
`Activation.get("name")` which always reads from the live TypeAttr.
### Extensible Attributes (`def_attr`)
Both `@py_enum` and `@c_enum` support `tvm::Op`-style extensible per-variant
attribute columns, registrable from anywhere:
```python
f_compute = Activation.def_attr("f_compute")
# Register from anywhere
f_compute[Activation.ReLU] = relu_fn
f_compute[Activation.GeLU] = gelu_fn
# Decorator form
@f_compute.register(Activation.SiLU)
def silu_compute(x):
return x * sigmoid(x)
# Query
fn = f_compute[conv.activation]
```
`EnumAttrMap` is a thin wrapper around a `Map<String, Any>` stored as a
TypeAttr:
```python
class EnumAttrMap(Generic[E]):
"""Per-variant attribute column. Analogous to tvm::OpAttrMap."""
def __init__(self, enum_cls: type[E], attr_name: str) -> None:
self._enum_cls = enum_cls
self._attr_name = attr_name
self._type_index = enum_cls.__tvm_ffi_type_info__.type_index
self._attr_key = f"__enum_attr_{attr_name}__"
# Load existing column (may be registered from C++)
self._table: dict[str, Any] = dict(
core._type_get_attr(self._type_index, self._attr_key) or {}
)
def __getitem__(self, variant: E) -> Any:
name = variant.__entry_name__
if name not in self._table:
raise KeyError(f"{self._attr_name} not set for {name!r}")
return self._table[name]
def __setitem__(self, variant: E, value: Any) -> None:
name = variant.__entry_name__
self._table[name] = value
core._type_register_attr(self._type_index, self._attr_key, self._table)
def register(self, variant: E):
def deco(fn):
self[variant] = fn
return fn
return deco
def get(self, variant: E, default=None):
name = variant.__entry_name__
return self._table.get(name, default)
```
### Backing Store: TypeAttr
All enum metadata is stored using the **existing TypeAttr C APIs**, keyed by
the enum's `type_index`:
| TypeAttr name | Value type | Description |
|---|---|---|
| `__enum_entries__` | `Map<String, ObjectRef>` | All entries by name |
| `__enum_attr_{col}__` | `Map<String, Any>` | Extensible attr column |
C APIs used (all existing):
```c
// Write: store per-type attribute
int TVMFFITypeRegisterAttr(int32_t type_index,
const TVMFFIByteArray* attr_name,
const TVMFFIAny* attr_value);
// Read: get attribute column for O(1) lookup by type_index
const TVMFFITypeAttrColumn* TVMFFIGetTypeAttrColumn(
const TVMFFIByteArray* attr_name);
```
Both C++ and Python access enum state through these C APIs. The C++ builder
classes (`EnumEntry`, `EnumAttrDef`) and the Python decorator (`@py_enum`) are
convenience layers on top.
**Note:** Distributed entry registration (multiple `EnumEntry` calls from
different files) requires `TVMFFITypeRegisterAttr` to support **override**
(re-registering the same `(type_index, attr_name)` pair with an updated Map).
If the current implementation does not support override, a small fix to the
existing API implementation is needed — not a new API surface.
### Cross-Language Interop
#### Wire format
Enum entries are `ObjectRef` singletons. They cross FFI as normal object
references — the same pointer/handle on both sides within a process.
- No special type code needed.
- No `int64_t` encoding/decoding — the Object system handles it.
- Singleton identity is preserved: C++ and Python share the same underlying
`Object*`.
#### Serialization
For JSON/binary serialization across processes, an enum entry serializes as its
`(type_key, entry_name)` pair. The deserializer reconstructs via
`EnumCls.get(entry_name)`.
## Full Example
```cpp
// ── C++ side ─────────────────────────────────────────────────────
class ActivationObj : public ffi::Object {
public:
bool approximate = false;
bool inplace = false;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL(
"nn.Activation", ActivationObj, ffi::Object);
};
class Activation : public ffi::ObjectRef {
public:
static Activation Get(const std::string& name);
TVM_FFI_DEFINE_OBJECT_REF_METHODS(
Activation, ffi::ObjectRef, ActivationObj);
};
// Register type + fields
TVM_FFI_STATIC_INIT_BLOCK() {
refl::ObjectDef<ActivationObj>()
.as_enum()
.def_ro("approximate", &ActivationObj::approximate)
.def_ro("inplace", &ActivationObj::inplace);
// Register entries
refl::EnumEntry<ActivationObj>("ReLU");
refl::EnumEntry<ActivationObj>("GeLU").set("approximate", true);
}
// Attach extensible attribute (from another file)
TVM_FFI_STATIC_INIT_BLOCK() {
refl::EnumAttrDef<ActivationObj>("f_compute")
.set("ReLU", relu_fn)
.set("GeLU", gelu_fn);
}
// Query
Activation relu = Activation::Get("ReLU");
EnumAttrMap<ActivationObj> f_compute("f_compute");
Function fn = f_compute[relu];
```
```python
# ── Python side (wrapping C++ enum) ──────────────────────────────
@c_enum("nn.Activation")
class Activation(EnumObject):
approximate: bool
inplace: bool
# Entries are auto-discovered from C++ registration:
Activation.ReLU.approximate # False
Activation.GeLU.approximate # True
Activation.get("ReLU") # same singleton
# Query C++-registered extensible attrs
f_compute = Activation.def_attr("f_compute")
fn = f_compute[Activation.ReLU]
# Extend with Python-defined attrs
f_gradient = Activation.def_attr("f_gradient")
f_gradient[Activation.ReLU] = relu_grad_fn
# Use as @py_class field
@py_class
class Conv2D(Object):
activation: Activation
channels: int
conv = Conv2D(activation=Activation.ReLU, channels=64)
assert conv.activation is Activation.ReLU
```
```python
# ── Pure Python enum (no C++ involvement) ────────────────────────
@py_enum("nn.Padding")
class Padding(EnumObject):
top: int = 0
bottom: int = 0
left: int = 0
right: int = 0
Zero: ClassVar[Padding] = entry()
Same: ClassVar[Padding] = entry(top=1, bottom=1, left=1, right=1)
Custom: ClassVar[Padding] = entry(top=2, bottom=2, left=3, right=3)
# Simple integer enum
@py_enum("nn.Precision")
class Precision(EnumObject):
FP32: ClassVar[Precision] = entry(0)
FP16: ClassVar[Precision] = entry(1)
BF16: ClassVar[Precision] = entry(2)
Precision.FP32.value # 0
Precision.get("FP16") # Precision.FP16 singleton
```
## Implementation Plan
1. **`EnumObject` base class** — typed base with `get()`, `entries()`,
`def_attr()` using `Self`.
2. **`entry()` sentinel and `@py_enum` decorator** — Python-side enum support,
building on `@py_class`.
3. **Verify `TVMFFITypeRegisterAttr` override support** — ensure
re-registration works for distributed entry registration; fix if needed.
4. **`ObjectDef::as_enum()` and `EnumEntry<T>`** — C++ RAII builders for
distributed entry registration via TypeAttr.
5. **`@c_enum` decorator** — Python wrapper for C++-defined enums (mirrors
`@c_class`).
6. **`EnumAttrDef` / `EnumAttrMap`** — Extensible per-variant attribute columns
(C++ and Python), backed by TypeAttr.
7. **`TypeSchema` integration** — Recognize enum types in `@py_class` field
annotations.
8. **Tests** — C++, Python, and cross-language round-trip tests.
## Summary
| Aspect | `@py_enum` (Python-defined) | `@c_enum` (C++-defined) |
|---|---|---|
| Type definition | `@py_enum class Foo(EnumObject)` |
`ObjectDef<FooObj>().as_enum()` |
| Fields | Python annotations | `def_ro` / `def_rw` |
| Entry syntax | `X: ClassVar[Self] = entry(...)` |
`EnumEntry<Obj>("name").set(...)` |
| Distributed entries | No (class body only) | Yes (any C++ file) |
| Python wrapper | N/A | `@c_enum("type_key")` |
| Wire format | `ObjectRef` (singleton) | `ObjectRef` (singleton) |
| Access by name | `Cls.get("name")` | `Cls::Get("name")` |
| Access by attr | `Cls.ReLU` | (via Python wrapper) |
| Ext attrs | `def_attr` / `EnumAttrMap` | `EnumAttrDef` / `EnumAttrMap` |
| Type checker | Native (`ClassVar[Self]`) | Native (via `@c_enum`) |
| Backing store | TypeAttr (`TVMFFITypeRegisterAttr`) | TypeAttr
(`TVMFFITypeRegisterAttr`) |
--
Reply to this email directly or view it on GitHub:
https://github.com/apache/tvm-ffi/issues/553
You are receiving this because you are subscribed to this thread.
Message ID: <apache/tvm-ffi/issues/[email protected]>