junrushao created an issue (apache/tvm-ffi#553)

## Summary

This RFC is highly inspired by `tvm::Op` and derived from its design.

Add first-class enum support to TVM-FFI. An enum is a registered Object type 
whose instances are **named, frozen singletons** — the same model as `tvm::Op`, 
but generalized as a decorator that integrates with `@py_class` / `@c_class`.

Two forms are supported:
- **Attribute-carrying enum** (primary): fields declared as annotations, 
entries declared via `ClassVar[Self] = entry(kwarg=val)`
- **Simple integer enum**: `ClassVar[Self] = entry(int)` with an implicit 
`value: int` field

Both forms support `tvm::Op`-style extensible per-variant attributes 
(`def_attr`), cross-language access by string name (`Cls.get("name")` / 
`Cls::Get("name")`), and distributed C++ entry registration across files.

## Motivation

TVM-FFI provides `@py_class` and `@c_class` for defining cross-language 
dataclasses, but has no equivalent for enumerations. Today, enums like 
`DataTypeCode` and `DLDeviceType` are manually defined as Python `IntEnum` 
subclasses with no FFI registration — invisible to other languages and to the 
reflection system.

TVM's `tvm::Op` demonstrates the pattern we want: a string-keyed singleton 
registry where each entry is an Object, extensible attributes can be attached 
from anywhere, and lookup is by name (`Op::Get("nn.relu")`). This RFC 
generalizes that pattern into a first-class enum abstraction.

## Design

### Core Model

An enum is a registered Object type whose instances are named, frozen 
singletons:

```
  ┌─────────────────────────────────────────────┐
  │  Activation  (type_key = "nn.Activation")   │
  │  ─────────────────────────────────────────   │
  │  fields:  approximate: bool, inplace: bool   │  ← same as @py_class
  │                                              │
  │  entries (singleton instances):              │
  │    "ReLU"  → Activation(False, False)        │
  │    "GeLU"  → Activation(True,  False)        │
  │    "SiLU"  → Activation(False, True)         │
  │                                              │
  │  ext attrs (def_attr columns):              │
  │    "f_compute"  → {"ReLU": fn, "GeLU": fn}  │
  │    "f_gradient" → {"ReLU": fn}              │
  └─────────────────────────────────────────────┘
```

- Each entry is a frozen `ObjectRef`, stored as a **TypeAttr** on the enum's 
`type_index`.
- Compared by identity (`is` / pointer equality).
- Crossed through FFI as `ObjectRef` — no special wire encoding.

### Python API: `@py_enum`

#### Attribute-carrying enum (primary form)

```python
from __future__ import annotations

from typing import ClassVar

from tvm_ffi.dataclasses import py_enum, entry

@py_enum("nn.Activation")
class Activation(EnumObject):
    approximate: bool = False
    inplace: bool = False

    ReLU: ClassVar[Activation] = entry()
    GeLU: ClassVar[Activation] = entry(approximate=True)
    SiLU: ClassVar[Activation] = entry(inplace=True)
```

Like `@py_class`, the decorator accepts an optional `type_key` as the first 
positional argument, or auto-generates it from `{module}.{qualname}` when used 
bare:

```python
@py_enum                          # type_key = "{module}.Activation"
class Activation(EnumObject): ...

@py_enum("nn.Activation")         # explicit type_key
class Activation(EnumObject): ...

@py_enum("nn.Activation", eq=True)  # with py_class options
class Activation(EnumObject): ...
```

Fields are declared via annotations (exactly like `@py_class`). Entries are 
declared via `entry()` with a `ClassVar[Self]` annotation, which accepts 
keyword arguments matching the field names:

```python
Activation.ReLU                  # frozen singleton (Activation instance)
Activation.ReLU.approximate      # False — direct field access
Activation.GeLU.approximate      # True
Activation.get("ReLU")           # same object as Activation.ReLU
Activation.ReLU is Activation.get("ReLU")  # True
```

#### Simple integer enum

When all entries use bare `int` values in `entry()`, `@py_enum` creates a 
single implicit `value: int` field:

```python
@py_enum("nn.Precision")
class Precision(EnumObject):
    FP32: ClassVar[Precision] = entry(0)
    FP16: ClassVar[Precision] = entry(1)
    BF16: ClassVar[Precision] = entry(2)
```

This is sugar for:

```python
@py_enum("nn.Precision")
class Precision(EnumObject):
    value: int

    FP32: ClassVar[Precision] = entry(value=0)
    FP16: ClassVar[Precision] = entry(value=1)
    BF16: ClassVar[Precision] = entry(value=2)
```

```python
Precision.FP32.value      # 0
Precision.get("FP16")     # Precision.FP16 singleton
```

#### Type checker compatibility

The `ClassVar[Self]` pattern works natively with mypy, pyright, and ty — no 
stubs, no plugins:

1. **`from __future__ import annotations`** (already required by TVM-FFI) makes 
`Activation` in `ClassVar[Activation]` a string — the forward self-reference 
resolves.
2. **`entry()` returns `Any`** — `Any` is assignable to `ClassVar[Activation]`, 
so no type mismatch.
3. **Type checkers see `ReLU: ClassVar[Activation]`** — so `Activation.ReLU` 
has type `Activation`.
4. **`@py_class` already skips `ClassVar` annotations** — entries are not 
registered as fields.

```python
Activation.ReLU                   # Activation            ✓
Activation.ReLU.approximate       # bool                  ✓
Activation.get("ReLU")            # Activation (via Self) ✓
Activation.entries()              # dict[str, Activation]  ✓
```

The `get()`, `entries()`, and `def_attr()` methods are typed on the 
`EnumObject` base class using `Self`:

```python
class EnumObject(Object):
    """Base for @py_enum and @c_enum types."""

    @classmethod
    def get(cls, name: str) -> Self: ...

    @classmethod
    def entries(cls) -> dict[str, Self]: ...

    @classmethod
    def def_attr(cls, name: str) -> EnumAttrMap[Self]: ...
```

#### The `entry()` sentinel

`entry()` returns a lightweight sentinel that `@py_enum` processes after the 
class body executes:

```python
class _EnumEntry:
    """Sentinel returned by entry(). Processed by @py_enum."""
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        self.args = args      # positional: entry(0) for simple int
        self.kwargs = kwargs   # keyword: entry(approximate=True)

def entry(*args: Any, **kwargs: Any) -> Any:
    return _EnumEntry(*args, **kwargs)
```

`@py_enum` scans the class `__dict__` for `_EnumEntry` instances, separates 
them from field annotations, and creates the singleton instances after the 
underlying `@py_class` registration finishes.

#### How `@py_enum` works internally

```
@py_enum("nn.Activation")
class Activation(EnumObject):
    approximate: bool = False       ← field annotations
    inplace: bool = False

    ReLU: ClassVar[Activation] = entry()           ← entry (ClassVar, skipped 
by @py_class)
    GeLU: ClassVar[Activation] = entry(approximate=True)
```

1. **Separate fields from entries.** Scan class `__dict__` for `_EnumEntry` 
instances. `ClassVar` annotations are already skipped by `@py_class`, so 
entries don't interfere with field registration.

2. **Register as `@py_class`.** Delegate to `@py_class(frozen=True)` for type 
registration, field registration, and dunder installation. The enum class gets 
a `type_key`, `type_index`, and reflection metadata like any other `@py_class`. 
If all entries are positional-int (`entry(0)`), inject an implicit `value: int` 
field first.

3. **Create singleton instances.** For each entry, instantiate the class with 
the specified kwargs (filling in defaults from field declarations), then freeze 
the instance.

4. **Register entries via TypeAttr.** Store the full `name -> singleton` 
mapping as a TypeAttr on the enum's `type_index`:
   ```python
   core._type_register_attr(type_index, "__enum_entries__", entries_map)
   ```

5. **Set class attributes.** For each `(name, singleton)` pair:
   - `cls.ReLU = singleton`
   - `singleton.__entry_name__ = "ReLU"`

#### Query API

```python
# By class attribute
act = Activation.ReLU

# By string name (Op::Get pattern)
act = Activation.get("ReLU")

# Iteration
for name, member in Activation.entries():
    print(name, member.approximate)

# Identity comparison (singletons)
assert conv.activation is Activation.ReLU
```

#### Usage in `@py_class` fields

```python
@py_class
class Conv2D(Object):
    activation: Activation
    channels: int
```

`TypeSchema.from_annotation()` detects registered enum types (they are 
`@py_class` types that happen to have entries) and the field getter/setter 
works normally since the enum is a proper `ObjectRef`.

### C++ API: `c_enum`

On the C++ side, enums follow the `tvm::Op` registration pattern: the type is a 
normal Object class, and entries are registered from static init blocks across 
multiple files.

**There are no C++ `enum class` literals.** Variants are accessed by string 
name, like `Op::Get("nn.relu")`.

#### Define the Object type

The enum's Object class is a normal C++ class with fields:

```cpp
#include <tvm/ffi/tvm_ffi.h>

namespace ffi = tvm::ffi;

class ActivationObj : public ffi::Object {
 public:
  bool approximate = false;
  bool inplace = false;

  TVM_FFI_DECLARE_OBJECT_INFO_FINAL(
      "nn.Activation", ActivationObj, ffi::Object);
};

class Activation : public ffi::ObjectRef {
 public:
  /*! \brief Look up a registered entry by name. */
  static Activation Get(const std::string& name);

  TVM_FFI_DEFINE_OBJECT_REF_METHODS(
      Activation, ffi::ObjectRef, ActivationObj);
};
```

#### Register fields and mark as enum

```cpp
TVM_FFI_STATIC_INIT_BLOCK() {
  namespace refl = tvm::ffi::reflection;

  // Register type + fields (like ObjectDef for @c_class)
  refl::ObjectDef<ActivationObj>()
      .as_enum()                                        // marks as enum type
      .def_ro("approximate", &ActivationObj::approximate)
      .def_ro("inplace", &ActivationObj::inplace);
}
```

#### Register entries (distributed across files)

```cpp
// file: activations_core.cc
TVM_FFI_STATIC_INIT_BLOCK() {
  namespace refl = tvm::ffi::reflection;
  refl::EnumEntry<ActivationObj>("ReLU");                       // all defaults
  refl::EnumEntry<ActivationObj>("GeLU").set("approximate", true);
}

// file: activations_extra.cc (different file, different library)
TVM_FFI_STATIC_INIT_BLOCK() {
  namespace refl = tvm::ffi::reflection;
  refl::EnumEntry<ActivationObj>("SiLU").set("inplace", true);
}
```

Each `EnumEntry` is a RAII builder. Its destructor:

1. Creates a singleton `ActivationObj` instance (via `make_object`).
2. Sets the specified fields (all others use class defaults).
3. Freezes the instance (marks immutable).
4. Registers via TypeAttr (read-modify-write):
   ```cpp
   // Read existing entries Map from TypeAttr
   auto entries = TypeAttrColumn("__enum_entries__")[type_index]
                      .cast<Map<String, ObjectRef>>();
   // Add new entry
   entries.Set(String(name_), singleton);
   // Write back
   TVMFFITypeRegisterAttr(type_index, "__enum_entries__", entries);
   ```

#### Query

```cpp
// By name (the primary access pattern — like Op::Get)
Activation relu = Activation::Get("ReLU");
relu->approximate;   // false
relu->inplace;       // false

Activation gelu = Activation::Get("GeLU");
gelu->approximate;   // true

// Identity comparison
assert(relu.same_as(Activation::Get("ReLU")));
```

`Activation::Get(name)` reads the entries Map from TypeAttr and does a Map 
lookup.

#### Extensible attributes (C++)

```cpp
// file: compute_registry.cc
TVM_FFI_STATIC_INIT_BLOCK() {
  namespace refl = tvm::ffi::reflection;
  refl::EnumAttrDef<ActivationObj>("f_compute")
      .set("ReLU", relu_compute)
      .set("GeLU", gelu_compute);
}

// file: gradient_registry.cc (different file)
TVM_FFI_STATIC_INIT_BLOCK() {
  namespace refl = tvm::ffi::reflection;
  refl::EnumAttrDef<ActivationObj>("f_gradient")
      .set("ReLU", relu_gradient);
}
```

Query:

```cpp
EnumAttrMap<ActivationObj> f_compute("f_compute");
Function fn = f_compute[relu];
```

### `@c_enum`: Python wrapper for C++-defined enums

Analogous to `@c_class`. Wraps a C++-registered enum type, discovers entries, 
and installs them as class attributes:

```python
from tvm_ffi.dataclasses import c_enum

@c_enum("nn.Activation")
class Activation(EnumObject):
    approximate: bool
    inplace: bool
```

`@c_enum(type_key)`:

1. Calls `@c_class(type_key, init=False)` — registers the Python class against 
the existing C++ type.
2. Reads the `"__enum_entries__"` TypeAttr for the type's `type_index` to 
discover all registered entries.
3. For each entry, sets it as a class attribute (`cls.ReLU = singleton`).
4. Inherits `cls.get(name)`, `cls.entries()`, and `cls.def_attr()` from 
`EnumObject`.

Entries registered later (e.g., from a dynamically loaded module) are not 
automatically discovered — call `Activation._refresh_entries()` or access via 
`Activation.get("name")` which always reads from the live TypeAttr.

### Extensible Attributes (`def_attr`)

Both `@py_enum` and `@c_enum` support `tvm::Op`-style extensible per-variant 
attribute columns, registrable from anywhere:

```python
f_compute = Activation.def_attr("f_compute")

# Register from anywhere
f_compute[Activation.ReLU] = relu_fn
f_compute[Activation.GeLU] = gelu_fn

# Decorator form
@f_compute.register(Activation.SiLU)
def silu_compute(x):
    return x * sigmoid(x)

# Query
fn = f_compute[conv.activation]
```

`EnumAttrMap` is a thin wrapper around a `Map<String, Any>` stored as a 
TypeAttr:

```python
class EnumAttrMap(Generic[E]):
    """Per-variant attribute column. Analogous to tvm::OpAttrMap."""

    def __init__(self, enum_cls: type[E], attr_name: str) -> None:
        self._enum_cls = enum_cls
        self._attr_name = attr_name
        self._type_index = enum_cls.__tvm_ffi_type_info__.type_index
        self._attr_key = f"__enum_attr_{attr_name}__"
        # Load existing column (may be registered from C++)
        self._table: dict[str, Any] = dict(
            core._type_get_attr(self._type_index, self._attr_key) or {}
        )

    def __getitem__(self, variant: E) -> Any:
        name = variant.__entry_name__
        if name not in self._table:
            raise KeyError(f"{self._attr_name} not set for {name!r}")
        return self._table[name]

    def __setitem__(self, variant: E, value: Any) -> None:
        name = variant.__entry_name__
        self._table[name] = value
        core._type_register_attr(self._type_index, self._attr_key, self._table)

    def register(self, variant: E):
        def deco(fn):
            self[variant] = fn
            return fn
        return deco

    def get(self, variant: E, default=None):
        name = variant.__entry_name__
        return self._table.get(name, default)
```

### Backing Store: TypeAttr

All enum metadata is stored using the **existing TypeAttr C APIs**, keyed by 
the enum's `type_index`:

| TypeAttr name | Value type | Description |
|---|---|---|
| `__enum_entries__` | `Map<String, ObjectRef>` | All entries by name |
| `__enum_attr_{col}__` | `Map<String, Any>` | Extensible attr column |

C APIs used (all existing):

```c
// Write: store per-type attribute
int TVMFFITypeRegisterAttr(int32_t type_index,
                           const TVMFFIByteArray* attr_name,
                           const TVMFFIAny* attr_value);

// Read: get attribute column for O(1) lookup by type_index
const TVMFFITypeAttrColumn* TVMFFIGetTypeAttrColumn(
    const TVMFFIByteArray* attr_name);
```

Both C++ and Python access enum state through these C APIs. The C++ builder 
classes (`EnumEntry`, `EnumAttrDef`) and the Python decorator (`@py_enum`) are 
convenience layers on top.

**Note:** Distributed entry registration (multiple `EnumEntry` calls from 
different files) requires `TVMFFITypeRegisterAttr` to support **override** 
(re-registering the same `(type_index, attr_name)` pair with an updated Map). 
If the current implementation does not support override, a small fix to the 
existing API implementation is needed — not a new API surface.

### Cross-Language Interop

#### Wire format

Enum entries are `ObjectRef` singletons. They cross FFI as normal object 
references — the same pointer/handle on both sides within a process.

- No special type code needed.
- No `int64_t` encoding/decoding — the Object system handles it.
- Singleton identity is preserved: C++ and Python share the same underlying 
`Object*`.

#### Serialization

For JSON/binary serialization across processes, an enum entry serializes as its 
`(type_key, entry_name)` pair. The deserializer reconstructs via 
`EnumCls.get(entry_name)`.

## Full Example

```cpp
// ── C++ side ─────────────────────────────────────────────────────

class ActivationObj : public ffi::Object {
 public:
  bool approximate = false;
  bool inplace = false;
  TVM_FFI_DECLARE_OBJECT_INFO_FINAL(
      "nn.Activation", ActivationObj, ffi::Object);
};

class Activation : public ffi::ObjectRef {
 public:
  static Activation Get(const std::string& name);
  TVM_FFI_DEFINE_OBJECT_REF_METHODS(
      Activation, ffi::ObjectRef, ActivationObj);
};

// Register type + fields
TVM_FFI_STATIC_INIT_BLOCK() {
  refl::ObjectDef<ActivationObj>()
      .as_enum()
      .def_ro("approximate", &ActivationObj::approximate)
      .def_ro("inplace", &ActivationObj::inplace);
  // Register entries
  refl::EnumEntry<ActivationObj>("ReLU");
  refl::EnumEntry<ActivationObj>("GeLU").set("approximate", true);
}

// Attach extensible attribute (from another file)
TVM_FFI_STATIC_INIT_BLOCK() {
  refl::EnumAttrDef<ActivationObj>("f_compute")
      .set("ReLU", relu_fn)
      .set("GeLU", gelu_fn);
}

// Query
Activation relu = Activation::Get("ReLU");
EnumAttrMap<ActivationObj> f_compute("f_compute");
Function fn = f_compute[relu];
```

```python
# ── Python side (wrapping C++ enum) ──────────────────────────────

@c_enum("nn.Activation")
class Activation(EnumObject):
    approximate: bool
    inplace: bool

# Entries are auto-discovered from C++ registration:
Activation.ReLU.approximate      # False
Activation.GeLU.approximate      # True
Activation.get("ReLU")           # same singleton

# Query C++-registered extensible attrs
f_compute = Activation.def_attr("f_compute")
fn = f_compute[Activation.ReLU]

# Extend with Python-defined attrs
f_gradient = Activation.def_attr("f_gradient")
f_gradient[Activation.ReLU] = relu_grad_fn

# Use as @py_class field
@py_class
class Conv2D(Object):
    activation: Activation
    channels: int

conv = Conv2D(activation=Activation.ReLU, channels=64)
assert conv.activation is Activation.ReLU
```

```python
# ── Pure Python enum (no C++ involvement) ────────────────────────

@py_enum("nn.Padding")
class Padding(EnumObject):
    top: int = 0
    bottom: int = 0
    left: int = 0
    right: int = 0

    Zero: ClassVar[Padding] = entry()
    Same: ClassVar[Padding] = entry(top=1, bottom=1, left=1, right=1)
    Custom: ClassVar[Padding] = entry(top=2, bottom=2, left=3, right=3)

# Simple integer enum
@py_enum("nn.Precision")
class Precision(EnumObject):
    FP32: ClassVar[Precision] = entry(0)
    FP16: ClassVar[Precision] = entry(1)
    BF16: ClassVar[Precision] = entry(2)

Precision.FP32.value    # 0
Precision.get("FP16")   # Precision.FP16 singleton
```

## Implementation Plan

1. **`EnumObject` base class** — typed base with `get()`, `entries()`, 
`def_attr()` using `Self`.
2. **`entry()` sentinel and `@py_enum` decorator** — Python-side enum support, 
building on `@py_class`.
3. **Verify `TVMFFITypeRegisterAttr` override support** — ensure 
re-registration works for distributed entry registration; fix if needed.
4. **`ObjectDef::as_enum()` and `EnumEntry<T>`** — C++ RAII builders for 
distributed entry registration via TypeAttr.
5. **`@c_enum` decorator** — Python wrapper for C++-defined enums (mirrors 
`@c_class`).
6. **`EnumAttrDef` / `EnumAttrMap`** — Extensible per-variant attribute columns 
(C++ and Python), backed by TypeAttr.
7. **`TypeSchema` integration** — Recognize enum types in `@py_class` field 
annotations.
8. **Tests** — C++, Python, and cross-language round-trip tests.

## Summary

| Aspect | `@py_enum` (Python-defined) | `@c_enum` (C++-defined) |
|---|---|---|
| Type definition | `@py_enum class Foo(EnumObject)` | 
`ObjectDef<FooObj>().as_enum()` |
| Fields | Python annotations | `def_ro` / `def_rw` |
| Entry syntax | `X: ClassVar[Self] = entry(...)` | 
`EnumEntry<Obj>("name").set(...)` |
| Distributed entries | No (class body only) | Yes (any C++ file) |
| Python wrapper | N/A | `@c_enum("type_key")` |
| Wire format | `ObjectRef` (singleton) | `ObjectRef` (singleton) |
| Access by name | `Cls.get("name")` | `Cls::Get("name")` |
| Access by attr | `Cls.ReLU` | (via Python wrapper) |
| Ext attrs | `def_attr` / `EnumAttrMap` | `EnumAttrDef` / `EnumAttrMap` |
| Type checker | Native (`ClassVar[Self]`) | Native (via `@c_enum`) |
| Backing store | TypeAttr (`TVMFFITypeRegisterAttr`) | TypeAttr 
(`TVMFFITypeRegisterAttr`) |


-- 
Reply to this email directly or view it on GitHub:
https://github.com/apache/tvm-ffi/issues/553
You are receiving this because you are subscribed to this thread.

Message ID: <apache/tvm-ffi/issues/[email protected]>

Reply via email to