This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fff3b4bf0d [Relax][Frontend][TFLite] Support StableHLO region-based
ops and multi-subgraph models (#19587)
fff3b4bf0d is described below
commit fff3b4bf0d82cded1c397b07706daf265c441ed0
Author: HoYi <[email protected]>
AuthorDate: Thu May 21 12:37:39 2026 +0800
[Relax][Frontend][TFLite] Support StableHLO region-based ops and
multi-subgraph models (#19587)
## Summary
This PR adds Relax TFLite frontend support for 10 additional StableHLO
builtin
operators from #19519 item I, building on the 29 ops merged in PR
#19536.
The first 5 ops are direct single-subgraph converters: `CBRT`,
`REMAINDER`,
`DYNAMIC_UPDATE_SLICE`, `DOT_GENERAL`, and `CONVOLUTION`. The remaining
5 ops
are region/subgraph-based: `REDUCE`, `REDUCE_WINDOW`, `SORT`, `SCATTER`,
and
`COMPOSITE`. To support these, the TFLite frontend is extended to accept
multi-subgraph models while still converting only `Subgraphs(0)` into
the
Relax main function. Region subgraphs are consumed by their parent op
converters as needed.
Relates to #19519.
## Changes
1. **Single-subgraph ops**
- `CBRT` — sign-preserving composite expression:
`where(x < 0, -power(-x, 1/3), power(x, 1/3))`. Float dtype only.
- `REMAINDER` — truncating remainder via `x - y * trunc(x / y)`,
matching
StableHLO semantics (sign follows dividend). Float dtype only.
- `DYNAMIC_UPDATE_SLICE` — static start indices + static shapes only,
lowered
to `R.scatter_nd` with a coordinate grid generated via `np.indices`.
Runtime starts and out-of-bounds ranges raise `OpNotImplemented`.
- `DOT_GENERAL` — canonical 2D matmul subset: no batching dims,
`lhs_contracting=[1]`, `rhs_contracting=[0]`, lowered to `R.matmul`.
- `CONVOLUTION` — canonical 2D NHWC/HWIO subset with
`BatchGroupCount=1`,
`FeatureGroupCount=1`, lowered to `R.nn.conv2d`. Non-canonical dimension
numbers and grouped/depthwise conv raise `OpNotImplemented`.
2. **Multi-subgraph infrastructure**
- Lift `from_tflite()` assertion from `model.SubgraphsLength() == 1` to
`model.SubgraphsLength() >= 1`. Only `Subgraphs(0)` is converted into
the
Relax main function.
- Limit `_input_type()` to `Subgraphs(0)` inputs, preventing region
parameters from leaking as Relax main function parameters.
- Add `_get_stablehlo_simple_body_op` helper for validating and
extracting
the single operator from a region body subgraph.
- Extend test helper `_finish_tflite_model` with `extra_subgraphs`
parameter
for constructing multi-subgraph TFLite flatbuffers.
3. **Region/subgraph ops**
- `REDUCE` — single-op reducer body subgraph. Supports `ADD` → `R.sum`,
`MAXIMUM` → `R.max`, `MINIMUM` → `R.min`, `MULTIPLY` → `R.prod`.
Init value must match the reducer identity element.
- `SORT` — single-op comparator body subgraph. `LT` → ascending sort,
`GT` → descending sort via `R.sort`. `IsStable` is not mapped.
- `REDUCE_WINDOW` — NHWC 4D 2D-pooling subset with `MAXIMUM` reducer and
identity init, lowered to `R.nn.max_pool2d`. BaseDilations must be all
1.
- `SCATTER` — single-op update computation body subgraph. Supports
`ADD`/`MAXIMUM`/`MINIMUM`/`MULTIPLY` → `R.scatter_nd` with the
corresponding reduction mode. Only canonical point-update semantics
(no window dims).
- `COMPOSITE` — inlines a decomposition subgraph through a recursive
`OperatorConverter` with an isolated `ExprTable`, so decomposition
tensor
bindings cannot overwrite main graph bindings. Only supports composites
without `CompositeAttributes`.
4. **Not included**
- `STABLEHLO_RESHAPE`, `STABLEHLO_TRANSPOSE`, and `STABLEHLO_SLICE` are
left to another contributor.
- `WHILE`, `CUSTOM_CALL`, and `RNG_BIT_GENERATOR` are deferred to
follow-up
PRs.
5. **Bug fix**
- Fixed `DYNAMIC_UPDATE_SLICE` scatter_nd indices layout: `np.indices`
returns `(rank, *update_shape)` but `scatter_nd` expects
`(*update_shape, rank)`. Added `np.moveaxis` to transpose the coordinate
axis from first to last position.
## Testing
All tests use manually-built minimal TFLite flatbuffers with
`tvm.ir.assert_structural_equal`. Region/subgraph tests construct the
smallest
valid body/comparator/update subgraphs. BuiltinOptions2 ops construct
their
options via the FlatBuffers schema API.
```bash
python -m pytest tests/python/relax/test_frontend_tflite.py -k stablehlo -q
```
## Result
- 39 StableHLO operators registered in the Relax TFLite frontend (29
from
PR #19536 + 10 from this PR).
- 77 StableHLO test cases covering all registered ops, including
structural-equal tests and unsupported/error-path checks:
- `REMAINDER` truncating semantics
- `DYNAMIC_UPDATE_SLICE` with dynamic starts and out-of-bounds starts
- `DOT_GENERAL` with non-canonical contracting dimensions
- `CONVOLUTION` with non-canonical dimension numbers and
`FeatureGroupCount > 1`
- `REDUCE` with unsupported reducer and non-identity init value
- `SORT` with unsupported comparator and stable sort
- `REDUCE_WINDOW` with unsupported reducer and base dilation
- `SCATTER` with unsupported reducer and update window dims
- `COMPOSITE` with composite attributes and scope isolation
- Multi-subgraph model with unused subgraphs
- All 77 StableHLO tests pass.
## References
- Issue #19519 item I: StableHLO operators in TFLite
- PR #19536: First batch of 29 StableHLO ops
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 631 ++++++++++-
tests/python/relax/test_frontend_tflite.py | 1168 +++++++++++++++++++-
2 files changed, 1776 insertions(+), 23 deletions(-)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 145e953394..28b125eec0 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -244,15 +244,20 @@ class OperatorConverter:
"STABLEHLO_ADD": functools.partial(self._convert_stablehlo_binary,
relax_op=_op.add),
"STABLEHLO_AND": self._convert_stablehlo_and,
"STABLEHLO_BROADCAST_IN_DIM":
self._convert_stablehlo_broadcast_in_dim,
+ "STABLEHLO_CBRT": self._convert_stablehlo_cbrt,
"STABLEHLO_CLAMP": self._convert_stablehlo_clamp,
"STABLEHLO_COMPARE": self._convert_stablehlo_compare,
+ "STABLEHLO_COMPOSITE": self._convert_stablehlo_composite,
"STABLEHLO_CONCATENATE": self._convert_stablehlo_concatenate,
+ "STABLEHLO_CONVOLUTION": self._convert_stablehlo_convolution,
"STABLEHLO_CONVERT": self._convert_stablehlo_convert,
"STABLEHLO_COSINE":
functools.partial(self._convert_stablehlo_unary, relax_op=_op.cos),
"STABLEHLO_DIVIDE": functools.partial(
self._convert_stablehlo_binary, relax_op=_op.divide
),
+ "STABLEHLO_DOT_GENERAL": self._convert_stablehlo_dot_general,
"STABLEHLO_DYNAMIC_SLICE": self._convert_stablehlo_dynamic_slice,
+ "STABLEHLO_DYNAMIC_UPDATE_SLICE":
self._convert_stablehlo_dynamic_update_slice,
"STABLEHLO_EXPONENTIAL": functools.partial(
self._convert_stablehlo_unary, relax_op=_op.exp
),
@@ -280,13 +285,18 @@ class OperatorConverter:
"STABLEHLO_POWER": functools.partial(
self._convert_stablehlo_binary, relax_op=_op.power
),
+ "STABLEHLO_REDUCE": self._convert_stablehlo_reduce,
+ "STABLEHLO_REDUCE_WINDOW": self._convert_stablehlo_reduce_window,
+ "STABLEHLO_REMAINDER": self._convert_stablehlo_remainder,
"STABLEHLO_RSQRT":
functools.partial(self._convert_stablehlo_unary, relax_op=_op.rsqrt),
+ "STABLEHLO_SCATTER": self._convert_stablehlo_scatter,
"STABLEHLO_SELECT": functools.partial(
self._convert_stablehlo_ternary, relax_op=_op.where
),
"STABLEHLO_SHIFT_LEFT": functools.partial(
self._convert_stablehlo_binary, relax_op=_op.left_shift
),
+ "STABLEHLO_SORT": self._convert_stablehlo_sort,
"STABLEHLO_SUBTRACT": functools.partial(
self._convert_stablehlo_binary, relax_op=_op.subtract
),
@@ -1483,6 +1493,413 @@ class OperatorConverter:
result.Init(op_options.Bytes, op_options.Pos)
return result
+ def _get_static_tensor_shape(self, tensor, op_name):
+ """Return a statically-known TFLite tensor shape as Python ints."""
+ try:
+ return [int(dim) for dim in self.get_tensor_shape(tensor)]
+ except (TypeError, ValueError) as err:
+ raise tvm.error.OpNotImplemented(
+ f"{op_name} requires statically-known tensor shapes"
+ ) from err
+
+ def _get_stablehlo_i64_vector(self, vector, default):
+ """Convert an optional StableHLO int64 vector field to a Python int
list."""
+ if vector is None or isinstance(vector, int):
+ return list(default)
+ return [int(v) for v in vector]
+
+ def _ensure_stablehlo_float_dtype(self, expr, op_name):
+ """Return expr dtype if the StableHLO subset supports it."""
+ dtype = expr.struct_info.dtype
+ if not dtype.startswith("float"):
+ raise tvm.error.OpNotImplemented(f"{op_name} with dtype {dtype} is
not supported")
+ return dtype
+
+ def _convert_stablehlo_cbrt(self, op):
+ """Convert STABLEHLO_CBRT to a sign-preserving Relax expression."""
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 1, "input tensors length should be 1"
+ assert len(self.get_output_tensors(op)) == 1
+
+ data = self.get_tensor_expr(input_tensors[0])
+ dtype = self._ensure_stablehlo_float_dtype(data, "STABLEHLO_CBRT")
+ zero = relax.const(0, dtype)
+ exponent = relax.const(1.0 / 3.0, dtype)
+
+ is_negative = self.bb.normalize(relax.op.less(data, zero))
+ negative_base = self.bb.normalize(relax.op.negative(data))
+ negative_root = self.bb.normalize(relax.op.power(negative_base,
exponent))
+ negative_result = self.bb.normalize(relax.op.negative(negative_root))
+ positive_result = self.bb.normalize(relax.op.power(data, exponent))
+ return self.bb.normalize(relax.op.where(is_negative, negative_result,
positive_result))
+
+ def _convert_stablehlo_remainder(self, op):
+ """Convert STABLEHLO_REMAINDER to truncating remainder for float
tensors."""
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+ assert len(self.get_output_tensors(op)) == 1
+
+ lhs = self.get_tensor_expr(input_tensors[0])
+ rhs = self.get_tensor_expr(input_tensors[1])
+ self._ensure_stablehlo_float_dtype(lhs, "STABLEHLO_REMAINDER")
+ self._ensure_stablehlo_float_dtype(rhs, "STABLEHLO_REMAINDER")
+
+ quotient = self.bb.normalize(relax.op.divide(lhs, rhs))
+ truncated = self.bb.normalize(relax.op.trunc(quotient))
+ product = self.bb.normalize(relax.op.multiply(rhs, truncated))
+ return self.bb.normalize(relax.op.subtract(lhs, product))
+
+ def _get_stablehlo_simple_body_op(self, body_subgraph_index,
parent_op_name, input_count):
+ """Return the single operator from a simple StableHLO body subgraph."""
+ if body_subgraph_index <= 0 or body_subgraph_index >=
self.model.SubgraphsLength():
+ raise tvm.error.OpNotImplemented(
+ f"{parent_op_name} requires a valid non-main body subgraph"
+ )
+
+ body_subgraph = self.model.Subgraphs(body_subgraph_index)
+ if (
+ body_subgraph.InputsLength() != input_count
+ or body_subgraph.OutputsLength() != 1
+ or body_subgraph.OperatorsLength() != 1
+ ):
+ raise tvm.error.OpNotImplemented(
+ f"{parent_op_name} only supports single-op body subgraphs"
+ )
+
+ return body_subgraph.Operators(0)
+
+ def _check_stablehlo_reduce_init(
+ self, init_tensor, reducer_name, parent_op_name="STABLEHLO_REDUCE"
+ ):
+ """Validate that the StableHLO reduce init value matches the Relax
identity."""
+ if self.has_expr(init_tensor.tensor_idx):
+ raise tvm.error.OpNotImplemented(
+ f"{parent_op_name} with dynamic init values is not supported"
+ )
+
+ init_value = np.asarray(self.get_tensor_value(init_tensor))
+ if init_value.shape not in [(), (1,)]:
+ raise tvm.error.OpNotImplemented(f"{parent_op_name} requires
scalar init values")
+
+ dtype = init_value.dtype
+ scalar = init_value.item()
+ if reducer_name == "STABLEHLO_ADD":
+ is_identity = bool(np.isclose(scalar, 0))
+ elif reducer_name == "STABLEHLO_MULTIPLY":
+ is_identity = bool(np.isclose(scalar, 1))
+ elif reducer_name == "STABLEHLO_MAXIMUM":
+ if np.issubdtype(dtype, np.floating):
+ is_identity = bool(np.isneginf(scalar))
+ elif np.issubdtype(dtype, np.integer):
+ is_identity = scalar == np.iinfo(dtype).min
+ else:
+ is_identity = False
+ elif reducer_name == "STABLEHLO_MINIMUM":
+ if np.issubdtype(dtype, np.floating):
+ is_identity = bool(np.isposinf(scalar))
+ elif np.issubdtype(dtype, np.integer):
+ is_identity = scalar == np.iinfo(dtype).max
+ else:
+ is_identity = False
+ else:
+ raise tvm.error.OpNotImplemented(
+ f"{parent_op_name} reducer {reducer_name} is not supported"
+ )
+
+ if not is_identity:
+ raise tvm.error.OpNotImplemented(
+ f"{parent_op_name} init value must match the reducer identity"
+ )
+
+ def _convert_stablehlo_reduce(self, op):
+ """Convert the single-input STABLEHLO_REDUCE subset to Relax
reductions."""
+ from tflite.StablehloReduceOptions import StablehloReduceOptions
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+ assert len(self.get_output_tensors(op)) == 1
+
+ opts = self._get_stablehlo_options(op, StablehloReduceOptions)
+ dimensions = self._get_stablehlo_i64_vector(opts.DimensionsAsNumpy(),
[])
+ body_op = self._get_stablehlo_simple_body_op(
+ int(opts.BodySubgraphIndex()), "STABLEHLO_REDUCE", 2
+ )
+ reducer_name = self.get_op_code_str(body_op)
+
+ reducers = {
+ "STABLEHLO_ADD": relax.op.sum,
+ "STABLEHLO_MAXIMUM": relax.op.max,
+ "STABLEHLO_MINIMUM": relax.op.min,
+ "STABLEHLO_MULTIPLY": relax.op.prod,
+ }
+ if reducer_name not in reducers:
+ raise tvm.error.OpNotImplemented(
+ f"STABLEHLO_REDUCE reducer {reducer_name} is not supported"
+ )
+
+ self._check_stablehlo_reduce_init(input_tensors[1], reducer_name)
+ data = self.get_tensor_expr(input_tensors[0])
+ return self.bb.normalize(reducers[reducer_name](data, axis=dimensions,
keepdims=False))
+
+ def _convert_stablehlo_reduce_window(self, op):
+ """Convert the NHWC 2D max-pool STABLEHLO_REDUCE_WINDOW subset."""
+ from tflite.StablehloReduceWindowOptions import
StablehloReduceWindowOptions
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+ assert len(self.get_output_tensors(op)) == 1
+
+ opts = self._get_stablehlo_options(op, StablehloReduceWindowOptions)
+ body_op = self._get_stablehlo_simple_body_op(
+ int(opts.BodySubgraphIndex()), "STABLEHLO_REDUCE_WINDOW", 2
+ )
+ reducer_name = self.get_op_code_str(body_op)
+ if reducer_name != "STABLEHLO_MAXIMUM":
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_REDUCE_WINDOW only supports MAXIMUM reducer windows"
+ )
+ self._check_stablehlo_reduce_init(input_tensors[1], reducer_name,
"STABLEHLO_REDUCE_WINDOW")
+
+ data_shape = self._get_static_tensor_shape(input_tensors[0],
"STABLEHLO_REDUCE_WINDOW")
+ if len(data_shape) != 4:
+ raise tvm.error.OpNotImplemented("STABLEHLO_REDUCE_WINDOW only
supports 4D input")
+
+ window_dimensions =
self._get_stablehlo_i64_vector(opts.WindowDimensionsAsNumpy(), [])
+ window_strides = self._get_stablehlo_i64_vector(
+ opts.WindowStridesAsNumpy(), [1] * len(window_dimensions)
+ )
+ base_dilations = self._get_stablehlo_i64_vector(
+ opts.BaseDilationsAsNumpy(), [1] * len(window_dimensions)
+ )
+ window_dilations = self._get_stablehlo_i64_vector(
+ opts.WindowDilationsAsNumpy(), [1] * len(window_dimensions)
+ )
+ padding = self._get_stablehlo_i64_vector(
+ opts.PaddingAsNumpy(), [0] * (2 * len(window_dimensions))
+ )
+
+ if (
+ len(window_dimensions) != 4
+ or len(window_strides) != 4
+ or len(base_dilations) != 4
+ or len(window_dilations) != 4
+ or len(padding) != 8
+ ):
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_REDUCE_WINDOW only supports rank-4 window
attributes"
+ )
+ if window_dimensions[0] != 1 or window_dimensions[3] != 1:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_REDUCE_WINDOW only supports pooling over spatial
dimensions"
+ )
+ if window_strides[0] != 1 or window_strides[3] != 1:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_REDUCE_WINDOW only supports unit batch/channel
strides"
+ )
+ if base_dilations != [1, 1, 1, 1]:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_REDUCE_WINDOW with base dilation is not supported"
+ )
+ if padding[0] != 0 or padding[1] != 0 or padding[6] != 0 or padding[7]
!= 0:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_REDUCE_WINDOW only supports spatial padding"
+ )
+
+ data = self.get_tensor_expr(input_tensors[0])
+ return self.bb.normalize(
+ relax.op.nn.max_pool2d(
+ data,
+ pool_size=[window_dimensions[1], window_dimensions[2]],
+ strides=[window_strides[1], window_strides[2]],
+ padding=[padding[2], padding[4], padding[3], padding[5]],
+ dilation=[window_dilations[1], window_dilations[2]],
+ layout="NHWC",
+ out_layout="NHWC",
+ )
+ )
+
+ def _convert_stablehlo_scatter(self, op):
+ """Convert the canonical point-update STABLEHLO_SCATTER subset."""
+ from tflite.StablehloScatterOptions import StablehloScatterOptions
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 3, "input tensors length should be 3"
+ assert len(self.get_output_tensors(op)) == 1
+
+ opts = self._get_stablehlo_options(op, StablehloScatterOptions)
+ operand_shape = self._get_static_tensor_shape(input_tensors[0],
"STABLEHLO_SCATTER")
+ indices_shape = self._get_static_tensor_shape(input_tensors[1],
"STABLEHLO_SCATTER")
+ updates_shape = self._get_static_tensor_shape(input_tensors[2],
"STABLEHLO_SCATTER")
+ operand_rank = len(operand_shape)
+ indices_rank = len(indices_shape)
+
+ update_window_dims =
self._get_stablehlo_i64_vector(opts.UpdateWindowDimsAsNumpy(), [])
+ inserted_window_dims =
self._get_stablehlo_i64_vector(opts.InsertedWindowDimsAsNumpy(), [])
+ scatter_dims_to_operand_dims = self._get_stablehlo_i64_vector(
+ opts.ScatterDimsToOperandDimsAsNumpy(), []
+ )
+ index_vector_dim = int(opts.IndexVectorDim())
+
+ if indices_rank == 0 or index_vector_dim != indices_rank - 1:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_SCATTER only supports trailing index-vector
dimensions"
+ )
+ if update_window_dims:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_SCATTER only supports point updates without update
windows"
+ )
+ if inserted_window_dims != list(range(operand_rank)):
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_SCATTER only supports point updates for every
operand dimension"
+ )
+ if scatter_dims_to_operand_dims != list(range(operand_rank)):
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_SCATTER only supports canonical scatter-to-operand
dimensions"
+ )
+ if indices_shape[-1] != operand_rank or updates_shape !=
indices_shape[:-1]:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_SCATTER requires point update shapes to match
scatter indices"
+ )
+
+ body_op = self._get_stablehlo_simple_body_op(
+ int(opts.UpdateComputationSubgraphIndex()), "STABLEHLO_SCATTER", 2
+ )
+ reducer_name = self.get_op_code_str(body_op)
+ reductions = {
+ "STABLEHLO_ADD": "add",
+ "STABLEHLO_MAXIMUM": "max",
+ "STABLEHLO_MINIMUM": "min",
+ "STABLEHLO_MULTIPLY": "mul",
+ }
+ if reducer_name not in reductions:
+ raise tvm.error.OpNotImplemented(
+ f"STABLEHLO_SCATTER reducer {reducer_name} is not supported"
+ )
+
+ operand = self.get_tensor_expr(input_tensors[0])
+ indices = self.get_tensor_expr(input_tensors[1])
+ updates = self.get_tensor_expr(input_tensors[2])
+ return self.bb.normalize(
+ relax.op.scatter_nd(operand, indices, updates,
reductions[reducer_name])
+ )
+
+ def _convert_stablehlo_composite(self, op):
+ """Convert STABLEHLO_COMPOSITE by inlining a simple decomposition
subgraph."""
+ from tflite.StableHLOCompositeOptions import StableHLOCompositeOptions
+
+ input_tensors = self.get_input_tensors(op)
+ output_tensors = self.get_output_tensors(op)
+ if len(output_tensors) != 1:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_COMPOSITE only supports single-output
decompositions"
+ )
+
+ opts = self._get_stablehlo_options(op, StableHLOCompositeOptions)
+ composite_name = opts.Name()
+ composite_name = (
+ composite_name.decode("utf-8") if composite_name is not None else
"<unnamed>"
+ )
+ if opts.CompositeAttributesLength() != 0:
+ raise tvm.error.OpNotImplemented(
+ f"STABLEHLO_COMPOSITE {composite_name} with composite
attributes is not supported"
+ )
+
+ decomposition_subgraph_index = int(opts.DecompositionSubgraphIndex())
+ if (
+ decomposition_subgraph_index <= 0
+ or decomposition_subgraph_index >= self.model.SubgraphsLength()
+ ):
+ raise tvm.error.OpNotImplemented(
+ f"STABLEHLO_COMPOSITE {composite_name} requires a valid
decomposition subgraph"
+ )
+ decomposition_subgraph =
self.model.Subgraphs(decomposition_subgraph_index)
+ if decomposition_subgraph.InputsLength() != len(input_tensors):
+ raise tvm.error.OpNotImplemented(
+ f"STABLEHLO_COMPOSITE {composite_name} decomposition input
count mismatch"
+ )
+ if decomposition_subgraph.OutputsLength() != 1:
+ raise tvm.error.OpNotImplemented(
+ f"STABLEHLO_COMPOSITE {composite_name} only supports
single-output decompositions"
+ )
+
+ decomposition_exp_tab = ExprTable()
+ decomposition_converter = OperatorConverter(
+ self.model, decomposition_subgraph, decomposition_exp_tab, self.bb
+ )
+ for decomposition_input_idx, composite_input in zip(
+ decomposition_subgraph.InputsAsNumpy(), input_tensors
+ ):
+ decomposition_input_name = get_tensor_name(
+ decomposition_subgraph, int(decomposition_input_idx)
+ )
+ decomposition_exp_tab.set_expr(
+ decomposition_input_name,
+ self.get_tensor_expr(composite_input),
+ force_override=True,
+ )
+
+ decomposition_converter.check_unsupported_ops()
+ decomposition_converter.convert_op_to_relax()
+ decomposition_output_idx = int(decomposition_subgraph.Outputs(0))
+ decomposition_output_tensor = decomposition_converter.get_tensors(
+ [decomposition_output_idx]
+ )[0]
+ for const_expr, value in decomposition_exp_tab.params.values():
+ param_name = f"_param_{self.exp_tab.const_ctr}"
+ self.exp_tab.const_ctr += 1
+ self.exp_tab.params[param_name] = (const_expr, value)
+ return
decomposition_converter.get_tensor_expr(decomposition_output_tensor)
+
+ def _convert_stablehlo_sort(self, op):
+ """Convert the single-input STABLEHLO_SORT subset to Relax sort."""
+ from tflite.StablehloCompareOptions import StablehloCompareOptions
+ from tflite.StablehloComparisonDirection import
StablehloComparisonDirection
+ from tflite.StablehloComparisonType import StablehloComparisonType
+ from tflite.StablehloSortOptions import StablehloSortOptions
+
+ input_tensors = self.get_input_tensors(op)
+ output_tensors = self.get_output_tensors(op)
+ if len(input_tensors) != 1 or len(output_tensors) != 1:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_SORT only supports single-input single-output sort"
+ )
+
+ opts = self._get_stablehlo_options(op, StablehloSortOptions)
+ if opts.IsStable():
+ raise tvm.error.OpNotImplemented("STABLEHLO_SORT stable sort is
not supported")
+
+ body_op = self._get_stablehlo_simple_body_op(
+ int(opts.ComparatorSubgraphIndex()), "STABLEHLO_SORT", 2
+ )
+ comparator_name = self.get_op_code_str(body_op)
+ if comparator_name != "STABLEHLO_COMPARE":
+ raise tvm.error.OpNotImplemented(
+ f"STABLEHLO_SORT comparator {comparator_name} is not supported"
+ )
+
+ compare_opts = self._get_stablehlo_options(body_op,
StablehloCompareOptions)
+ if (
+ compare_opts.CompareType()
+ ==
StablehloComparisonType.STABLEHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER
+ ):
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_SORT with TOTALORDER comparator is not supported"
+ )
+
+ direction = compare_opts.ComparisonDirection()
+ _DIR = StablehloComparisonDirection
+ if direction == _DIR.STABLEHLO_COMPARISON_DIRECTION_LT:
+ descending = False
+ elif direction == _DIR.STABLEHLO_COMPARISON_DIRECTION_GT:
+ descending = True
+ else:
+ raise tvm.error.OpNotImplemented("STABLEHLO_SORT only supports LT
or GT comparators")
+
+ data = self.get_tensor_expr(input_tensors[0])
+ return self.bb.normalize(
+ relax.op.sort(data, axis=int(opts.Dimension()),
descending=descending)
+ )
+
def _convert_stablehlo_convert(self, op):
"""Convert STABLEHLO_CONVERT to Relax (astype).
@@ -1719,6 +2136,189 @@ class OperatorConverter:
return self.bb.normalize(relax.op.dynamic_strided_slice(operand,
begin, end, strides))
+ def _convert_stablehlo_dynamic_update_slice(self, op):
+ """Convert STABLEHLO_DYNAMIC_UPDATE_SLICE to Relax for static
starts."""
+ input_tensors = self.get_input_tensors(op)
+ # operand + update + N start-index scalars
+ assert len(input_tensors) >= 3, "input tensors length should be >= 3"
+ assert len(self.get_output_tensors(op)) == 1
+
+ operand_tensor = input_tensors[0]
+ update_tensor = input_tensors[1]
+ start_tensors = input_tensors[2:]
+
+ op_name = "STABLEHLO_DYNAMIC_UPDATE_SLICE"
+ operand_shape = self._get_static_tensor_shape(operand_tensor, op_name)
+ update_shape = self._get_static_tensor_shape(update_tensor, op_name)
+ rank = len(operand_shape)
+ if len(update_shape) != rank or len(start_tensors) != rank:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_DYNAMIC_UPDATE_SLICE requires operand, update, "
+ "and start-index ranks to match"
+ )
+
+ if any(self.has_expr(t.tensor_idx) for t in start_tensors):
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_DYNAMIC_UPDATE_SLICE with dynamic start indices is
not supported"
+ )
+
+ start_vals = [int(np.asarray(self.get_tensor_value(t)).item()) for t
in start_tensors]
+ for start, size, dim in zip(start_vals, update_shape, operand_shape):
+ if start < 0 or start + size > dim:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_DYNAMIC_UPDATE_SLICE with out-of-bounds update "
+ "indices is not supported"
+ )
+
+ update_indices = np.indices(update_shape, dtype=np.int64)
+ for axis, start in enumerate(start_vals):
+ update_indices[axis] += start
+ update_indices = np.moveaxis(update_indices, 0, -1)
+
+ operand = self.get_tensor_expr(operand_tensor)
+ update = self.get_tensor_expr(update_tensor)
+ indices = self.bb.normalize(relax.const(update_indices, dtype="int64"))
+ return self.bb.normalize(relax.op.scatter_nd(operand, indices, update,
"update"))
+
+ def _convert_stablehlo_dot_general(self, op):
+ """Convert the canonical 2D STABLEHLO_DOT_GENERAL subset to Relax
matmul."""
+ from tflite.StablehloDotGeneralOptions import
StablehloDotGeneralOptions
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+ assert len(self.get_output_tensors(op)) == 1
+
+ opts = self._get_stablehlo_options(op, StablehloDotGeneralOptions)
+ lhs_batch_dims =
self._get_stablehlo_i64_vector(opts.LhsBatchingDimensionsAsNumpy(), [])
+ rhs_batch_dims =
self._get_stablehlo_i64_vector(opts.RhsBatchingDimensionsAsNumpy(), [])
+ lhs_contract_dims = self._get_stablehlo_i64_vector(
+ opts.LhsContractingDimensionsAsNumpy(), []
+ )
+ rhs_contract_dims = self._get_stablehlo_i64_vector(
+ opts.RhsContractingDimensionsAsNumpy(), []
+ )
+
+ lhs_shape = self._get_static_tensor_shape(input_tensors[0],
"STABLEHLO_DOT_GENERAL")
+ rhs_shape = self._get_static_tensor_shape(input_tensors[1],
"STABLEHLO_DOT_GENERAL")
+ if len(lhs_shape) != 2 or len(rhs_shape) != 2:
+ raise tvm.error.OpNotImplemented("STABLEHLO_DOT_GENERAL only
supports 2D matmul")
+ if lhs_batch_dims or rhs_batch_dims:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_DOT_GENERAL with batching dimensions is not
supported"
+ )
+ if lhs_contract_dims != [1] or rhs_contract_dims != [0]:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_DOT_GENERAL only supports canonical contracting
dimensions"
+ )
+
+ lhs = self.get_tensor_expr(input_tensors[0])
+ rhs = self.get_tensor_expr(input_tensors[1])
+ return self.bb.normalize(relax.op.matmul(lhs, rhs))
+
+ def _convert_stablehlo_convolution(self, op):
+ """Convert the canonical 2D NHWC/HWIO STABLEHLO_CONVOLUTION subset."""
+ from tflite.StablehloConvolutionOptions import
StablehloConvolutionOptions
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+ assert len(self.get_output_tensors(op)) == 1
+
+ opts = self._get_stablehlo_options(op, StablehloConvolutionOptions)
+ input_spatial_dims = self._get_stablehlo_i64_vector(
+ opts.InputSpatialDimensionsAsNumpy(), []
+ )
+ kernel_spatial_dims = self._get_stablehlo_i64_vector(
+ opts.KernelSpatialDimensionsAsNumpy(), []
+ )
+ output_spatial_dims = self._get_stablehlo_i64_vector(
+ opts.OutputSpatialDimensionsAsNumpy(), []
+ )
+ if input_spatial_dims != [1, 2]:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_CONVOLUTION only supports NHWC input layout"
+ )
+ if kernel_spatial_dims != [0, 1]:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_CONVOLUTION only supports HWIO kernel layout"
+ )
+ if output_spatial_dims != [1, 2]:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_CONVOLUTION only supports NHWC output layout"
+ )
+
+ if (
+ int(opts.InputBatchDimension()) != 0
+ or int(opts.InputFeatureDimension()) != 3
+ or int(opts.KernelInputFeatureDimension()) != 2
+ or int(opts.KernelOutputFeatureDimension()) != 3
+ or int(opts.OutputBatchDimension()) != 0
+ or int(opts.OutputFeatureDimension()) != 3
+ ):
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_CONVOLUTION only supports canonical NHWC/HWIO
dimension numbers"
+ )
+ if int(opts.BatchGroupCount()) != 1:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_CONVOLUTION with batch_group_count > 1 is not
supported"
+ )
+ if int(opts.FeatureGroupCount()) != 1:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_CONVOLUTION with feature_group_count > 1 is not
supported"
+ )
+
+ data_shape = self._get_static_tensor_shape(input_tensors[0],
"STABLEHLO_CONVOLUTION")
+ kernel_shape = self._get_static_tensor_shape(input_tensors[1],
"STABLEHLO_CONVOLUTION")
+ if len(data_shape) != 4 or len(kernel_shape) != 4:
+ raise tvm.error.OpNotImplemented("STABLEHLO_CONVOLUTION only
supports 2D convolution")
+ if data_shape[3] != kernel_shape[2]:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_CONVOLUTION input channels must match kernel input
channels"
+ )
+
+ window_strides =
self._get_stablehlo_i64_vector(opts.WindowStridesAsNumpy(), [1, 1])
+ padding = self._get_stablehlo_i64_vector(opts.PaddingAsNumpy(), [0, 0,
0, 0])
+ lhs_dilation =
self._get_stablehlo_i64_vector(opts.LhsDilationAsNumpy(), [1, 1])
+ rhs_dilation =
self._get_stablehlo_i64_vector(opts.RhsDilationAsNumpy(), [1, 1])
+ window_reversal = opts.WindowReversalAsNumpy()
+ window_reversal = (
+ [False, False] if window_reversal is None else [bool(v) for v in
window_reversal]
+ )
+
+ if len(window_strides) != 2 or len(rhs_dilation) != 2:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_CONVOLUTION only supports two spatial dimensions"
+ )
+ if lhs_dilation != [1, 1]:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_CONVOLUTION with lhs dilation is not supported"
+ )
+ if any(window_reversal):
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_CONVOLUTION with window reversal is not supported"
+ )
+ if len(padding) != 4:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_CONVOLUTION only supports 2D low/high padding"
+ )
+
+ # StableHLO stores padding as [low_h, high_h, low_w, high_w].
+ relax_padding = [padding[0], padding[2], padding[1], padding[3]]
+ data = self.get_tensor_expr(input_tensors[0])
+ kernel = self.get_tensor_expr(input_tensors[1])
+ self._ensure_stablehlo_float_dtype(data, "STABLEHLO_CONVOLUTION")
+ self._ensure_stablehlo_float_dtype(kernel, "STABLEHLO_CONVOLUTION")
+ return self.bb.normalize(
+ relax.op.nn.conv2d(
+ data,
+ kernel,
+ strides=window_strides,
+ padding=relax_padding,
+ dilation=rhs_dilation,
+ data_layout="NHWC",
+ kernel_layout="HWIO",
+ )
+ )
+
def _convert_stablehlo_gather(self, op):
"""Convert STABLEHLO_GATHER to Relax (take-equivalent subset only).
@@ -5528,19 +6128,18 @@ def _input_type(model):
assert subgraph_count > 0
shape_dict = {}
dtype_dict = {}
- for subgraph_index in range(subgraph_count):
- subgraph = model.Subgraphs(subgraph_index)
- inputs_count = subgraph.InputsLength()
- # TFLite subgraphs can validly have zero inputs (e.g. constant-only
RANGE models).
- for input_index in range(inputs_count):
- input_ = subgraph.Inputs(input_index)
- assert subgraph.TensorsLength() > input_
- tensor = subgraph.Tensors(input_)
- input_shape = tuple(tensor.ShapeAsNumpy())
- tensor_type = tensor.Type()
- input_name = get_tensor_name(subgraph, input_)
- shape_dict[input_name] = input_shape
- dtype_dict[input_name] = _decode_type(tensor_type)
+ subgraph = model.Subgraphs(0)
+ inputs_count = subgraph.InputsLength()
+ # TFLite subgraphs can validly have zero inputs (e.g. constant-only RANGE
models).
+ for input_index in range(inputs_count):
+ input_ = subgraph.Inputs(input_index)
+ assert subgraph.TensorsLength() > input_
+ tensor = subgraph.Tensors(input_)
+ input_shape = tuple(tensor.ShapeAsNumpy())
+ tensor_type = tensor.Type()
+ input_name = get_tensor_name(subgraph, input_)
+ shape_dict[input_name] = input_shape
+ dtype_dict[input_name] = _decode_type(tensor_type)
return shape_dict, dtype_dict
@@ -5652,8 +6251,10 @@ def from_tflite(
if dtype_dict is not None:
_dtype_dict.update(dtype_dict)
- # keep the same as tflite
- assert model.SubgraphsLength() == 1, "only support one subgraph (main
subgraph)"
+ # Only Subgraphs(0) is converted into Relax main. Additional subgraphs are
+ # region bodies referenced by specific TFLite ops and are consumed by those
+ # op converters as needed.
+ assert model.SubgraphsLength() >= 1, "TFLite model must contain at least
one subgraph"
subgraph = model.Subgraphs(0)
# model inputs / outputs
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index bb2fb0bfa7..031c1553d8 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -3677,6 +3677,9 @@ _tfl_dilate_options =
_get_tflite_schema_module("DilateOptions")
# ── StableHLO BuiltinOptions2 schema modules ────────────────────────────
_tfl_stablehlo_concat_opts =
_get_tflite_schema_module("StablehloConcatenateOptions")
_tfl_stablehlo_bcast_opts =
_get_tflite_schema_module("StablehloBroadcastInDimOptions")
+_tfl_stablehlo_composite_opts =
_get_tflite_schema_module("StableHLOCompositeOptions")
+_tfl_stablehlo_conv_opts =
_get_tflite_schema_module("StablehloConvolutionOptions")
+_tfl_stablehlo_dot_opts =
_get_tflite_schema_module("StablehloDotGeneralOptions")
_tfl_stablehlo_iota_opts = _get_tflite_schema_module("StablehloIotaOptions")
_tfl_stablehlo_compare_opts =
_get_tflite_schema_module("StablehloCompareOptions")
_tfl_stablehlo_comp_dir =
_get_tflite_schema_module("StablehloComparisonDirection")
@@ -3684,6 +3687,10 @@ _tfl_stablehlo_comp_type =
_get_tflite_schema_module("StablehloComparisonType")
_tfl_stablehlo_pad_opts = _get_tflite_schema_module("StablehloPadOptions")
_tfl_stablehlo_dyn_slice_opts =
_get_tflite_schema_module("StablehloDynamicSliceOptions")
_tfl_stablehlo_gather_opts =
_get_tflite_schema_module("StablehloGatherOptions")
+_tfl_stablehlo_reduce_opts =
_get_tflite_schema_module("StablehloReduceOptions")
+_tfl_stablehlo_reduce_window_opts =
_get_tflite_schema_module("StablehloReduceWindowOptions")
+_tfl_stablehlo_scatter_opts =
_get_tflite_schema_module("StablehloScatterOptions")
+_tfl_stablehlo_sort_opts = _get_tflite_schema_module("StablehloSortOptions")
_tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata")
_tfl_fully_connected_options =
_get_tflite_schema_module("FullyConnectedOptions")
_tfl_int32_vector = _get_tflite_schema_module("Int32Vector")
@@ -3721,6 +3728,20 @@ def _tflite_int32_vector(builder, start_vector_fn,
values):
return builder.EndVector()
+def _tflite_int64_vector(builder, start_vector_fn, values):
+ start_vector_fn(builder, len(values))
+ for value in reversed(values):
+ builder.PrependInt64(value)
+ return builder.EndVector()
+
+
+def _tflite_bool_vector(builder, start_vector_fn, values):
+ start_vector_fn(builder, len(values))
+ for value in reversed(values):
+ builder.PrependBool(value)
+ return builder.EndVector()
+
+
def _tflite_offset_vector(builder, start_vector_fn, offsets):
start_vector_fn(builder, len(offsets))
for offset in reversed(offsets):
@@ -3834,12 +3855,15 @@ def _build_subgraph(builder, *, tensors, operators,
inputs, outputs):
return _tfl_subgraph.SubGraphEnd(builder)
-def _finish_tflite_model(builder, *, subgraph, operator_codes, buffers):
+def _finish_tflite_model(builder, *, subgraph, operator_codes, buffers,
extra_subgraphs=None):
+ all_subgraphs = [subgraph] + (extra_subgraphs or [])
buffers_vec = _tflite_offset_vector(builder,
_tfl_model.ModelStartBuffersVector, buffers)
opcodes_vec = _tflite_offset_vector(
builder, _tfl_model.ModelStartOperatorCodesVector, operator_codes
)
- subgraphs_vec = _tflite_offset_vector(builder,
_tfl_model.ModelStartSubgraphsVector, [subgraph])
+ subgraphs_vec = _tflite_offset_vector(
+ builder, _tfl_model.ModelStartSubgraphsVector, all_subgraphs
+ )
_tfl_model.ModelStart(builder)
_tfl_model.ModelAddBuffers(builder, buffers_vec)
@@ -3896,6 +3920,453 @@ def _build_stablehlo_model(*, builtin_name,
input_count):
)
+def _build_stablehlo_model_with_unused_subgraph():
+ """Build a StableHLO model with an unused extra subgraph."""
+ builder = flatbuffers.Builder(1024)
+ builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_ADD")
+
+ main_tensors = [_build_tensor(builder, buffer_idx, [2, 2]) for buffer_idx
in range(3)]
+ main_op = _build_operator(builder, 0, [0, 1], [2])
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=main_tensors,
+ operators=[main_op],
+ inputs=[0, 1],
+ outputs=[2],
+ )
+
+ # Give the unused subgraph a conflicting input tensor name and different
+ # shape. from_tflite should infer the main function input shape only from
+ # Subgraphs(0).
+ extra_tensors = [_build_tensor(builder, buffer_idx, [4, 4]) for buffer_idx
in range(3, 6)]
+ extra_op = _build_operator(builder, 0, [0, 1], [2])
+ extra_subgraph = _build_subgraph(
+ builder,
+ tensors=extra_tensors,
+ operators=[extra_op],
+ inputs=[0, 1],
+ outputs=[2],
+ )
+
+ operator_codes = [_build_operator_code(builder, builtin_op)]
+ buffers = [_build_buffer(builder) for _ in range(6)]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ extra_subgraphs=[extra_subgraph],
+ operator_codes=operator_codes,
+ buffers=buffers,
+ )
+
+
+def _build_stablehlo_reduce_model(reducer_name, init_value):
+ """Build a single-input STABLEHLO_REDUCE model with a binary reducer
body."""
+ builder = flatbuffers.Builder(1024)
+
+ dimensions_vec = _tflite_int64_vector(
+ builder,
+ _tfl_stablehlo_reduce_opts.StablehloReduceOptionsStartDimensionsVector,
+ [1],
+ )
+ _tfl_stablehlo_reduce_opts.StablehloReduceOptionsStart(builder)
+ _tfl_stablehlo_reduce_opts.StablehloReduceOptionsAddDimensions(builder,
dimensions_vec)
+
_tfl_stablehlo_reduce_opts.StablehloReduceOptionsAddBodySubgraphIndex(builder,
1)
+ reduce_opts = _tfl_stablehlo_reduce_opts.StablehloReduceOptionsEnd(builder)
+
+ reduce_builtin = _get_stablehlo_builtin_operator("STABLEHLO_REDUCE")
+ reducer_builtin = _get_stablehlo_builtin_operator(reducer_name)
+ reduce_code = _build_operator_code(builder, reduce_builtin)
+ reducer_code = _build_operator_code(builder, reducer_builtin)
+
+ main_tensors = [
+ _build_tensor(builder, 0, [2, 3]),
+ _build_tensor(builder, 1, []),
+ _build_tensor(builder, 2, [2]),
+ ]
+ reduce_op = _build_operator(
+ builder,
+ 0,
+ [0, 1],
+ [2],
+ builtin_options2_type=_tfl_builtin_options2.StablehloReduceOptions,
+ builtin_options2=reduce_opts,
+ )
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=main_tensors,
+ operators=[reduce_op],
+ inputs=[0],
+ outputs=[2],
+ )
+
+ body_tensors = [_build_tensor(builder, buffer_idx, []) for buffer_idx in
range(3, 6)]
+ reducer_op = _build_operator(builder, 1, [0, 1], [2])
+ body_subgraph = _build_subgraph(
+ builder,
+ tensors=body_tensors,
+ operators=[reducer_op],
+ inputs=[0, 1],
+ outputs=[2],
+ )
+
+ buffers = [
+ _build_buffer(builder),
+ _build_buffer(builder, np.array(init_value,
dtype=np.float32).tobytes()),
+ _build_buffer(builder),
+ _build_buffer(builder),
+ _build_buffer(builder),
+ _build_buffer(builder),
+ ]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ extra_subgraphs=[body_subgraph],
+ operator_codes=[reduce_code, reducer_code],
+ buffers=buffers,
+ )
+
+
+def _build_stablehlo_sort_model(comparison_direction, is_stable=False):
+ """Build a single-input STABLEHLO_SORT model with a compare body."""
+ builder = flatbuffers.Builder(1024)
+
+ _tfl_stablehlo_sort_opts.StablehloSortOptionsStart(builder)
+ _tfl_stablehlo_sort_opts.StablehloSortOptionsAddDimension(builder, 1)
+ _tfl_stablehlo_sort_opts.StablehloSortOptionsAddIsStable(builder,
is_stable)
+
_tfl_stablehlo_sort_opts.StablehloSortOptionsAddComparatorSubgraphIndex(builder,
1)
+ sort_opts = _tfl_stablehlo_sort_opts.StablehloSortOptionsEnd(builder)
+
+ _tfl_stablehlo_compare_opts.StablehloCompareOptionsStart(builder)
+ _tfl_stablehlo_compare_opts.StablehloCompareOptionsAddComparisonDirection(
+ builder, comparison_direction
+ )
+ compare_opts =
_tfl_stablehlo_compare_opts.StablehloCompareOptionsEnd(builder)
+
+ sort_builtin = _get_stablehlo_builtin_operator("STABLEHLO_SORT")
+ compare_builtin = _get_stablehlo_builtin_operator("STABLEHLO_COMPARE")
+ sort_code = _build_operator_code(builder, sort_builtin)
+ compare_code = _build_operator_code(builder, compare_builtin)
+
+ main_tensors = [
+ _build_tensor(builder, 0, [2, 3]),
+ _build_tensor(builder, 1, [2, 3]),
+ ]
+ sort_op = _build_operator(
+ builder,
+ 0,
+ [0],
+ [1],
+ builtin_options2_type=_tfl_builtin_options2.StablehloSortOptions,
+ builtin_options2=sort_opts,
+ )
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=main_tensors,
+ operators=[sort_op],
+ inputs=[0],
+ outputs=[1],
+ )
+
+ body_tensors = [
+ _build_tensor(builder, 2, []),
+ _build_tensor(builder, 3, []),
+ _build_tensor(builder, 4, [], tensor_type=_tfl_tensor_type.BOOL),
+ ]
+ compare_op = _build_operator(
+ builder,
+ 1,
+ [0, 1],
+ [2],
+ builtin_options2_type=_tfl_builtin_options2.StablehloCompareOptions,
+ builtin_options2=compare_opts,
+ )
+ body_subgraph = _build_subgraph(
+ builder,
+ tensors=body_tensors,
+ operators=[compare_op],
+ inputs=[0, 1],
+ outputs=[2],
+ )
+
+ buffers = [_build_buffer(builder) for _ in range(5)]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ extra_subgraphs=[body_subgraph],
+ operator_codes=[sort_code, compare_code],
+ buffers=buffers,
+ )
+
+
+def _build_stablehlo_reduce_window_model(
+ reducer_name="STABLEHLO_MAXIMUM",
+ init_value=-np.inf,
+ base_dilations=None,
+):
+ """Build an NHWC 2D STABLEHLO_REDUCE_WINDOW model."""
+ builder = flatbuffers.Builder(1024)
+ if base_dilations is None:
+ base_dilations = [1, 1, 1, 1]
+
+ window_dimensions_vec = _tflite_int64_vector(
+ builder,
+
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartWindowDimensionsVector,
+ [1, 2, 2, 1],
+ )
+ window_strides_vec = _tflite_int64_vector(
+ builder,
+
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartWindowStridesVector,
+ [1, 2, 2, 1],
+ )
+ base_dilations_vec = _tflite_int64_vector(
+ builder,
+
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartBaseDilationsVector,
+ base_dilations,
+ )
+ window_dilations_vec = _tflite_int64_vector(
+ builder,
+
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartWindowDilationsVector,
+ [1, 1, 1, 1],
+ )
+ padding_vec = _tflite_int64_vector(
+ builder,
+
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartPaddingVector,
+ [0, 0, 0, 0, 0, 0, 0, 0],
+ )
+
+
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStart(builder)
+
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddWindowDimensions(
+ builder, window_dimensions_vec
+ )
+
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddWindowStrides(
+ builder, window_strides_vec
+ )
+
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddBaseDilations(
+ builder, base_dilations_vec
+ )
+
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddWindowDilations(
+ builder, window_dilations_vec
+ )
+
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddPadding(builder,
padding_vec)
+
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddBodySubgraphIndex(builder,
1)
+ reduce_window_opts =
_tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsEnd(builder)
+
+ reduce_window_builtin =
_get_stablehlo_builtin_operator("STABLEHLO_REDUCE_WINDOW")
+ reducer_builtin = _get_stablehlo_builtin_operator(reducer_name)
+ reduce_window_code = _build_operator_code(builder, reduce_window_builtin)
+ reducer_code = _build_operator_code(builder, reducer_builtin)
+
+ main_tensors = [
+ _build_tensor(builder, 0, [1, 4, 4, 1]),
+ _build_tensor(builder, 1, []),
+ _build_tensor(builder, 2, [1, 2, 2, 1]),
+ ]
+ reduce_window_op = _build_operator(
+ builder,
+ 0,
+ [0, 1],
+ [2],
+
builtin_options2_type=_tfl_builtin_options2.StablehloReduceWindowOptions,
+ builtin_options2=reduce_window_opts,
+ )
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=main_tensors,
+ operators=[reduce_window_op],
+ inputs=[0],
+ outputs=[2],
+ )
+
+ body_tensors = [_build_tensor(builder, buffer_idx, []) for buffer_idx in
range(3, 6)]
+ reducer_op = _build_operator(builder, 1, [0, 1], [2])
+ body_subgraph = _build_subgraph(
+ builder,
+ tensors=body_tensors,
+ operators=[reducer_op],
+ inputs=[0, 1],
+ outputs=[2],
+ )
+
+ buffers = [
+ _build_buffer(builder),
+ _build_buffer(builder, np.array(init_value,
dtype=np.float32).tobytes()),
+ _build_buffer(builder),
+ _build_buffer(builder),
+ _build_buffer(builder),
+ _build_buffer(builder),
+ ]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ extra_subgraphs=[body_subgraph],
+ operator_codes=[reduce_window_code, reducer_code],
+ buffers=buffers,
+ )
+
+
+def _build_stablehlo_scatter_model(reducer_name="STABLEHLO_ADD",
update_window_dims=None):
+ """Build a canonical point-update STABLEHLO_SCATTER model."""
+ builder = flatbuffers.Builder(1024)
+ if update_window_dims is None:
+ update_window_dims = []
+
+ update_window_dims_vec = _tflite_int64_vector(
+ builder,
+
_tfl_stablehlo_scatter_opts.StablehloScatterOptionsStartUpdateWindowDimsVector,
+ update_window_dims,
+ )
+ inserted_window_dims_vec = _tflite_int64_vector(
+ builder,
+
_tfl_stablehlo_scatter_opts.StablehloScatterOptionsStartInsertedWindowDimsVector,
+ [0],
+ )
+ scatter_dims_vec = _tflite_int64_vector(
+ builder,
+
_tfl_stablehlo_scatter_opts.StablehloScatterOptionsStartScatterDimsToOperandDimsVector,
+ [0],
+ )
+
+ _tfl_stablehlo_scatter_opts.StablehloScatterOptionsStart(builder)
+ _tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddUpdateWindowDims(
+ builder, update_window_dims_vec
+ )
+ _tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddInsertedWindowDims(
+ builder, inserted_window_dims_vec
+ )
+
_tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddScatterDimsToOperandDims(
+ builder, scatter_dims_vec
+ )
+
_tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddIndexVectorDim(builder, 1)
+
_tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddUpdateComputationSubgraphIndex(builder,
1)
+ scatter_opts =
_tfl_stablehlo_scatter_opts.StablehloScatterOptionsEnd(builder)
+
+ scatter_builtin = _get_stablehlo_builtin_operator("STABLEHLO_SCATTER")
+ reducer_builtin = _get_stablehlo_builtin_operator(reducer_name)
+ scatter_code = _build_operator_code(builder, scatter_builtin)
+ reducer_code = _build_operator_code(builder, reducer_builtin)
+
+ main_tensors = [
+ _build_tensor(builder, 0, [4]),
+ _build_tensor(builder, 1, [2, 1], tensor_type=_tfl_tensor_type.INT32),
+ _build_tensor(builder, 2, [2]),
+ _build_tensor(builder, 3, [4]),
+ ]
+ scatter_op = _build_operator(
+ builder,
+ 0,
+ [0, 1, 2],
+ [3],
+ builtin_options2_type=_tfl_builtin_options2.StablehloScatterOptions,
+ builtin_options2=scatter_opts,
+ )
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=main_tensors,
+ operators=[scatter_op],
+ inputs=[0, 1, 2],
+ outputs=[3],
+ )
+
+ body_tensors = [_build_tensor(builder, buffer_idx, []) for buffer_idx in
range(4, 7)]
+ reducer_op = _build_operator(builder, 1, [0, 1], [2])
+ body_subgraph = _build_subgraph(
+ builder,
+ tensors=body_tensors,
+ operators=[reducer_op],
+ inputs=[0, 1],
+ outputs=[2],
+ )
+
+ buffers = [_build_buffer(builder) for _ in range(7)]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ extra_subgraphs=[body_subgraph],
+ operator_codes=[scatter_code, reducer_code],
+ buffers=buffers,
+ )
+
+
+def _build_stablehlo_composite_model(with_attributes=False,
use_main_input_after_composite=False):
+ """Build a STABLEHLO_COMPOSITE model that decomposes to
STABLEHLO_NEGATE."""
+ builder = flatbuffers.Builder(1024)
+
+ name = builder.CreateString("test.negate")
+ attributes = None
+ if with_attributes:
+
_tfl_stablehlo_composite_opts.StableHLOCompositeOptionsStartCompositeAttributesVector(
+ builder, 1
+ )
+ builder.PrependUint8(1)
+ attributes = builder.EndVector()
+
+ _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsStart(builder)
+ _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsAddName(builder,
name)
+ _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsAddVersion(builder,
1)
+
_tfl_stablehlo_composite_opts.StableHLOCompositeOptionsAddDecompositionSubgraphIndex(builder,
1)
+ if attributes is not None:
+
_tfl_stablehlo_composite_opts.StableHLOCompositeOptionsAddCompositeAttributes(
+ builder, attributes
+ )
+ composite_opts =
_tfl_stablehlo_composite_opts.StableHLOCompositeOptionsEnd(builder)
+
+ composite_builtin = _get_stablehlo_builtin_operator("STABLEHLO_COMPOSITE")
+ negate_builtin = _get_stablehlo_builtin_operator("STABLEHLO_NEGATE")
+ add_builtin = _get_stablehlo_builtin_operator("STABLEHLO_ADD")
+ composite_code = _build_operator_code(builder, composite_builtin)
+ negate_code = _build_operator_code(builder, negate_builtin)
+ add_code = _build_operator_code(builder, add_builtin)
+
+ main_tensors = [
+ _build_tensor(builder, 0, [2, 2]),
+ _build_tensor(builder, 1, [2, 2]),
+ _build_tensor(builder, 2, [2, 2]),
+ ]
+ composite_op = _build_operator(
+ builder,
+ 0,
+ [0],
+ [1],
+ builtin_options2_type=_tfl_builtin_options2.StableHLOCompositeOptions,
+ builtin_options2=composite_opts,
+ )
+ main_ops = [composite_op]
+ main_outputs = [1]
+ if use_main_input_after_composite:
+ main_ops.append(_build_operator(builder, 2, [0, 1], [2]))
+ main_outputs = [2]
+
+ main_subgraph = _build_subgraph(
+ builder,
+ tensors=main_tensors,
+ operators=main_ops,
+ inputs=[0],
+ outputs=main_outputs,
+ )
+
+ decomposition_tensors = [
+ _build_tensor(builder, 2, [2, 2]),
+ _build_tensor(builder, 3, [2, 2]),
+ ]
+ negate_op = _build_operator(builder, 1, [0], [1])
+ decomposition_subgraph = _build_subgraph(
+ builder,
+ tensors=decomposition_tensors,
+ operators=[negate_op],
+ inputs=[0],
+ outputs=[1],
+ )
+
+ buffers = [_build_buffer(builder) for _ in range(4)]
+ return _finish_tflite_model(
+ builder,
+ subgraph=main_subgraph,
+ extra_subgraphs=[decomposition_subgraph],
+ operator_codes=[composite_code, negate_code, add_code],
+ buffers=buffers,
+ )
+
+
def _build_stablehlo_typed_binary_model(*, builtin_name, tensor_type):
"""Build a minimal TFLite StableHLO binary model with the requested tensor
type."""
builder = flatbuffers.Builder(1024)
@@ -3972,19 +4443,302 @@ def test_stablehlo_binary(builtin_name, relax_op):
@I.ir_module
class Expected:
@R.function
- def main(
- x: R.Tensor((2, 2), dtype="float32"),
- y: R.Tensor((2, 2), dtype="float32"),
- ) -> R.Tensor((2, 2), dtype="float32"):
- R.func_attr({"num_input": 2})
+ def main(
+ x: R.Tensor((2, 2), dtype="float32"),
+ y: R.Tensor((2, 2), dtype="float32"),
+ ) -> R.Tensor((2, 2), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ gv: R.Tensor((2, 2), dtype="float32") = relax_op(x, y)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_model_with_unused_subgraph():
+ """TFLite StableHLO import ignores unused non-main subgraphs."""
+ mod =
_load_model_from_buffer(_build_stablehlo_model_with_unused_subgraph())
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 2), dtype="float32"),
+ y: R.Tensor((2, 2), dtype="float32"),
+ ) -> R.Tensor((2, 2), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ gv: R.Tensor((2, 2), dtype="float32") = R.add(x, y)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
[email protected](
+ "reducer_name, init_value, relax_op",
+ [
+ ("STABLEHLO_ADD", 0.0, R.sum),
+ ("STABLEHLO_MAXIMUM", -np.inf, R.max),
+ ("STABLEHLO_MINIMUM", np.inf, R.min),
+ ("STABLEHLO_MULTIPLY", 1.0, R.prod),
+ ],
+)
+def test_stablehlo_reduce(reducer_name, init_value, relax_op):
+ """TFLite StableHLO REDUCE with simple binary reducer body subgraphs."""
+ mod = _load_model_from_buffer(_build_stablehlo_reduce_model(reducer_name,
init_value))
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2,),
dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor((2,), dtype="float32") = relax_op(x, axis=[1],
keepdims=False)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_reduce_unsupported_reducer():
+ """TFLite StableHLO REDUCE rejects unsupported body reducer ops."""
+ buf = _build_stablehlo_reduce_model("STABLEHLO_SUBTRACT", 0.0)
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="reducer"):
+ from_tflite(tflite_model)
+
+
+def test_stablehlo_reduce_non_identity_init_unsupported():
+ """TFLite StableHLO REDUCE rejects init values that Relax reductions
cannot express."""
+ buf = _build_stablehlo_reduce_model("STABLEHLO_ADD", 1.0)
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="init value"):
+ from_tflite(tflite_model)
+
+
[email protected](
+ "comparison_direction, descending",
+ [
+ (
+
_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LT,
+ False,
+ ),
+ (
+
_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_GT,
+ True,
+ ),
+ ],
+)
+def test_stablehlo_sort(comparison_direction, descending):
+ """TFLite StableHLO SORT with LT/GT scalar compare body subgraphs."""
+ mod =
_load_model_from_buffer(_build_stablehlo_sort_model(comparison_direction))
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor((2, 3), dtype="float32") = R.sort(x, axis=1,
descending=descending)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_sort_unsupported_comparator():
+ """TFLite StableHLO SORT rejects non-ordering comparators."""
+ _DIR = _tfl_stablehlo_comp_dir.StablehloComparisonDirection
+ buf = _build_stablehlo_sort_model(_DIR.STABLEHLO_COMPARISON_DIRECTION_EQ)
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="LT or GT"):
+ from_tflite(tflite_model)
+
+
+def test_stablehlo_sort_stable_unsupported():
+ """TFLite StableHLO SORT rejects stable sort until Relax exposes that
contract."""
+ _DIR = _tfl_stablehlo_comp_dir.StablehloComparisonDirection
+ buf = _build_stablehlo_sort_model(_DIR.STABLEHLO_COMPARISON_DIRECTION_LT,
is_stable=True)
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="stable sort"):
+ from_tflite(tflite_model)
+
+
+def test_stablehlo_reduce_window_max_pool2d():
+ """TFLite StableHLO REDUCE_WINDOW max reducer lowers to NHWC max_pool2d."""
+ mod = _load_model_from_buffer(_build_stablehlo_reduce_window_model())
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((1, 4, 4, 1), dtype="float32"),
+ ) -> R.Tensor((1, 2, 2, 1), dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor((1, 2, 2, 1), dtype="float32") = R.nn.max_pool2d(
+ x,
+ pool_size=[2, 2],
+ strides=[2, 2],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ ceil_mode=False,
+ layout="NHWC",
+ out_layout="NHWC",
+ )
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_reduce_window_unsupported_reducer():
+ """TFLite StableHLO REDUCE_WINDOW rejects non-max reducers in the pool
subset."""
+ buf = _build_stablehlo_reduce_window_model(reducer_name="STABLEHLO_ADD",
init_value=0.0)
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="MAXIMUM"):
+ from_tflite(tflite_model)
+
+
+def test_stablehlo_reduce_window_base_dilation_unsupported():
+ """TFLite StableHLO REDUCE_WINDOW rejects base dilation in the pool
subset."""
+ buf = _build_stablehlo_reduce_window_model(base_dilations=[1, 2, 1, 1])
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="base dilation"):
+ from_tflite(tflite_model)
+
+
[email protected](
+ "reducer_name, reduction",
+ [
+ ("STABLEHLO_ADD", "add"),
+ ("STABLEHLO_MAXIMUM", "max"),
+ ("STABLEHLO_MINIMUM", "min"),
+ ("STABLEHLO_MULTIPLY", "mul"),
+ ],
+)
+def test_stablehlo_scatter(reducer_name, reduction):
+ """TFLite StableHLO SCATTER point updates lower to Relax scatter_nd."""
+ mod = _load_model_from_buffer(_build_stablehlo_scatter_model(reducer_name))
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ operand: R.Tensor((4,), dtype="float32"),
+ indices: R.Tensor((2, 1), dtype="int32"),
+ updates: R.Tensor((2,), dtype="float32"),
+ ) -> R.Tensor((4,), dtype="float32"):
+ R.func_attr({"num_input": 3})
+ with R.dataflow():
+ gv: R.Tensor((4,), dtype="float32") = R.scatter_nd(
+ operand, indices, updates, reduction=reduction
+ )
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_scatter_unsupported_reducer():
+ """TFLite StableHLO SCATTER rejects unsupported update computation ops."""
+ buf = _build_stablehlo_scatter_model(reducer_name="STABLEHLO_SUBTRACT")
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="reducer"):
+ from_tflite(tflite_model)
+
+
+def test_stablehlo_scatter_update_window_unsupported():
+ """TFLite StableHLO SCATTER rejects slice update windows in the point
subset."""
+ buf = _build_stablehlo_scatter_model(update_window_dims=[0])
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="point updates"):
+ from_tflite(tflite_model)
+
+
+def test_stablehlo_composite():
+ """TFLite StableHLO COMPOSITE inlines a simple decomposition subgraph."""
+ mod = _load_model_from_buffer(_build_stablehlo_composite_model())
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2),
dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor((2, 2), dtype="float32") = R.negative(x)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_composite_does_not_overwrite_main_bindings():
+ """TFLite StableHLO COMPOSITE decomposition tensor names are scoped
locally."""
+ mod = _load_model_from_buffer(
+ _build_stablehlo_composite_model(use_main_input_after_composite=True)
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2),
dtype="float32"):
+ R.func_attr({"num_input": 1})
with R.dataflow():
- gv: R.Tensor((2, 2), dtype="float32") = relax_op(x, y)
+ lv: R.Tensor((2, 2), dtype="float32") = R.negative(x)
+ gv: R.Tensor((2, 2), dtype="float32") = R.add(x, lv)
R.output(gv)
return gv
tvm.ir.assert_structural_equal(mod, Expected)
+def test_stablehlo_composite_attributes_unsupported():
+ """TFLite StableHLO COMPOSITE rejects attributes until they are parsed."""
+ buf = _build_stablehlo_composite_model(with_attributes=True)
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="composite
attributes"):
+ from_tflite(tflite_model)
+
+
@pytest.mark.parametrize(
"builtin_name, relax_op, dtype, tensor_type",
[
@@ -4987,6 +5741,404 @@ def
test_stablehlo_dynamic_slice_out_of_bounds_unsupported():
from_tflite(tflite_model)
+def test_stablehlo_cbrt():
+ """TFLite StableHLO CBRT uses a sign-preserving composite expression."""
+ mod = _load_model_from_buffer(
+ _build_stablehlo_model(builtin_name="STABLEHLO_CBRT", input_count=1)
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2),
dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ lv: R.Tensor((2, 2), dtype="float32") = R.negative(x)
+ lv1: R.Tensor((2, 2), dtype="float32") = R.power(lv,
R.const(1.0 / 3.0, "float32"))
+ lv2: R.Tensor((2, 2), dtype="bool") = R.less(x, R.const(0,
"float32"))
+ lv3: R.Tensor((2, 2), dtype="float32") = R.negative(lv1)
+ lv4: R.Tensor((2, 2), dtype="float32") = R.power(x,
R.const(1.0 / 3.0, "float32"))
+ gv: R.Tensor((2, 2), dtype="float32") = R.where(lv2, lv3, lv4)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_remainder():
+ """TFLite StableHLO REMAINDER uses truncating remainder semantics."""
+ mod = _load_model_from_buffer(
+ _build_stablehlo_model(builtin_name="STABLEHLO_REMAINDER",
input_count=2)
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 2), dtype="float32"),
+ y: R.Tensor((2, 2), dtype="float32"),
+ ) -> R.Tensor((2, 2), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ lv: R.Tensor((2, 2), dtype="float32") = R.divide(x, y)
+ lv1: R.Tensor((2, 2), dtype="float32") = R.trunc(lv)
+ lv2: R.Tensor((2, 2), dtype="float32") = R.multiply(y, lv1)
+ gv: R.Tensor((2, 2), dtype="float32") = R.subtract(x, lv2)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def _build_stablehlo_dynamic_update_slice_model(start_vals,
dynamic_starts=False):
+ """Build a minimal STABLEHLO_DYNAMIC_UPDATE_SLICE model."""
+ builder = flatbuffers.Builder(1024)
+ builtin_op =
_get_stablehlo_builtin_operator("STABLEHLO_DYNAMIC_UPDATE_SLICE")
+ op_code = _build_operator_code(builder, builtin_op)
+
+ t_operand = _build_tensor(builder, 0, [3, 4])
+ t_update = _build_tensor(builder, 1, [2, 2])
+ start_tensors = [
+ _build_tensor(builder, 2 + i, [], tensor_type=_tfl_tensor_type.INT32)
+ for i in range(len(start_vals))
+ ]
+ out_idx = 2 + len(start_vals)
+ t_out = _build_tensor(builder, out_idx, [3, 4])
+ tensors = [t_operand, t_update, *start_tensors, t_out]
+
+ op_inputs = [0, 1, *range(2, out_idx)]
+ op = _build_operator(builder, 0, op_inputs, [out_idx])
+ subgraph_inputs = op_inputs if dynamic_starts else [0, 1]
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[op],
+ inputs=subgraph_inputs,
+ outputs=[out_idx],
+ )
+ if dynamic_starts:
+ buffers = [_build_buffer(builder) for _ in range(out_idx + 1)]
+ else:
+ start_buffers = [
+ _build_buffer(builder, np.array([start], dtype=np.int32).tobytes())
+ for start in start_vals
+ ]
+ buffers = [
+ _build_buffer(builder),
+ _build_buffer(builder),
+ *start_buffers,
+ _build_buffer(builder),
+ ]
+
+ return _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+ )
+
+
+def test_stablehlo_dynamic_update_slice():
+ """TFLite StableHLO DYNAMIC_UPDATE_SLICE with static starts."""
+ mod =
_load_model_from_buffer(_build_stablehlo_dynamic_update_slice_model([1, 1]))
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ operand: R.Tensor((3, 4), dtype="float32"),
+ update: R.Tensor((2, 2), dtype="float32"),
+ ) -> R.Tensor((3, 4), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ gv: R.Tensor((3, 4), dtype="float32") = R.scatter_nd(
+ operand,
+ R.const([[[1, 1], [1, 2]], [[2, 1], [2, 2]]],
dtype="int64"),
+ update,
+ reduction="update",
+ )
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_dynamic_update_slice_dynamic_starts_unsupported():
+ """TFLite StableHLO DYNAMIC_UPDATE_SLICE with runtime starts is
unsupported."""
+ buf = _build_stablehlo_dynamic_update_slice_model([0, 0],
dynamic_starts=True)
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="dynamic start"):
+ from_tflite(tflite_model)
+
+
+def test_stablehlo_dynamic_update_slice_out_of_bounds_unsupported():
+ """TFLite StableHLO DYNAMIC_UPDATE_SLICE rejects out-of-bounds updates."""
+ buf = _build_stablehlo_dynamic_update_slice_model([2, 3])
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="out-of-bounds"):
+ from_tflite(tflite_model)
+
+
+def _build_stablehlo_dot_general_model(lhs_contract, rhs_contract,
lhs_batch=None, rhs_batch=None):
+ """Build a minimal STABLEHLO_DOT_GENERAL model."""
+ builder = flatbuffers.Builder(1024)
+ lhs_batch = [] if lhs_batch is None else lhs_batch
+ rhs_batch = [] if rhs_batch is None else rhs_batch
+
+ lhs_batch_vec = _tflite_int64_vector(
+ builder,
+
_tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartLhsBatchingDimensionsVector,
+ lhs_batch,
+ )
+ rhs_batch_vec = _tflite_int64_vector(
+ builder,
+
_tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartRhsBatchingDimensionsVector,
+ rhs_batch,
+ )
+ lhs_contract_vec = _tflite_int64_vector(
+ builder,
+
_tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartLhsContractingDimensionsVector,
+ lhs_contract,
+ )
+ rhs_contract_vec = _tflite_int64_vector(
+ builder,
+
_tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartRhsContractingDimensionsVector,
+ rhs_contract,
+ )
+
+ _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStart(builder)
+ _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddLhsBatchingDimensions(
+ builder, lhs_batch_vec
+ )
+ _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddRhsBatchingDimensions(
+ builder, rhs_batch_vec
+ )
+
_tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddLhsContractingDimensions(
+ builder, lhs_contract_vec
+ )
+
_tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddRhsContractingDimensions(
+ builder, rhs_contract_vec
+ )
+ dot_opts = _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsEnd(builder)
+
+ builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DOT_GENERAL")
+ op_code = _build_operator_code(builder, builtin_op)
+ t_lhs = _build_tensor(builder, 0, [2, 3])
+ t_rhs = _build_tensor(builder, 1, [3, 4])
+ t_out = _build_tensor(builder, 2, [2, 4])
+ op = _build_operator(
+ builder,
+ 0,
+ [0, 1],
+ [2],
+ builtin_options2_type=_tfl_builtin_options2.StablehloDotGeneralOptions,
+ builtin_options2=dot_opts,
+ )
+ subgraph = _build_subgraph(
+ builder,
+ tensors=[t_lhs, t_rhs, t_out],
+ operators=[op],
+ inputs=[0, 1],
+ outputs=[2],
+ )
+ buffers = [_build_buffer(builder) for _ in range(3)]
+ return _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+ )
+
+
+def test_stablehlo_dot_general():
+ """TFLite StableHLO DOT_GENERAL canonical 2D matmul."""
+ mod = _load_model_from_buffer(_build_stablehlo_dot_general_model([1], [0]))
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ lhs: R.Tensor((2, 3), dtype="float32"),
+ rhs: R.Tensor((3, 4), dtype="float32"),
+ ) -> R.Tensor((2, 4), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ gv: R.Tensor((2, 4), dtype="float32") = R.matmul(lhs, rhs,
out_dtype="void")
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_dot_general_noncanonical_unsupported():
+ """TFLite StableHLO DOT_GENERAL rejects non-canonical contracting dims."""
+ buf = _build_stablehlo_dot_general_model([0], [0])
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="contracting"):
+ from_tflite(tflite_model)
+
+
+def _build_stablehlo_convolution_model(feature_group_count=1,
input_batch_dimension=0):
+ """Build a minimal STABLEHLO_CONVOLUTION model."""
+ builder = flatbuffers.Builder(1024)
+
+ window_strides_vec = _tflite_int64_vector(
+ builder,
+
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartWindowStridesVector,
+ [1, 1],
+ )
+ padding_vec = _tflite_int64_vector(
+ builder,
+ _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartPaddingVector,
+ [0, 0, 0, 0],
+ )
+ lhs_dilation_vec = _tflite_int64_vector(
+ builder,
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartLhsDilationVector, [1,
1]
+ )
+ rhs_dilation_vec = _tflite_int64_vector(
+ builder,
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartRhsDilationVector, [1,
1]
+ )
+ window_reversal_vec = _tflite_bool_vector(
+ builder,
+
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartWindowReversalVector,
+ [False, False],
+ )
+ input_spatial_vec = _tflite_int64_vector(
+ builder,
+
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartInputSpatialDimensionsVector,
+ [1, 2],
+ )
+ kernel_spatial_vec = _tflite_int64_vector(
+ builder,
+
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartKernelSpatialDimensionsVector,
+ [0, 1],
+ )
+ output_spatial_vec = _tflite_int64_vector(
+ builder,
+
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartOutputSpatialDimensionsVector,
+ [1, 2],
+ )
+
+ _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStart(builder)
+ _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddWindowStrides(
+ builder, window_strides_vec
+ )
+ _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddPadding(builder,
padding_vec)
+
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddLhsDilation(builder,
lhs_dilation_vec)
+
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddRhsDilation(builder,
rhs_dilation_vec)
+ _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddWindowReversal(
+ builder, window_reversal_vec
+ )
+ _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputBatchDimension(
+ builder, input_batch_dimension
+ )
+
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputFeatureDimension(builder,
3)
+
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputSpatialDimensions(
+ builder, input_spatial_vec
+ )
+
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelInputFeatureDimension(builder,
2)
+
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelOutputFeatureDimension(builder,
3)
+
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelSpatialDimensions(
+ builder, kernel_spatial_vec
+ )
+
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputBatchDimension(builder,
0)
+
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputFeatureDimension(builder,
3)
+
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputSpatialDimensions(
+ builder, output_spatial_vec
+ )
+ _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddFeatureGroupCount(
+ builder, feature_group_count
+ )
+
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddBatchGroupCount(builder,
1)
+ conv_opts =
_tfl_stablehlo_conv_opts.StablehloConvolutionOptionsEnd(builder)
+
+ builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_CONVOLUTION")
+ op_code = _build_operator_code(builder, builtin_op)
+ t_data = _build_tensor(builder, 0, [1, 5, 5, 2])
+ t_kernel = _build_tensor(builder, 1, [3, 3, 2, 4])
+ t_out = _build_tensor(builder, 2, [1, 3, 3, 4])
+ op = _build_operator(
+ builder,
+ 0,
+ [0, 1],
+ [2],
+
builtin_options2_type=_tfl_builtin_options2.StablehloConvolutionOptions,
+ builtin_options2=conv_opts,
+ )
+ subgraph = _build_subgraph(
+ builder,
+ tensors=[t_data, t_kernel, t_out],
+ operators=[op],
+ inputs=[0, 1],
+ outputs=[2],
+ )
+ buffers = [_build_buffer(builder) for _ in range(3)]
+ return _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+ )
+
+
+def test_stablehlo_convolution():
+ """TFLite StableHLO CONVOLUTION canonical NHWC/HWIO 2D convolution."""
+ mod = _load_model_from_buffer(_build_stablehlo_convolution_model())
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ data: R.Tensor((1, 5, 5, 2), dtype="float32"),
+ kernel: R.Tensor((3, 3, 2, 4), dtype="float32"),
+ ) -> R.Tensor((1, 3, 3, 4), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ gv: R.Tensor((1, 3, 3, 4), dtype="float32") = R.nn.conv2d(
+ data,
+ kernel,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="HWIO",
+ out_layout="NHWC",
+ out_dtype="void",
+ )
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_convolution_feature_group_unsupported():
+ """TFLite StableHLO CONVOLUTION rejects grouped convolution in the first
subset."""
+ buf = _build_stablehlo_convolution_model(feature_group_count=2)
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented,
match="feature_group_count"):
+ from_tflite(tflite_model)
+
+
+def test_stablehlo_convolution_dimension_numbers_unsupported():
+ """TFLite StableHLO CONVOLUTION rejects non-canonical dimension numbers."""
+ buf = _build_stablehlo_convolution_model(input_batch_dimension=1)
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="dimension numbers"):
+ from_tflite(tflite_model)
+
+
def _build_csr_sparsity(
builder,
*,