This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c0406a54bc [Relax][Frontend][TFLite] Add initial StableHLO builtin
operator support (#19536)
c0406a54bc is described below
commit c0406a54bc507bd77fdbe763b2d6968cafde6466
Author: HoYi <[email protected]>
AuthorDate: Mon May 11 22:05:20 2026 +0800
[Relax][Frontend][TFLite] Add initial StableHLO builtin operator support
(#19536)
## Summary
This PR adds initial Relax TFLite frontend support for 29 StableHLO
builtin
operators from #19519 item I.
The covered subset includes pure elementwise ops, BuiltinOptions2 /
metadata-based ops, simple shape-manipulation ops, and a take-equivalent
subset
of `STABLEHLO_GATHER`.
StableHLO builtins carry no TFLite-specific quantization or
fused-activation
metadata, so the implementation uses dedicated converter helpers that
bypass the
existing TFLite elemwise/QNN code paths.
Relates to #19519.
## Changes
1. **Zero-attribute elementwise helpers**
- Add `_convert_stablehlo_unary`, `_convert_stablehlo_binary`, and
`_convert_stablehlo_ternary` for pure elementwise mapping.
- Register 20 ops: unary (`ABS`, `NEGATE`, `COSINE`, `EXPONENTIAL`,
`FLOOR`,
`LOG`, `LOGISTIC`, `RSQRT`, `TANH`), binary (`ADD`, `SUBTRACT`,
`MULTIPLY`,
`DIVIDE`, `MAXIMUM`, `MINIMUM`, `POWER`), ternary (`SELECT` →
`R.where`),
and dtype-dispatched bitwise/logical ops (`AND` / `OR` → logical ops for
bool or bitwise ops for integer, `SHIFT_LEFT` → `R.left_shift` for
integer).
2. **BuiltinOptions2 infrastructure**
- Add `_get_stablehlo_options` helper for parsing `BuiltinOptions2`
flatbuffers
with enum validation via `getattr(BuiltinOptions2,
options_cls.__name__)`.
- Register 6 ops: `CONVERT` → `R.astype`, `CLAMP` →
`R.minimum(R.maximum(...))`, `CONCATENATE` → `R.concat`,
`BROADCAST_IN_DIM` → `R.reshape` + `R.broadcast_to`, `IOTA` →
`R.arange` + `R.broadcast_to`, and `COMPARE` → 6 comparison directions
(`TOTALORDER` raises `OpNotImplemented`).
3. **Shape-manipulation ops**
- `PAD` → `R.nn.pad` in constant mode. The initial PAD path supports
non-negative edge padding with zero interior padding and a constant
scalar
padding value. Interior padding, negative padding, and dynamic padding
values raise `OpNotImplemented`.
- `DYNAMIC_SLICE` → `R.dynamic_strided_slice`. The initial path supports
constant, in-bound start indices only. Runtime start indices and
out-of-bounds StableHLO clamping semantics are deferred.
4. **Indexing op**
- `GATHER` → `R.take` for the take-equivalent subset only.
- Parses the relevant `StablehloGatherOptions` attributes needed to
validate
this subset: `offset_dims`, `collapsed_slice_dims`, `start_index_map`,
`index_vector_dim`, and `slice_sizes`.
- Validates the gather axis, collapsed dims, offset dims, slice sizes,
and
output shape against the expected `R.take` layout. Multi-dimensional and
non-take-equivalent gather patterns raise `OpNotImplemented`.
5. **Not included**
- `STABLEHLO_RESHAPE`, `STABLEHLO_TRANSPOSE`, and `STABLEHLO_SLICE` are
left
to another contributor who expressed interest in those ops.
- The remaining Issue #19519 StableHLO items are deferred to follow-up
PRs:
`CBRT`, `REMAINDER`, `SCATTER`, `CONVOLUTION`, `DOT_GENERAL`, `REDUCE`,
`REDUCE_WINDOW`, `DYNAMIC_UPDATE_SLICE`, `COMPOSITE`, `CUSTOM_CALL`,
`RNG_BIT_GENERATOR`, `SORT`, and `WHILE`.
- More general or multi-dimensional `STABLEHLO_GATHER` patterns are also
deferred to follow-up work.
## Testing
All tests use manually-built minimal TFLite flatbuffers with
`tvm.ir.assert_structural_equal`. BuiltinOptions2 ops construct their
options
via the FlatBuffers schema API, modeled after the existing DILATE test
pattern.
```bash
python -m pytest tests/python/relax/test_frontend_tflite.py -k stablehlo -q
```
## Result
- 29 StableHLO operators registered in the Relax TFLite frontend.
- 44 StableHLO test cases covering all registered ops, including
structural-equal tests and unsupported/error-path checks:
- `COMPARE` with `TOTALORDER`
- `PAD` with interior padding, negative padding, and dynamic padding
values
- `DYNAMIC_SLICE` with runtime starts and out-of-bounds starts
- non-take-equivalent or multi-dimensional `GATHER`
- All StableHLO TFLite frontend tests pass locally.
## References
- Issue #19519 item I: StableHLO operators in TFLite
- Related PR #19481: DILATE operator mapping, the first use of
BuiltinOptions2
in the TFLite frontend tests
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 493 +++++++++
tests/python/relax/test_frontend_tflite.py | 1133 ++++++++++++++++++++
2 files changed, 1626 insertions(+)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 376f14138b..0b71990c90 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -240,6 +240,71 @@ class OperatorConverter:
"SQRT": functools.partial(self._convert_unary_elemwise,
relax_op=_op.sqrt),
"SQUARE": self.convert_square,
"SQUARED_DIFFERENCE": self.convert_squared_difference,
+ "STABLEHLO_ABS": functools.partial(
+ self._convert_stablehlo_unary, relax_op=_op.abs
+ ),
+ "STABLEHLO_ADD": functools.partial(
+ self._convert_stablehlo_binary, relax_op=_op.add
+ ),
+ "STABLEHLO_AND": self._convert_stablehlo_and,
+ "STABLEHLO_BROADCAST_IN_DIM":
self._convert_stablehlo_broadcast_in_dim,
+ "STABLEHLO_CLAMP": self._convert_stablehlo_clamp,
+ "STABLEHLO_COMPARE": self._convert_stablehlo_compare,
+ "STABLEHLO_CONCATENATE": self._convert_stablehlo_concatenate,
+ "STABLEHLO_CONVERT": self._convert_stablehlo_convert,
+ "STABLEHLO_COSINE": functools.partial(
+ self._convert_stablehlo_unary, relax_op=_op.cos
+ ),
+ "STABLEHLO_DIVIDE": functools.partial(
+ self._convert_stablehlo_binary, relax_op=_op.divide
+ ),
+ "STABLEHLO_DYNAMIC_SLICE": self._convert_stablehlo_dynamic_slice,
+ "STABLEHLO_EXPONENTIAL": functools.partial(
+ self._convert_stablehlo_unary, relax_op=_op.exp
+ ),
+ "STABLEHLO_FLOOR": functools.partial(
+ self._convert_stablehlo_unary, relax_op=_op.floor
+ ),
+ "STABLEHLO_GATHER": self._convert_stablehlo_gather,
+ "STABLEHLO_IOTA": self._convert_stablehlo_iota,
+ "STABLEHLO_LOG": functools.partial(
+ self._convert_stablehlo_unary, relax_op=_op.log
+ ),
+ "STABLEHLO_LOGISTIC": functools.partial(
+ self._convert_stablehlo_unary, relax_op=_op.sigmoid
+ ),
+ "STABLEHLO_MAXIMUM": functools.partial(
+ self._convert_stablehlo_binary, relax_op=_op.maximum
+ ),
+ "STABLEHLO_MINIMUM": functools.partial(
+ self._convert_stablehlo_binary, relax_op=_op.minimum
+ ),
+ "STABLEHLO_MULTIPLY": functools.partial(
+ self._convert_stablehlo_binary, relax_op=_op.multiply
+ ),
+ "STABLEHLO_NEGATE": functools.partial(
+ self._convert_stablehlo_unary, relax_op=_op.negative
+ ),
+ "STABLEHLO_OR": self._convert_stablehlo_or,
+ "STABLEHLO_PAD": self._convert_stablehlo_pad,
+ "STABLEHLO_POWER": functools.partial(
+ self._convert_stablehlo_binary, relax_op=_op.power
+ ),
+ "STABLEHLO_RSQRT": functools.partial(
+ self._convert_stablehlo_unary, relax_op=_op.rsqrt
+ ),
+ "STABLEHLO_SELECT": functools.partial(
+ self._convert_stablehlo_ternary, relax_op=_op.where
+ ),
+ "STABLEHLO_SHIFT_LEFT": functools.partial(
+ self._convert_stablehlo_binary, relax_op=_op.left_shift
+ ),
+ "STABLEHLO_SUBTRACT": functools.partial(
+ self._convert_stablehlo_binary, relax_op=_op.subtract
+ ),
+ "STABLEHLO_TANH": functools.partial(
+ self._convert_stablehlo_unary, relax_op=_op.tanh
+ ),
"SQUEEZE": self.convert_squeeze,
"STRIDED_SLICE": self.convert_strided_slice,
"SUB": functools.partial(self._convert_elemwise,
relax_op=_op.subtract),
@@ -1323,6 +1388,434 @@ class OperatorConverter:
out = self.quantize(out, output_tensor)
return out
+ def _convert_stablehlo_unary(self, op, relax_op):
+ """Convert a unary StableHLO TFLite builtin operator.
+
+ StableHLO builtins do not have TFLite fused activation attributes. Keep
+ this path independent from the regular TFLite elemwise/QNN helpers so
+ StableHLO semantics are mapped directly to Relax operators.
+ """
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 1, "input tensors length should be 1"
+
+ assert len(self.get_output_tensors(op)) == 1, "output tensors length
should be 1"
+
+ in_expr = self.get_tensor_expr(input_tensors[0])
+ return relax_op(in_expr)
+
+ def _convert_stablehlo_binary(self, op, relax_op):
+ """Convert a binary StableHLO TFLite builtin operator.
+
+ StableHLO builtins do not have TFLite fused activation attributes. Keep
+ this path independent from the regular TFLite elemwise/QNN helpers so
+ StableHLO semantics are mapped directly to Relax operators.
+ """
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+
+ assert len(self.get_output_tensors(op)) == 1, "output tensors length
should be 1"
+
+ lhs_expr = self.get_tensor_expr(input_tensors[0])
+ rhs_expr = self.get_tensor_expr(input_tensors[1])
+ return relax_op(lhs_expr, rhs_expr)
+
+ def _convert_stablehlo_and(self, op):
+ """Convert StableHLO AND for bool and integer tensors."""
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+
+ assert len(self.get_output_tensors(op)) == 1, "output tensors length
should be 1"
+
+ lhs = self.get_tensor_expr(input_tensors[0])
+ rhs = self.get_tensor_expr(input_tensors[1])
+ dtype = lhs.struct_info.dtype
+ if dtype == "bool":
+ op_fn = _op.logical_and
+ elif dtype.startswith(("int", "uint")):
+ op_fn = _op.bitwise_and
+ else:
+ raise tvm.error.OpNotImplemented(
+ f"STABLEHLO_AND with dtype {dtype} is not supported"
+ )
+ return self.bb.normalize(op_fn(lhs, rhs))
+
+ def _convert_stablehlo_or(self, op):
+ """Convert StableHLO OR for bool and integer tensors."""
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+
+ assert len(self.get_output_tensors(op)) == 1, "output tensors length
should be 1"
+
+ lhs = self.get_tensor_expr(input_tensors[0])
+ rhs = self.get_tensor_expr(input_tensors[1])
+ dtype = lhs.struct_info.dtype
+ if dtype == "bool":
+ op_fn = _op.logical_or
+ elif dtype.startswith(("int", "uint")):
+ op_fn = _op.bitwise_or
+ else:
+ raise tvm.error.OpNotImplemented(
+ f"STABLEHLO_OR with dtype {dtype} is not supported"
+ )
+ return self.bb.normalize(op_fn(lhs, rhs))
+
+ def _convert_stablehlo_ternary(self, op, relax_op):
+ """Convert a ternary StableHLO TFLite builtin operator.
+
+ StableHLO builtins do not have TFLite fused activation attributes. Keep
+ this path independent from the regular TFLite elemwise/QNN helpers so
+ StableHLO semantics are mapped directly to Relax operators.
+ """
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 3, "input tensors length should be 3"
+
+ assert len(self.get_output_tensors(op)) == 1, "output tensors length
should be 1"
+
+ arg0 = self.get_tensor_expr(input_tensors[0])
+ arg1 = self.get_tensor_expr(input_tensors[1])
+ arg2 = self.get_tensor_expr(input_tensors[2])
+ return relax_op(arg0, arg1, arg2)
+
+ def _get_stablehlo_options(self, op, options_cls):
+ """Parse BuiltinOptions2 for a StableHLO TFLite builtin operator.
+
+ Returns an initialized options object of the given class.
+ """
+ from tflite.BuiltinOptions2 import BuiltinOptions2
+
+ op_options = op.BuiltinOptions2()
+ # Look up the expected BuiltinOptions2 enum value by matching the class
+ # name to an enum member (e.g. StablehloConcatenateOptions → 1).
+ options_type = getattr(BuiltinOptions2, options_cls.__name__, None)
+ if options_type is not None:
+ assert op.BuiltinOptions2Type() == options_type, (
+ f"Unexpected BuiltinOptions2 type: expected "
+ f"{options_cls.__name__}, got {op.BuiltinOptions2Type()}"
+ )
+ result = options_cls()
+ result.Init(op_options.Bytes, op_options.Pos)
+ return result
+
+ def _convert_stablehlo_convert(self, op):
+ """Convert STABLEHLO_CONVERT to Relax (astype).
+
+ Reads the output tensor dtype from the TFLite schema and applies
+ relax.op.astype. This path is intentionally separate from the
+ generic _convert_stablehlo_unary helper because the output dtype
+ is operator-level metadata, not a Relax op parameter.
+ """
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 1, "input tensors length should be 1"
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1, "output tensors length should be 1"
+
+ in_expr = self.get_tensor_expr(input_tensors[0])
+ output_dtype =
self.get_tensor_type_str(output_tensors[0].tensor.Type())
+ return self.bb.normalize(relax.op.astype(in_expr, output_dtype))
+
+ def _convert_stablehlo_clamp(self, op):
+ """Convert STABLEHLO_CLAMP to Relax.
+
+ StableHLO clamp(min, operand, max) → R.minimum(R.maximum(operand,
min), max).
+ """
+ # NOTE: R.clip is not used here because it only accepts scalar
PrimValue
+ # min/max, not tensor inputs.
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 3, "input tensors length should be 3"
+
+ assert len(self.get_output_tensors(op)) == 1
+
+ min_expr = self.get_tensor_expr(input_tensors[0])
+ operand_expr = self.get_tensor_expr(input_tensors[1])
+ max_expr = self.get_tensor_expr(input_tensors[2])
+
+ clamped = self.bb.normalize(relax.op.maximum(operand_expr, min_expr))
+ return self.bb.normalize(relax.op.minimum(clamped, max_expr))
+
+ def _convert_stablehlo_concatenate(self, op):
+ """Convert STABLEHLO_CONCATENATE to Relax."""
+ from tflite.StablehloConcatenateOptions import
StablehloConcatenateOptions
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) >= 1, "input tensors length should be >= 1"
+ assert len(self.get_output_tensors(op)) == 1
+
+ opts = self._get_stablehlo_options(op, StablehloConcatenateOptions)
+ dim = opts.Dimension()
+
+ in_exprs = [self.get_tensor_expr(t) for t in input_tensors]
+ return self.bb.normalize(relax.op.concat(in_exprs, axis=dim))
+
+ def _convert_stablehlo_broadcast_in_dim(self, op):
+ """Convert STABLEHLO_BROADCAST_IN_DIM to Relax."""
+ from tflite.StablehloBroadcastInDimOptions import
StablehloBroadcastInDimOptions
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 1
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1
+
+ opts = self._get_stablehlo_options(op, StablehloBroadcastInDimOptions)
+ broadcast_dims = [int(d) for d in opts.BroadcastDimensionsAsNumpy()]
+
+ in_expr = self.get_tensor_expr(input_tensors[0])
+ input_shape = [int(d) for d in self.get_tensor_shape(input_tensors[0])]
+ output_shape = [int(d) for d in
self.get_tensor_shape(output_tensors[0])]
+
+ # Map input dims to output dims via broadcast_dims, filling
+ # unmapped positions with 1 so broadcast_to covers them.
+ intermediate_shape = [1] * len(output_shape)
+ for i, d in enumerate(broadcast_dims):
+ intermediate_shape[d] = input_shape[i]
+
+ reshaped = self.bb.normalize(relax.op.reshape(in_expr,
intermediate_shape))
+ return self.bb.normalize(relax.op.broadcast_to(reshaped, output_shape))
+
+ def _convert_stablehlo_iota(self, op):
+ """Convert STABLEHLO_IOTA to Relax (arange + broadcast)."""
+ from tflite.StablehloIotaOptions import StablehloIotaOptions
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1
+
+ opts = self._get_stablehlo_options(op, StablehloIotaOptions)
+ iota_dim = opts.IotaDimension()
+
+ output_tensor = output_tensors[0]
+ output_shape = [int(d) for d in self.get_tensor_shape(output_tensor)]
+ output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
+
+ # arange along the iota dimension
+ size = output_shape[iota_dim]
+ arange_1d = self.bb.normalize(relax.op.arange(0, size, 1,
output_dtype))
+
+ # reshape to [1, ..., size, ..., 1]
+ broadcast_shape = [1] * len(output_shape)
+ broadcast_shape[iota_dim] = size
+ arange_reshaped = self.bb.normalize(relax.op.reshape(arange_1d,
broadcast_shape))
+
+ # broadcast to full output shape
+ return self.bb.normalize(relax.op.broadcast_to(arange_reshaped,
output_shape))
+
+ def _convert_stablehlo_compare(self, op):
+ """Convert STABLEHLO_COMPARE to Relax binary comparison ops."""
+ from tflite.StablehloCompareOptions import StablehloCompareOptions
+ from tflite.StablehloComparisonDirection import
StablehloComparisonDirection
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2
+ assert len(self.get_output_tensors(op)) == 1
+
+ from tflite.StablehloComparisonType import StablehloComparisonType
+
+ opts = self._get_stablehlo_options(op, StablehloCompareOptions)
+ direction = opts.ComparisonDirection()
+ compare_type = opts.CompareType()
+
+ # TOTALORDER compare is not expressible via Relax comparison ops.
+ if compare_type ==
StablehloComparisonType.STABLEHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_COMPARE with TOTALORDER comparison type is not
supported"
+ )
+
+ _DIR = StablehloComparisonDirection
+ direction_map = {
+ _DIR.STABLEHLO_COMPARISON_DIRECTION_EQ: relax.op.equal,
+ _DIR.STABLEHLO_COMPARISON_DIRECTION_NE: relax.op.not_equal,
+ _DIR.STABLEHLO_COMPARISON_DIRECTION_GE: relax.op.greater_equal,
+ _DIR.STABLEHLO_COMPARISON_DIRECTION_GT: relax.op.greater,
+ _DIR.STABLEHLO_COMPARISON_DIRECTION_LE: relax.op.less_equal,
+ _DIR.STABLEHLO_COMPARISON_DIRECTION_LT: relax.op.less,
+ }
+ relax_fn = direction_map.get(direction)
+ if relax_fn is None:
+ raise tvm.error.OpNotImplemented(
+ f"Unsupported StableHLO comparison direction: {direction}"
+ )
+
+ lhs = self.get_tensor_expr(input_tensors[0])
+ rhs = self.get_tensor_expr(input_tensors[1])
+ return self.bb.normalize(relax_fn(lhs, rhs))
+
+ def _convert_stablehlo_pad(self, op):
+ """Convert STABLEHLO_PAD to Relax (nn.pad).
+
+ Maps edge padding to R.nn.pad with constant mode. Interior padding
+ (dilation) is not supported in the first version.
+ """
+ from tflite.StablehloPadOptions import StablehloPadOptions
+
+ input_tensors = self.get_input_tensors(op)
+ # operand + padding_value
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+ assert len(self.get_output_tensors(op)) == 1
+
+ opts = self._get_stablehlo_options(op, StablehloPadOptions)
+ edge_low = [int(d) for d in opts.EdgePaddingLowAsNumpy()]
+ edge_high = [int(d) for d in opts.EdgePaddingHighAsNumpy()]
+ interior = [int(d) for d in opts.InteriorPaddingAsNumpy()]
+
+ if any(d != 0 for d in interior):
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_PAD with interior (dilation) padding is not
supported"
+ )
+ if any(d < 0 for d in edge_low) or any(d < 0 for d in edge_high):
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_PAD with negative edge padding (crop) is not
supported"
+ )
+
+ operand = self.get_tensor_expr(input_tensors[0])
+
+ # R.nn.pad only supports a static Python float pad_value.
+ pad_value_tensor = input_tensors[1]
+ if not self.has_expr(pad_value_tensor.tensor_idx):
+ pad_val = float(self.get_tensor_value(pad_value_tensor))
+ else:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_PAD with dynamic padding value is not supported"
+ )
+
+ # R.nn.pad with flat pad_width: [lo0, hi0, lo1, hi1, ...]
+ pad_width = []
+ for lo, hi in zip(edge_low, edge_high):
+ pad_width.extend([lo, hi])
+
+ return self.bb.normalize(
+ relax.op.nn.pad(operand, pad_width=pad_width, pad_value=pad_val)
+ )
+
+ def _convert_stablehlo_dynamic_slice(self, op):
+ """Convert STABLEHLO_DYNAMIC_SLICE to Relax (dynamic_strided_slice).
+
+ Start indices are assumed to be constant (non-dynamic) values stored
+ in the flatbuffer. Truly dynamic (runtime) start indices require
+ Relax arithmetic to compute begin/end from scalar inputs and are not
+ yet supported.
+ """
+ from tflite.StablehloDynamicSliceOptions import
StablehloDynamicSliceOptions
+
+ input_tensors = self.get_input_tensors(op)
+ # operand + N start-index scalars
+ assert len(input_tensors) >= 2
+ ndim = len(input_tensors) - 1
+ assert len(self.get_output_tensors(op)) == 1
+
+ opts = self._get_stablehlo_options(op, StablehloDynamicSliceOptions)
+ slice_sizes = [int(d) for d in opts.SliceSizesAsNumpy()]
+ assert len(slice_sizes) == ndim
+
+ operand = self.get_tensor_expr(input_tensors[0])
+
+ # Build constant 1D tensors for begin, end, strides
+ # (assumes start values are constant in the flatbuffer)
+ # TODO: support dynamic start indices via Relax arithmetic
+ if any(self.has_expr(t.tensor_idx) for t in input_tensors[1:]):
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_DYNAMIC_SLICE with dynamic start indices is not
supported"
+ )
+ start_vals = [int(self.get_tensor_value(t)) for t in input_tensors[1:]]
+ operand_shape = [int(d) for d in
self.get_tensor_shape(input_tensors[0])]
+ for start, size, dim in zip(start_vals, slice_sizes, operand_shape):
+ if start < 0 or start + size > dim:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_DYNAMIC_SLICE with out-of-bounds start indices
is not supported"
+ )
+ end_vals = [s + sz for s, sz in zip(start_vals, slice_sizes)]
+ stride_vals = [1] * ndim
+
+ def _const_1d(values, dtype="int64"):
+ arr = np.array(values, dtype=dtype)
+ return self.bb.normalize(relax.const(arr, dtype=dtype))
+
+ begin = _const_1d(start_vals)
+ end = _const_1d(end_vals)
+ strides = _const_1d(stride_vals)
+
+ return self.bb.normalize(
+ relax.op.dynamic_strided_slice(operand, begin, end, strides)
+ )
+
+
+ def _convert_stablehlo_gather(self, op):
+ """Convert STABLEHLO_GATHER to Relax (take-equivalent subset only).
+
+ Only handles gather patterns equivalent to R.take along a single axis.
+ Multi-dimensional gathers, index_vector_dim != rank(indices)-1, and
+ non-trivial slice_sizes raise OpNotImplemented.
+ """
+ from tflite.StablehloGatherOptions import StablehloGatherOptions
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1
+
+ opts = self._get_stablehlo_options(op, StablehloGatherOptions)
+ offset_dims = [int(d) for d in opts.OffsetDimsAsNumpy()]
+ collapsed_slice_dims = [int(d) for d in
opts.CollapsedSliceDimsAsNumpy()]
+ start_index_map = [int(d) for d in opts.StartIndexMapAsNumpy()]
+ slice_sizes = [int(d) for d in opts.SliceSizesAsNumpy()]
+ index_vector_dim = int(opts.IndexVectorDim())
+
+ data_tensor, indices_tensor = input_tensors
+ data_shape = [int(d) for d in self.get_tensor_shape(data_tensor)]
+ indices_shape = [int(d) for d in self.get_tensor_shape(indices_tensor)]
+ output_shape = [int(d) for d in
self.get_tensor_shape(output_tensors[0])]
+
+ if len(start_index_map) != 1:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_GATHER only supports one start_index_map entry"
+ )
+ axis = start_index_map[0]
+ if axis < 0 or axis >= len(data_shape):
+ raise tvm.error.OpNotImplemented(f"Unsupported STABLEHLO_GATHER
axis: {axis}")
+ if collapsed_slice_dims != [axis]:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_GATHER only supports collapsed_slice_dims matching
the gather axis"
+ )
+ if len(slice_sizes) != len(data_shape):
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_GATHER slice_sizes must match operand rank"
+ )
+ for i, (size, dim) in enumerate(zip(slice_sizes, data_shape)):
+ expected = 1 if i == axis else dim
+ if size != expected:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_GATHER only supports take-equivalent
slice_sizes"
+ )
+ if index_vector_dim != len(indices_shape) - 1:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_GATHER only supports trailing index_vector_dim"
+ )
+ if not indices_shape or indices_shape[index_vector_dim] != 1:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_GATHER only supports index vector size 1"
+ )
+
+ indices_batch_shape = indices_shape[:index_vector_dim]
+ expected_offset_dims = list(range(axis)) + list(
+ range(axis + len(indices_batch_shape), len(data_shape) +
len(indices_batch_shape) - 1)
+ )
+ if offset_dims != expected_offset_dims:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_GATHER offset_dims do not match Relax take output
layout"
+ )
+
+ expected_output_shape = (
+ data_shape[:axis] + indices_batch_shape + data_shape[axis + 1 :]
+ )
+ if output_shape != expected_output_shape:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_GATHER output shape does not match Relax take
semantics"
+ )
+
+ data = self.get_tensor_expr(data_tensor)
+ indices = self.get_tensor_expr(indices_tensor)
+ indices = self.bb.normalize(relax.op.reshape(indices,
indices_batch_shape))
+ return self.bb.normalize(relax.op.take(data, indices, axis=axis,
mode="fast"))
+
+
def convert_elu(self, op):
"""Convert TFLite ELU"""
input_tensors = self.get_input_tensors(op)
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 9b9029b5a5..fc509a4d0f 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -3666,6 +3666,17 @@ _tfl_add_options =
_get_tflite_schema_module("AddOptions")
_tfl_buffer = _get_tflite_schema_module("Buffer")
_tfl_conv2d_options = _get_tflite_schema_module("Conv2DOptions")
_tfl_dilate_options = _get_tflite_schema_module("DilateOptions")
+
+# ── StableHLO BuiltinOptions2 schema modules ────────────────────────────
+_tfl_stablehlo_concat_opts =
_get_tflite_schema_module("StablehloConcatenateOptions")
+_tfl_stablehlo_bcast_opts =
_get_tflite_schema_module("StablehloBroadcastInDimOptions")
+_tfl_stablehlo_iota_opts = _get_tflite_schema_module("StablehloIotaOptions")
+_tfl_stablehlo_compare_opts =
_get_tflite_schema_module("StablehloCompareOptions")
+_tfl_stablehlo_comp_dir =
_get_tflite_schema_module("StablehloComparisonDirection")
+_tfl_stablehlo_comp_type = _get_tflite_schema_module("StablehloComparisonType")
+_tfl_stablehlo_pad_opts = _get_tflite_schema_module("StablehloPadOptions")
+_tfl_stablehlo_dyn_slice_opts =
_get_tflite_schema_module("StablehloDynamicSliceOptions")
+_tfl_stablehlo_gather_opts =
_get_tflite_schema_module("StablehloGatherOptions")
_tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata")
_tfl_fully_connected_options =
_get_tflite_schema_module("FullyConnectedOptions")
_tfl_int32_vector = _get_tflite_schema_module("Int32Vector")
@@ -3838,6 +3849,1128 @@ def _finish_tflite_model(builder, *, subgraph,
operator_codes, buffers):
return bytes(builder.Output())
+def _load_model_from_buffer(model_bytes):
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(model_bytes, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(model_bytes, 0)
+ mod = from_tflite(tflite_model)
+ mod["main"] = mod["main"].without_attr("params")
+ return mod
+
+
+def _get_stablehlo_builtin_operator(builtin_name):
+ if not hasattr(_tfl_builtin_operator, builtin_name):
+ pytest.skip(f"TFLite schema does not provide
BuiltinOperator.{builtin_name}")
+ return getattr(_tfl_builtin_operator, builtin_name)
+
+
+def _build_stablehlo_model(*, builtin_name, input_count):
+ """Build a minimal TFLite model containing one StableHLO builtin
operator."""
+ builder = flatbuffers.Builder(1024)
+ shape = [2, 2]
+ output_tensor_idx = input_count
+ builtin_op = _get_stablehlo_builtin_operator(builtin_name)
+
+ tensors = [_build_tensor(builder, buffer_idx, shape) for buffer_idx in
range(input_count + 1)]
+ stablehlo_op = _build_operator(
+ builder,
+ 0,
+ list(range(input_count)),
+ [output_tensor_idx],
+ )
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[stablehlo_op],
+ inputs=list(range(input_count)),
+ outputs=[output_tensor_idx],
+ )
+ operator_codes = [_build_operator_code(builder, builtin_op)]
+ buffers = [_build_buffer(builder) for _ in range(input_count + 1)]
+ return _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=operator_codes,
buffers=buffers
+ )
+
+
+def _build_stablehlo_typed_binary_model(*, builtin_name, tensor_type):
+ """Build a minimal TFLite StableHLO binary model with the requested tensor
type."""
+ builder = flatbuffers.Builder(1024)
+ shape = [2, 2]
+ output_tensor_idx = 2
+ builtin_op = _get_stablehlo_builtin_operator(builtin_name)
+
+ tensors = [
+ _build_tensor(builder, buffer_idx, shape, tensor_type=tensor_type)
+ for buffer_idx in range(3)
+ ]
+ stablehlo_op = _build_operator(builder, 0, [0, 1], [output_tensor_idx])
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[stablehlo_op],
+ inputs=[0, 1],
+ outputs=[output_tensor_idx],
+ )
+ operator_codes = [_build_operator_code(builder, builtin_op)]
+ buffers = [_build_buffer(builder) for _ in range(3)]
+ return _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=operator_codes,
buffers=buffers
+ )
+
+
[email protected](
+ "builtin_name, relax_op",
+ [
+ ("STABLEHLO_ABS", R.abs),
+ ("STABLEHLO_COSINE", R.cos),
+ ("STABLEHLO_EXPONENTIAL", R.exp),
+ ("STABLEHLO_FLOOR", R.floor),
+ ("STABLEHLO_LOG", R.log),
+ ("STABLEHLO_LOGISTIC", R.sigmoid),
+ ("STABLEHLO_NEGATE", R.negative),
+ ("STABLEHLO_RSQRT", R.rsqrt),
+ ("STABLEHLO_TANH", R.tanh),
+ ],
+)
+def test_stablehlo_unary(builtin_name, relax_op):
+ """TFLite StableHLO unary elementwise operators."""
+ mod = _load_model_from_buffer(
+ _build_stablehlo_model(builtin_name=builtin_name, input_count=1)
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2),
dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor((2, 2), dtype="float32") = relax_op(x)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
[email protected](
+ "builtin_name, relax_op",
+ [
+ ("STABLEHLO_ADD", R.add),
+ ("STABLEHLO_DIVIDE", R.divide),
+ ("STABLEHLO_MAXIMUM", R.maximum),
+ ("STABLEHLO_MINIMUM", R.minimum),
+ ("STABLEHLO_MULTIPLY", R.multiply),
+ ("STABLEHLO_POWER", R.power),
+ ("STABLEHLO_SUBTRACT", R.subtract),
+ ],
+)
+def test_stablehlo_binary(builtin_name, relax_op):
+ """TFLite StableHLO binary elementwise operators."""
+ mod = _load_model_from_buffer(
+ _build_stablehlo_model(builtin_name=builtin_name, input_count=2)
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 2), dtype="float32"),
+ y: R.Tensor((2, 2), dtype="float32"),
+ ) -> R.Tensor((2, 2), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ gv: R.Tensor((2, 2), dtype="float32") = relax_op(x, y)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
[email protected](
+ "builtin_name, relax_op, dtype, tensor_type",
+ [
+ ("STABLEHLO_AND", R.logical_and, "bool", _tfl_tensor_type.BOOL),
+ ("STABLEHLO_OR", R.logical_or, "bool", _tfl_tensor_type.BOOL),
+ ("STABLEHLO_AND", R.bitwise_and, "int32", _tfl_tensor_type.INT32),
+ ("STABLEHLO_OR", R.bitwise_or, "int32", _tfl_tensor_type.INT32),
+ ("STABLEHLO_SHIFT_LEFT", R.left_shift, "int32",
_tfl_tensor_type.INT32),
+ ],
+)
+def test_stablehlo_typed_binary(builtin_name, relax_op, dtype, tensor_type):
+ """TFLite StableHLO binary elementwise operators with non-float dtype
requirements."""
+ mod = _load_model_from_buffer(
+ _build_stablehlo_typed_binary_model(
+ builtin_name=builtin_name, tensor_type=tensor_type
+ )
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 2), dtype=dtype),
+ y: R.Tensor((2, 2), dtype=dtype),
+ ) -> R.Tensor((2, 2), dtype=dtype):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ gv: R.Tensor((2, 2), dtype=dtype) = relax_op(x, y)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
[email protected](
+ "builtin_name, relax_op",
+ [
+ ("STABLEHLO_SELECT", R.where),
+ ],
+)
+def test_stablehlo_ternary(builtin_name, relax_op):
+ """TFLite StableHLO ternary elementwise operators."""
+ builder = flatbuffers.Builder(1024)
+ shape = [2, 2]
+ builtin_op = _get_stablehlo_builtin_operator(builtin_name)
+
+ # First input (condition) must be bool for R.where
+ tensor_0 = _build_tensor(builder, 0, shape,
tensor_type=_tfl_tensor_type.BOOL)
+ tensor_1 = _build_tensor(builder, 1, shape)
+ tensor_2 = _build_tensor(builder, 2, shape)
+ tensor_out = _build_tensor(builder, 3, shape)
+ tensors = [tensor_0, tensor_1, tensor_2, tensor_out]
+
+ stablehlo_op = _build_operator(
+ builder,
+ 0,
+ [0, 1, 2],
+ [3],
+ )
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[stablehlo_op],
+ inputs=[0, 1, 2],
+ outputs=[3],
+ )
+ operator_codes = [_build_operator_code(builder, builtin_op)]
+ buffers = [_build_buffer(builder) for _ in range(4)]
+
+ mod = _load_model_from_buffer(
+ _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=operator_codes,
buffers=buffers
+ )
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ c: R.Tensor((2, 2), dtype="bool"),
+ x: R.Tensor((2, 2), dtype="float32"),
+ y: R.Tensor((2, 2), dtype="float32"),
+ ) -> R.Tensor((2, 2), dtype="float32"):
+ R.func_attr({"num_input": 3})
+ with R.dataflow():
+ gv: R.Tensor((2, 2), dtype="float32") = relax_op(c, x, y)
+ R.output(gv)
+ return gv
+
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+
+
+def _build_stablehlo_convert_model():
+ """STABLEHLO_CONVERT: float32 input -> int32 output."""
+ builder = flatbuffers.Builder(1024)
+ shape = [2, 2]
+
+ t_in = _build_tensor(builder, 0, shape,
tensor_type=_tfl_tensor_type.FLOAT32)
+ t_out = _build_tensor(builder, 1, shape,
tensor_type=_tfl_tensor_type.INT32)
+ tensors = [t_in, t_out]
+
+ op_code = _build_operator_code(
+ builder, _get_stablehlo_builtin_operator("STABLEHLO_CONVERT")
+ )
+ op = _build_operator(builder, 0, [0], [1])
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[op],
+ inputs=[0],
+ outputs=[1],
+ )
+ buffers = [_build_buffer(builder) for _ in range(2)]
+ return _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+ )
+
+
+def test_stablehlo_convert():
+ """TFLite StableHLO CONVERT (astype float32 -> int32)."""
+ mod = _load_model_from_buffer(_build_stablehlo_convert_model())
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2),
dtype="int32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor((2, 2), dtype="int32") = R.astype(x,
dtype="int32")
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_clamp():
+ """TFLite StableHLO CLAMP (clip with min/operand/max order)."""
+ mod = _load_model_from_buffer(
+ _build_stablehlo_model(builtin_name="STABLEHLO_CLAMP", input_count=3)
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ m: R.Tensor((2, 2), dtype="float32"),
+ x: R.Tensor((2, 2), dtype="float32"),
+ M: R.Tensor((2, 2), dtype="float32"),
+ ) -> R.Tensor((2, 2), dtype="float32"):
+ R.func_attr({"num_input": 3})
+ with R.dataflow():
+ gv: R.Tensor((2, 2), dtype="float32") = R.minimum(R.maximum(x,
m), M)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def _build_stablehlo_concat_model(dimension, num_inputs):
+ """STABLEHLO_CONCATENATE with given dimension and number of inputs."""
+ builder = flatbuffers.Builder(1024)
+ shape = [2, 2]
+
+ # Build concat options
+ _tfl_stablehlo_concat_opts.StablehloConcatenateOptionsStart(builder)
+
_tfl_stablehlo_concat_opts.StablehloConcatenateOptionsAddDimension(builder,
dimension)
+ concat_opts =
_tfl_stablehlo_concat_opts.StablehloConcatenateOptionsEnd(builder)
+
+ builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_CONCATENATE")
+ op_code = _build_operator_code(builder, builtin_op)
+
+ if dimension == 0:
+ out_shape = [num_inputs * shape[0], shape[1]]
+ else:
+ out_shape = [shape[0], num_inputs * shape[1]]
+ tensors = [
+ _build_tensor(builder, i, shape) for i in range(num_inputs)
+ ] + [_build_tensor(builder, num_inputs, out_shape)]
+
+ op = _build_operator(
+ builder,
+ 0,
+ list(range(num_inputs)),
+ [num_inputs],
+
builtin_options2_type=_tfl_builtin_options2.StablehloConcatenateOptions,
+ builtin_options2=concat_opts,
+ )
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[op],
+ inputs=list(range(num_inputs)),
+ outputs=[num_inputs],
+ )
+ buffers = [_build_buffer(builder) for _ in range(num_inputs + 1)]
+ return _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+ )
+
+
[email protected]("dimension", [0, 1])
+def test_stablehlo_concatenate(dimension):
+ """TFLite StableHLO CONCATENATE with 2 inputs along given axis."""
+ num_inputs = 2
+ mod = _load_model_from_buffer(
+ _build_stablehlo_concat_model(dimension=dimension,
num_inputs=num_inputs)
+ )
+
+ out_dim = (4, 2) if dimension == 0 else (2, 4)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 2), dtype="float32"),
+ y: R.Tensor((2, 2), dtype="float32"),
+ ) -> R.Tensor(out_dim, dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ gv: R.Tensor(out_dim, dtype="float32") = R.concat((x, y),
axis=dimension)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def _build_stablehlo_broadcast_in_dim_model(input_shape, broadcast_dims,
output_shape):
+ """STABLEHLO_BROADCAST_IN_DIM with given broadcast dimensions."""
+ builder = flatbuffers.Builder(1024)
+
+ # Build broadcast dimensions vector
+
_tfl_stablehlo_bcast_opts.StablehloBroadcastInDimOptionsStartBroadcastDimensionsVector(
+ builder, len(broadcast_dims)
+ )
+ for d in reversed(broadcast_dims):
+ builder.PrependInt64(d)
+ dims_vec = builder.EndVector()
+
+ _tfl_stablehlo_bcast_opts.StablehloBroadcastInDimOptionsStart(builder)
+
_tfl_stablehlo_bcast_opts.StablehloBroadcastInDimOptionsAddBroadcastDimensions(
+ builder, dims_vec
+ )
+ bcast_opts =
_tfl_stablehlo_bcast_opts.StablehloBroadcastInDimOptionsEnd(builder)
+
+ builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_BROADCAST_IN_DIM")
+ op_code = _build_operator_code(builder, builtin_op)
+
+ t_in = _build_tensor(builder, 0, input_shape)
+ t_out = _build_tensor(builder, 1, output_shape)
+ tensors = [t_in, t_out]
+
+ op = _build_operator(
+ builder,
+ 0,
+ [0],
+ [1],
+
builtin_options2_type=_tfl_builtin_options2.StablehloBroadcastInDimOptions,
+ builtin_options2=bcast_opts,
+ )
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[op],
+ inputs=[0],
+ outputs=[1],
+ )
+ buffers = [_build_buffer(builder) for _ in range(2)]
+ return _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+ )
+
+
+def test_stablehlo_broadcast_in_dim():
+ """TFLite StableHLO BROADCAST_IN_DIM: (3,) -> (2, 3) with dims=[1]."""
+ mod = _load_model_from_buffer(
+ _build_stablehlo_broadcast_in_dim_model(
+ input_shape=[3], broadcast_dims=[1], output_shape=[2, 3]
+ )
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((3,), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor((2, 3), dtype="float32") = R.broadcast_to(
+ R.reshape(x, (1, 3)), (2, 3)
+ )
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def _build_stablehlo_iota_model(iota_dimension, output_shape):
+ """STABLEHLO_IOTA with given iota dimension and output shape."""
+ builder = flatbuffers.Builder(1024)
+
+ _tfl_stablehlo_iota_opts.StablehloIotaOptionsStart(builder)
+ _tfl_stablehlo_iota_opts.StablehloIotaOptionsAddIotaDimension(builder,
iota_dimension)
+ iota_opts = _tfl_stablehlo_iota_opts.StablehloIotaOptionsEnd(builder)
+
+ builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_IOTA")
+ op_code = _build_operator_code(builder, builtin_op)
+
+ t_out = _build_tensor(builder, 0, output_shape,
tensor_type=_tfl_tensor_type.INT32)
+ tensors = [t_out]
+
+ op = _build_operator(
+ builder,
+ 0,
+ [],
+ [0],
+ builtin_options2_type=_tfl_builtin_options2.StablehloIotaOptions,
+ builtin_options2=iota_opts,
+ )
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[op],
+ inputs=[],
+ outputs=[0],
+ )
+ buffers = [_build_buffer(builder)]
+ return _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+ )
+
+
+def test_stablehlo_iota():
+ """TFLite StableHLO IOTA: iota_dim=1, shape=(2, 3), dtype=int32."""
+ mod = _load_model_from_buffer(
+ _build_stablehlo_iota_model(iota_dimension=1, output_shape=[2, 3])
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main() -> R.Tensor((2, 3), dtype="int32"):
+ R.func_attr({"num_input": 0})
+ with R.dataflow():
+ gv: R.Tensor((2, 3), dtype="int32") = R.broadcast_to(
+ R.reshape(R.arange(0, 3, 1, dtype="int32"), (1, 3)), (2, 3)
+ )
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def _build_stablehlo_compare_model(direction):
+ """STABLEHLO_COMPARE with given comparison direction."""
+ builder = flatbuffers.Builder(1024)
+
+ _tfl_stablehlo_compare_opts.StablehloCompareOptionsStart(builder)
+
_tfl_stablehlo_compare_opts.StablehloCompareOptionsAddComparisonDirection(builder,
direction)
+ cmp_opts = _tfl_stablehlo_compare_opts.StablehloCompareOptionsEnd(builder)
+
+ builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_COMPARE")
+ op_code = _build_operator_code(builder, builtin_op)
+
+ shape = [2, 2]
+ t_lhs = _build_tensor(builder, 0, shape)
+ t_rhs = _build_tensor(builder, 1, shape)
+ t_out = _build_tensor(builder, 2, shape, tensor_type=_tfl_tensor_type.BOOL)
+ tensors = [t_lhs, t_rhs, t_out]
+
+ op = _build_operator(
+ builder,
+ 0,
+ [0, 1],
+ [2],
+ builtin_options2_type=_tfl_builtin_options2.StablehloCompareOptions,
+ builtin_options2=cmp_opts,
+ )
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[op],
+ inputs=[0, 1],
+ outputs=[2],
+ )
+ buffers = [_build_buffer(builder) for _ in range(3)]
+ return _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+ )
+
+
[email protected](
+ "direction_enum, relax_op",
+ [
+
(_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_EQ,
R.equal),
+
(_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_NE,
R.not_equal),
+
(_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_GE,
R.greater_equal),
+
(_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_GT,
R.greater),
+
(_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LE,
R.less_equal),
+
(_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LT,
R.less),
+ ],
+)
+def test_stablehlo_compare(direction_enum, relax_op):
+ """TFLite StableHLO COMPARE with various comparison directions."""
+ mod =
_load_model_from_buffer(_build_stablehlo_compare_model(direction_enum))
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 2), dtype="float32"),
+ y: R.Tensor((2, 2), dtype="float32"),
+ ) -> R.Tensor((2, 2), dtype="bool"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ gv: R.Tensor((2, 2), dtype="bool") = relax_op(x, y)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_compare_totalorder_unsupported():
+ """STABLEHLO_COMPARE with TOTALORDER type raises OpNotImplemented."""
+ builder = flatbuffers.Builder(1024)
+
+ _DIR = _tfl_stablehlo_comp_dir.StablehloComparisonDirection
+ _TYPE = _tfl_stablehlo_comp_type.StablehloComparisonType
+
+ _tfl_stablehlo_compare_opts.StablehloCompareOptionsStart(builder)
+ _tfl_stablehlo_compare_opts.StablehloCompareOptionsAddComparisonDirection(
+ builder, _DIR.STABLEHLO_COMPARISON_DIRECTION_EQ
+ )
+ _tfl_stablehlo_compare_opts.StablehloCompareOptionsAddCompareType(
+ builder, _TYPE.STABLEHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER
+ )
+ cmp_opts = _tfl_stablehlo_compare_opts.StablehloCompareOptionsEnd(builder)
+
+ builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_COMPARE")
+ op_code = _build_operator_code(builder, builtin_op)
+
+ shape = [2, 2]
+ t_lhs = _build_tensor(builder, 0, shape)
+ t_rhs = _build_tensor(builder, 1, shape)
+ t_out = _build_tensor(builder, 2, shape, tensor_type=_tfl_tensor_type.BOOL)
+ tensors = [t_lhs, t_rhs, t_out]
+
+ op = _build_operator(
+ builder,
+ 0,
+ [0, 1],
+ [2],
+ builtin_options2_type=_tfl_builtin_options2.StablehloCompareOptions,
+ builtin_options2=cmp_opts,
+ )
+ subgraph = _build_subgraph(
+ builder,
+ tensors=tensors,
+ operators=[op],
+ inputs=[0, 1],
+ outputs=[2],
+ )
+ buffers = [_build_buffer(builder) for _ in range(3)]
+ buf = _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+ )
+
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="TOTALORDER"):
+ from_tflite(tflite_model)
+
+
+def _stablehlo_gather_i64_vector(builder, start_vector_fn, values):
+ start_vector_fn(builder, len(values))
+ for value in reversed(values):
+ builder.PrependInt64(value)
+ return builder.EndVector()
+
+
+def _build_stablehlo_gather_model(
+ *,
+ data_shape,
+ indices_shape,
+ output_shape,
+ offset_dims,
+ collapsed_slice_dims,
+ start_index_map,
+ index_vector_dim,
+ slice_sizes,
+):
+ """Build a minimal STABLEHLO_GATHER TFLite model."""
+ builder = flatbuffers.Builder(1024)
+
+ offset_dims_vec = _stablehlo_gather_i64_vector(
+ builder,
+ _tfl_stablehlo_gather_opts.StablehloGatherOptionsStartOffsetDimsVector,
+ offset_dims,
+ )
+ collapsed_slice_dims_vec = _stablehlo_gather_i64_vector(
+ builder,
+
_tfl_stablehlo_gather_opts.StablehloGatherOptionsStartCollapsedSliceDimsVector,
+ collapsed_slice_dims,
+ )
+ start_index_map_vec = _stablehlo_gather_i64_vector(
+ builder,
+
_tfl_stablehlo_gather_opts.StablehloGatherOptionsStartStartIndexMapVector,
+ start_index_map,
+ )
+ slice_sizes_vec = _stablehlo_gather_i64_vector(
+ builder,
+ _tfl_stablehlo_gather_opts.StablehloGatherOptionsStartSliceSizesVector,
+ slice_sizes,
+ )
+
+ _tfl_stablehlo_gather_opts.StablehloGatherOptionsStart(builder)
+ _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddOffsetDims(
+ builder, offset_dims_vec
+ )
+ _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddCollapsedSliceDims(
+ builder, collapsed_slice_dims_vec
+ )
+ _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddStartIndexMap(
+ builder, start_index_map_vec
+ )
+ _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddIndexVectorDim(
+ builder, index_vector_dim
+ )
+ _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddSliceSizes(
+ builder, slice_sizes_vec
+ )
+ gather_opts = _tfl_stablehlo_gather_opts.StablehloGatherOptionsEnd(builder)
+
+ builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_GATHER")
+ op_code = _build_operator_code(builder, builtin_op)
+
+ t_data = _build_tensor(builder, 0, data_shape)
+ t_indices = _build_tensor(builder, 1, indices_shape,
tensor_type=_tfl_tensor_type.INT32)
+ t_out = _build_tensor(builder, 2, output_shape)
+ op = _build_operator(
+ builder,
+ 0,
+ [0, 1],
+ [2],
+ builtin_options2_type=_tfl_builtin_options2.StablehloGatherOptions,
+ builtin_options2=gather_opts,
+ )
+ subgraph = _build_subgraph(
+ builder,
+ tensors=[t_data, t_indices, t_out],
+ operators=[op],
+ inputs=[0, 1],
+ outputs=[2],
+ )
+ buffers = [_build_buffer(builder) for _ in range(3)]
+ return _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+ )
+
+
[email protected](
+ "axis, offset_dims, slice_sizes, output_shape",
+ [
+ (0, [1], [1, 4], [2, 4]),
+ (1, [0], [3, 1], [3, 2]),
+ ],
+)
+def test_stablehlo_gather_take_equivalent(axis, offset_dims, slice_sizes,
output_shape):
+ """TFLite StableHLO GATHER take-equivalent subset."""
+ mod = _load_model_from_buffer(
+ _build_stablehlo_gather_model(
+ data_shape=[3, 4],
+ indices_shape=[2, 1],
+ output_shape=output_shape,
+ offset_dims=offset_dims,
+ collapsed_slice_dims=[axis],
+ start_index_map=[axis],
+ index_vector_dim=1,
+ slice_sizes=slice_sizes,
+ )
+ )
+
+ out_shape = tuple(output_shape)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ data: R.Tensor((3, 4), dtype="float32"),
+ indices: R.Tensor((2, 1), dtype="int32"),
+ ) -> R.Tensor(out_shape, dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ reshaped: R.Tensor((2,), dtype="int32") = R.reshape(indices,
(2,))
+ gv: R.Tensor(out_shape, dtype="float32") = R.take(
+ data, reshaped, axis=axis, mode="fast"
+ )
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_gather_complex_unsupported():
+ """TFLite StableHLO GATHER with multi-dimensional start_index_map is
unsupported."""
+ buf = _build_stablehlo_gather_model(
+ data_shape=[3, 4],
+ indices_shape=[2, 2],
+ output_shape=[2],
+ offset_dims=[],
+ collapsed_slice_dims=[0, 1],
+ start_index_map=[0, 1],
+ index_vector_dim=1,
+ slice_sizes=[1, 1],
+ )
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="start_index_map"):
+ from_tflite(tflite_model)
+
+def _pad_vector(builder, start_vector_fn, values):
+ """Build a FlatBuffers int64 vector for pad options."""
+ start_vector_fn(builder, len(values))
+ for v in reversed(values):
+ builder.PrependInt64(v)
+ return builder.EndVector()
+
+
+def _build_stablehlo_pad_model(edge_low, edge_high, interior):
+ """STABLEHLO_PAD with given padding vectors."""
+ builder = flatbuffers.Builder(1024)
+
+ lo_vec = _pad_vector(
+ builder,
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingLowVector,
+ edge_low,
+ )
+ hi_vec = _pad_vector(
+ builder,
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingHighVector,
+ edge_high,
+ )
+ int_vec = _pad_vector(
+ builder,
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsStartInteriorPaddingVector,
+ interior,
+ )
+
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsStart(builder)
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingLow(builder,
lo_vec)
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingHigh(builder,
hi_vec)
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsAddInteriorPadding(builder,
int_vec)
+ pad_opts = _tfl_stablehlo_pad_opts.StablehloPadOptionsEnd(builder)
+
+ builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_PAD")
+ op_code = _build_operator_code(builder, builtin_op)
+
+ t_in = _build_tensor(builder, 0, [3, 3])
+ # pad_value is a scalar tensor
+ t_pad_val = _build_tensor(builder, 1, [])
+ t_out = _build_tensor(builder, 2, [4, 4])
+ tensors = [t_in, t_pad_val, t_out]
+
+ op = _build_operator(
+ builder, 0, [0, 1], [2],
+ builtin_options2_type=_tfl_builtin_options2.StablehloPadOptions,
+ builtin_options2=pad_opts,
+ )
+ subgraph = _build_subgraph(
+ builder, tensors=tensors, operators=[op],
+ inputs=[0], outputs=[2],
+ )
+ buffers = [
+ _build_buffer(builder),
+ _build_buffer(builder, np.array([0.0], dtype=np.float32).tobytes()),
+ _build_buffer(builder),
+ ]
+ return _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+ )
+
+
+def test_stablehlo_pad():
+ """TFLite StableHLO PAD: edge_low=[1,0], edge_high=[0,1],
interior=[0,0]."""
+ mod = _load_model_from_buffer(
+ _build_stablehlo_pad_model(edge_low=[1, 0], edge_high=[0, 1],
interior=[0, 0])
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((3, 3), dtype="float32"),
+ ) -> R.Tensor((4, 4), dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor((4, 4), dtype="float32") = R.nn.pad(
+ x, pad_width=[1, 0, 0, 1], pad_value=0.0
+ )
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_pad_interior_unsupported():
+ """STABLEHLO_PAD with interior padding raises OpNotImplemented."""
+ builder = flatbuffers.Builder(1024)
+
+ lo_vec = _pad_vector(
+ builder,
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingLowVector,
+ [0, 0],
+ )
+ hi_vec = _pad_vector(
+ builder,
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingHighVector,
+ [0, 0],
+ )
+ int_vec = _pad_vector(
+ builder,
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsStartInteriorPaddingVector,
+ [1, 0],
+ )
+
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsStart(builder)
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingLow(builder,
lo_vec)
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingHigh(builder,
hi_vec)
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsAddInteriorPadding(builder,
int_vec)
+ pad_opts = _tfl_stablehlo_pad_opts.StablehloPadOptionsEnd(builder)
+
+ builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_PAD")
+ op_code = _build_operator_code(builder, builtin_op)
+
+ t_in = _build_tensor(builder, 0, [3, 3])
+ t_pv = _build_tensor(builder, 1, [])
+ t_out = _build_tensor(builder, 2, [3, 3])
+ tensors = [t_in, t_pv, t_out]
+
+ op = _build_operator(
+ builder, 0, [0, 1], [2],
+ builtin_options2_type=_tfl_builtin_options2.StablehloPadOptions,
+ builtin_options2=pad_opts,
+ )
+ subgraph = _build_subgraph(
+ builder, tensors=tensors, operators=[op],
+ inputs=[0], outputs=[2],
+ )
+ buffers = [
+ _build_buffer(builder),
+ _build_buffer(builder, np.array([0.0], dtype=np.float32).tobytes()),
+ _build_buffer(builder),
+ ]
+ buf = _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+ )
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+ with pytest.raises(tvm.error.OpNotImplemented, match="interior"):
+ from_tflite(tflite_model)
+
+
+def test_stablehlo_pad_negative_unsupported():
+ """STABLEHLO_PAD with negative edge padding raises OpNotImplemented."""
+ builder = flatbuffers.Builder(1024)
+
+ lo_vec = _pad_vector(
+ builder,
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingLowVector,
+ [-1, 0],
+ )
+ hi_vec = _pad_vector(
+ builder,
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingHighVector,
+ [0, 0],
+ )
+ int_vec = _pad_vector(
+ builder,
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsStartInteriorPaddingVector,
+ [0, 0],
+ )
+
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsStart(builder)
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingLow(builder,
lo_vec)
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingHigh(builder,
hi_vec)
+ _tfl_stablehlo_pad_opts.StablehloPadOptionsAddInteriorPadding(builder,
int_vec)
+ pad_opts = _tfl_stablehlo_pad_opts.StablehloPadOptionsEnd(builder)
+
+ builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_PAD")
+ op_code = _build_operator_code(builder, builtin_op)
+
+ t_in = _build_tensor(builder, 0, [3, 3])
+ t_pv = _build_tensor(builder, 1, [])
+ t_out = _build_tensor(builder, 2, [2, 3])
+ tensors = [t_in, t_pv, t_out]
+
+ op = _build_operator(
+ builder, 0, [0, 1], [2],
+ builtin_options2_type=_tfl_builtin_options2.StablehloPadOptions,
+ builtin_options2=pad_opts,
+ )
+ subgraph = _build_subgraph(
+ builder, tensors=tensors, operators=[op],
+ inputs=[0], outputs=[2],
+ )
+ buffers = [
+ _build_buffer(builder),
+ _build_buffer(builder, np.array([0.0], dtype=np.float32).tobytes()),
+ _build_buffer(builder),
+ ]
+ buf = _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+ )
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+ with pytest.raises(tvm.error.OpNotImplemented, match="negative"):
+ from_tflite(tflite_model)
+
+
+def _build_stablehlo_dynamic_slice_model(slice_sizes, start_vals):
+ """STABLEHLO_DYNAMIC_SLICE with given slice sizes and start indices."""
+ builder = flatbuffers.Builder(1024)
+ ndim = len(slice_sizes)
+
+ # Build SliceSizes vector
+
_tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStartSliceSizesVector(
+ builder, ndim
+ )
+ for v in reversed(slice_sizes):
+ builder.PrependInt64(v)
+ sizes_vec = builder.EndVector()
+
+ _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStart(builder)
+ _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsAddSliceSizes(
+ builder, sizes_vec
+ )
+ dyn_opts =
_tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsEnd(builder)
+
+ builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DYNAMIC_SLICE")
+ op_code = _build_operator_code(builder, builtin_op)
+
+ # operand + start indices + output
+ t_in = _build_tensor(builder, 0, [3, 3])
+ start_tensors = []
+ start_inputs = []
+ start_buffers = []
+ for i, sv in enumerate(start_vals):
+ bidx = 1 + i
+ start_tensors.append(
+ _build_tensor(builder, bidx, [],
tensor_type=_tfl_tensor_type.INT32)
+ )
+ start_inputs.append(bidx)
+ start_buffers.append(
+ _build_buffer(builder, np.array([sv], dtype=np.int32).tobytes())
+ )
+ out_idx = 1 + ndim
+ t_out = _build_tensor(builder, out_idx, slice_sizes)
+ tensors = [t_in, *start_tensors, t_out]
+ op_inputs = [0, *start_inputs]
+
+ op = _build_operator(
+ builder, 0, op_inputs, [out_idx],
+
builtin_options2_type=_tfl_builtin_options2.StablehloDynamicSliceOptions,
+ builtin_options2=dyn_opts,
+ )
+ subgraph = _build_subgraph(
+ builder, tensors=tensors, operators=[op],
+ inputs=[0], outputs=[out_idx],
+ )
+ buffers = [_build_buffer(builder), *start_buffers, _build_buffer(builder)]
+ return _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+ )
+
+
+def _build_stablehlo_dynamic_slice_with_dynamic_starts_model(slice_sizes):
+ """STABLEHLO_DYNAMIC_SLICE with runtime start-index inputs."""
+ builder = flatbuffers.Builder(1024)
+ ndim = len(slice_sizes)
+
+
_tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStartSliceSizesVector(
+ builder, ndim
+ )
+ for v in reversed(slice_sizes):
+ builder.PrependInt64(v)
+ sizes_vec = builder.EndVector()
+
+ _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStart(builder)
+ _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsAddSliceSizes(
+ builder, sizes_vec
+ )
+ dyn_opts =
_tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsEnd(builder)
+
+ builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DYNAMIC_SLICE")
+ op_code = _build_operator_code(builder, builtin_op)
+
+ t_in = _build_tensor(builder, 0, [3, 3])
+ start_tensors = [
+ _build_tensor(builder, 1 + i, [], tensor_type=_tfl_tensor_type.INT32)
+ for i in range(ndim)
+ ]
+ out_idx = 1 + ndim
+ t_out = _build_tensor(builder, out_idx, slice_sizes)
+ start_inputs = list(range(1, 1 + ndim))
+ tensors = [t_in, *start_tensors, t_out]
+ op_inputs = [0, *start_inputs]
+
+ op = _build_operator(
+ builder, 0, op_inputs, [out_idx],
+
builtin_options2_type=_tfl_builtin_options2.StablehloDynamicSliceOptions,
+ builtin_options2=dyn_opts,
+ )
+ subgraph = _build_subgraph(
+ builder, tensors=tensors, operators=[op],
+ inputs=op_inputs, outputs=[out_idx],
+ )
+ buffers = [_build_buffer(builder) for _ in range(out_idx + 1)]
+ return _finish_tflite_model(
+ builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers
+ )
+
+
+def test_stablehlo_dynamic_slice():
+ """TFLite StableHLO DYNAMIC_SLICE: start=[0,1], sizes=[2,2] from (3,3)."""
+ mod = _load_model_from_buffer(
+ _build_stablehlo_dynamic_slice_model(
+ slice_sizes=[2, 2], start_vals=[0, 1]
+ )
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((3, 3), dtype="float32"),
+ ) -> R.Tensor(dtype="float32", ndim=2):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor(dtype="float32", ndim=2) =
R.dynamic_strided_slice(
+ x,
+ R.const([0, 1], dtype="int64"),
+ R.const([2, 3], dtype="int64"),
+ R.const([1, 1], dtype="int64"),
+ )
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_stablehlo_dynamic_slice_dynamic_starts_unsupported():
+ """TFLite StableHLO DYNAMIC_SLICE with runtime starts is not supported
yet."""
+ buf =
_build_stablehlo_dynamic_slice_with_dynamic_starts_model(slice_sizes=[2, 2])
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="dynamic start"):
+ from_tflite(tflite_model)
+
+
+def test_stablehlo_dynamic_slice_out_of_bounds_unsupported():
+ """TFLite StableHLO DYNAMIC_SLICE with out-of-bounds starts is not
supported."""
+ buf = _build_stablehlo_dynamic_slice_model(slice_sizes=[2, 2],
start_vals=[0, 2])
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(buf, 0)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="out-of-bounds"):
+ from_tflite(tflite_model)
+
+
def _build_csr_sparsity(
builder,
*,