This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b21855758e [Relax] Implement operators to read runtime DLTensor* 
information (#16563)
b21855758e is described below

commit b21855758e40057e1b4d7f10410ed7bfb36aa808
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Tue Feb 20 14:59:13 2024 -0600

    [Relax] Implement operators to read runtime DLTensor* information (#16563)
    
    Relax is capable of expressing tensors whose element type is unknown.
    However, these must typically be replaced with a known dtype prior to
    compilation, as most operators require known data types prior to
    legalization.  This can be done by using a `relax::MatchCast` node,
    such as accepting a parameter `arg: R.Tensor([16,16])`, then defining
    the dtype using `R.match_cast(arg, R.Tensor([16,16],'float16'))`.
    
    However, using a `R.match_cast` node requires knowing which data type
    should be used in the new `R.Tensor`, and raises an error for an
    incorrect data type.  If an argument may be one of two distinct data
    types, `R.match_cast` cannot be used to check which data type is in
    use.
    
    This commit adds Relax operators to read the runtime values of a
    `DLTensor*` argument.  These can be be used to normalize arguments
    prior to a compute step.  For example, pre-processing a model weight
    that may be provided in either `float16` or `bfloat16` format.
---
 python/tvm/relax/expr.py             | 186 +++++++++++++++++++
 src/relax/op/tensor/inspect.cc       | 351 +++++++++++++++++++++++++++++++++++
 src/relax/op/tensor/inspect.h        |  92 +++++++++
 src/relax/transform/legalize_ops.cc  | 115 +++++++++---
 tests/python/relax/test_op_unpack.py | 127 +++++++++++++
 5 files changed, 846 insertions(+), 25 deletions(-)

diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index c9780bea7e..12f08f4dbf 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -244,6 +244,192 @@ class ExprWithOp(Expr, Scriptable):
                 raise IndexError from err
             raise
 
+    def _check_for_tensor_struct_info(self):
+        """Raise an error if this is something other than a Tensor
+
+        Used for early checks in `expr.dtype` and `expr.shape`
+        accessors.  While invalid usage would cause errors to be
+        raised during shape inference, an earlier check makes it
+        easier to find the invalid usage.
+        """
+        if self.struct_info_ is None:
+            return
+
+        if not isinstance(self.struct_info_, tvm.relax.TensorStructInfo):
+            raise TypeError(
+                f"Runtime unpacking of DLDataType is only implemented for 
tensors, "
+                f"but was applied to object {self} of type {type(self)}."
+            )
+
+    @property
+    def dtype(self) -> "_DLTensorDTypeProxy":
+        """Returns a proxy object for accessing DLTensor::dtype"""
+        self._check_for_tensor_struct_info()
+        return _DLTensorDTypeProxy(self)
+
+    @property
+    def ndim(self) -> "Expr":
+        """Returns the runtime value of DLTensor::ndim"""
+        self._check_for_tensor_struct_info()
+        op = tvm.ir.Op.get("relax.inspect.tensor_ndim")
+        return tvm.relax.Call(op, [self])
+
+    @property
+    def shape(self) -> "_DLTensorShapeProxy":
+        """Returns a proxy object for accessing DLTensor::shape"""
+        self._check_for_tensor_struct_info()
+        return _DLTensorShapeProxy(self)
+
+
+class _DLTensorDTypeProxy(tvm.runtime.ObjectGeneric):
+    """A proxy object for unpacking DLDatatype from DLTensor
+
+    Exposes accessors for `DLDataType` fields `type_code`, `lanes`,
+    and `bits` within a `DLTensor::dtype`.  Accessing these fields
+    will produce `relax.Call` expressions, representing the field's
+    runtime value.  If the datatype of the tensor is known at
+    compile-time, the `relax.Call` will be normalized into a
+    `relax.PrimValue`, with no runtime cost.
+
+    Parameters
+    ----------
+    tensor: relax.Expr
+
+        The relax tensor (or a variable referring to a relax tensor),
+        whose runtime shape is being inspected.
+
+    """
+
+    def __init__(self, tensor):
+        self.tensor = tensor
+
+    def asobject(self):
+        """Provide expected in error message
+
+        This method is called when `_DLTensorDTypeProxy` is used in a
+        context that requires a `relax.Expr`.  This usage is not
+        supported, and raising an error here can provide suggested
+        fixes that are not present in the default error message from
+        `tvm.runtime.convert_to_object`.
+        """
+
+        fields = [f"{self.tensor}.dtype.{field}" for field in ["type_code", 
"bits", "lanes"]]
+        raise TypeError(
+            f"{self.tensor}.dtype cannot be converted to a relax expression, "
+            f"and should be used as a proxy object to access "
+            f"fields {fields}"
+        )
+
+    @property
+    def type_code(self) -> Expr:
+        """Accessor for the DLDataType::bits field
+
+        Returns
+        -------
+        type_code: Expr
+
+            The type code of the DLTensor.  See the `DLDeviceType`
+            enum in `dlpack.h` for more information.
+        """
+        op = tvm.ir.Op.get("relax.inspect.tensor_dtype_code")
+        return tvm.relax.Call(op, [self.tensor])
+
+    @property
+    def lanes(self) -> Expr:
+        """Accessor for the DLDataType::bits field
+
+        Returns
+        -------
+        lanes: Expr
+
+            The number of lanes in the DLDataType
+        """
+        op = tvm.ir.Op.get("relax.inspect.tensor_dtype_lanes")
+        return tvm.relax.Call(op, [self.tensor])
+
+    @property
+    def bits(self) -> Expr:
+        """Accessor for the DLDataType::bits field
+
+        Returns
+        -------
+        bits: Expr
+
+            The number of bits in the DLDataType
+        """
+        op = tvm.ir.Op.get("relax.inspect.tensor_dtype_bits")
+        return tvm.relax.Call(op, [self.tensor])
+
+
+class _DLTensorShapeProxy(tvm.runtime.ObjectGeneric):
+    """A proxy object for unpacking the shape from DLTensor
+
+    Exposes accessors for the `DLTensor::shape` field.  Accessing
+    these fields will produce `relax.Call` expressions, representing
+    the field's runtime value.  If the datatype of the tensor is known
+    at compile-time, the `relax.Call` will be normalized into a
+    `relax.PrimValue`, with no runtime cost.
+
+    Parameters
+    ----------
+    tensor: relax.Expr
+
+        The relax tensor (or a variable referring to a relax tensor),
+        whose runtime shape is being inspected.
+    """
+
+    def __init__(self, tensor):
+        self.tensor = tensor
+
+    def asobject(self):
+        """Provide expected in error message
+
+        This method is called when `_DLTensorShapeProxy` is used in a
+        context that requires a `relax.Expr`.  This usage is not
+        supported, and raising an error here can provide suggested
+        fixes that are not present in the default error message from
+        `tvm.runtime.convert_to_object`.
+        """
+        raise TypeError(
+            f"{self.tensor}.shape cannot be converted to a relax expression, "
+            f"and should be used as a proxy object to access the runtime shape 
of the DLTensor. "
+            f"The DLTensor::ndim field can be accessed as len({self.tensor}), "
+            f"and the DLTensor::shape array can be accessed as 
{self.tensor}.shape[i]"
+        )
+
+    def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr:
+        """Returns the extent of a tensor axis
+
+        Parameters
+        ----------
+        axis: Union[int, PrimExpr, Expr]
+
+            The tensor axis whose extent should be returned.  For ease
+            of use, any python integers or TIR expressions are
+            converted to `relax.Expr`.
+
+        Returns
+        -------
+        extent: Expr
+
+            The extent of the tensor's axis.
+        """
+
+        if not isinstance(axis, tvm.relax.Expr):
+            axis = tvm.relax.PrimValue(axis)
+
+        if axis.struct_info_ is not None and not isinstance(
+            axis.struct_info_, tvm.relax.PrimStructInfo
+        ):
+            raise TypeError(
+                f"The index used to access {self.tensor}.shape "
+                f'must have struct info R.Prim("int64"), '
+                f"but index {axis} had struct info {axis.struct_info_}."
+            )
+
+        op = tvm.ir.Op.get("relax.inspect.tensor_shape_i")
+        return tvm.relax.Call(op, [self.tensor, axis])
+
 
 @tvm._ffi.register_object("relax.expr.Call")
 class Call(ExprWithOp):
diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc
new file mode 100644
index 0000000000..a40b2af5ef
--- /dev/null
+++ b/src/relax/op/tensor/inspect.cc
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file inspect.cc
+ * \brief Operators to access runtime DLTensor parameters
+ */
+
+#include "inspect.h"
+
+#include <tvm/relax/op_attr_types.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace relax {
+namespace inspect {
+
+TensorStructInfo GetTensorArgInfo(const Call& call) {
+  CHECK_EQ(call->args.size(), 1) << "TypeError: "
+                                 << "Operator " << call->op << " expects one 
argument, "
+                                 << "but received " << call->args.size()
+                                 << " arguments: " << call->args;
+
+  const auto& arg = call->args[0];
+  auto sinfo = GetStructInfo(arg);
+
+  auto tensor_sinfo = sinfo.as<TensorStructInfo>();
+  CHECK(tensor_sinfo) << "TypeError: "
+                      << "Operator " << call->op << " expects a tensor 
argument, "
+                      << "but argument " << arg << " has struct info " << 
sinfo;
+
+  return tensor_sinfo.value();
+}
+
+DataType GetTensorDataType(const Call& call) { return 
GetTensorArgInfo(call)->dtype; }
+
+tir::PrimFunc GetDLTensorField(tir::builtin::TVMStructFieldKind field, 
DataType field_dtype) {
+  tir::Var dlpack_handle("dlpack_handle", DataType::Handle());
+
+  tir::Var value("value", field_dtype);
+
+  tir::LetStmt body(
+      value,
+      tir::Call(field_dtype, tir::builtin::tvm_struct_get(),
+                {dlpack_handle, IntImm(DataType::Int(32), 0), 
IntImm(DataType::Int(32), field)}),
+      tir::Evaluate(tvm::ret(value)));
+
+  DictAttrs attrs({{"tir.is_scheduled", Bool(true)}, {"tir.is_host", 
Bool(true)}});
+
+  tir::PrimFunc func(Array<tir::Var>{dlpack_handle}, body, 
PrimType(field_dtype), {}, attrs);
+
+  FuncStructInfo sinfo({TensorStructInfo(DataType::Void(), kUnknownNDim)},
+                       PrimStructInfo(field_dtype));
+  UpdateStructInfo(func, sinfo);
+
+  return func;
+}
+
+Expr NormalizeToKnownPrimValue(const BlockBuilder&, Call call) {
+  if (auto prim_sinfo = call->struct_info_.as<PrimStructInfoNode>()) {
+    if (prim_sinfo->value.defined()) {
+      return PrimValue(prim_sinfo->value.value());
+    }
+  }
+  return call;
+}
+
+//// relax.tensor_dtype_code
+
+Expr tensor_dtype_code(Expr expr) {
+  static const Op& op = Op::Get("relax.inspect.tensor_dtype_code");
+  return Call(op, {expr});
+}
+
+StructInfo InferStructInfoTensorDtypeCode(const Call& call, const 
BlockBuilder&) {
+  auto dlpack_type = DataType::UInt(8);
+
+  DataType dtype = GetTensorDataType(call);
+  if (dtype.is_void()) {
+    return PrimStructInfo(dlpack_type);
+  } else {
+    return PrimStructInfo(IntImm(dlpack_type, dtype.code()));
+  }
+}
+
+Expr LegalizeTensorDtypeCode(const BlockBuilder& bb, const Call& call) {
+  auto field_dtype = Downcast<PrimStructInfo>(call->struct_info_)->dtype;
+
+  Expr arg = call->args[0];
+  tir::PrimFunc getter =
+      GetDLTensorField(tir::builtin::TVMStructFieldKind::kArrTypeCode, 
field_dtype);
+
+  GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_code");
+  return Call(gvar_getter, {arg});
+}
+
+TVM_REGISTER_OP("relax.inspect.tensor_dtype_code")
+    .set_num_inputs(1)
+    .add_argument("tensor", "Tensor", "The tensor to be inspected")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoTensorDtypeCode)
+    .set_attr<FLegalize>("FLegalize", LegalizeTensorDtypeCode)
+    .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+    .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+//// relax.tensor_dtype_bits
+
+Expr tensor_dtype_bits(Expr expr) {
+  static const Op& op = Op::Get("relax.inspect.tensor_dtype_bits");
+  return Call(op, {expr});
+}
+
+StructInfo InferStructInfoTensorDtypeBits(const Call& call, const 
BlockBuilder&) {
+  auto dlpack_type = DataType::UInt(8);
+
+  DataType dtype = GetTensorDataType(call);
+  if (dtype.is_void()) {
+    return PrimStructInfo(dlpack_type);
+  } else {
+    return PrimStructInfo(IntImm(dlpack_type, dtype.bits()));
+  }
+}
+
+Expr LegalizeTensorDtypeBits(const BlockBuilder& bb, const Call& call) {
+  auto field_dtype = Downcast<PrimStructInfo>(call->struct_info_)->dtype;
+
+  Expr arg = call->args[0];
+  tir::PrimFunc getter =
+      GetDLTensorField(tir::builtin::TVMStructFieldKind::kArrTypeBits, 
field_dtype);
+
+  GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_bits");
+  return Call(gvar_getter, {arg});
+}
+
+TVM_REGISTER_OP("relax.inspect.tensor_dtype_bits")
+    .set_num_inputs(1)
+    .add_argument("tensor", "Tensor", "The tensor to be inspected")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoTensorDtypeBits)
+    .set_attr<FLegalize>("FLegalize", LegalizeTensorDtypeBits)
+    .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+    .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+//// relax.tensor_dtype_lanes
+
+Expr tensor_dtype_lanes(Expr expr) {
+  static const Op& op = Op::Get("relax.inspect.tensor_dtype_lanes");
+  return Call(op, {expr});
+}
+
+StructInfo InferStructInfoTensorDtypeLanes(const Call& call, const 
BlockBuilder&) {
+  auto dlpack_type = DataType::UInt(16);
+
+  DataType dtype = GetTensorDataType(call);
+  if (dtype.is_void()) {
+    return PrimStructInfo(dlpack_type);
+  } else {
+    return PrimStructInfo(IntImm(dlpack_type, dtype.lanes()));
+  }
+}
+
+Expr LegalizeTensorDtypeLanes(const BlockBuilder& bb, const Call& call) {
+  auto field_dtype = Downcast<PrimStructInfo>(call->struct_info_)->dtype;
+
+  Expr arg = call->args[0];
+  tir::PrimFunc getter =
+      GetDLTensorField(tir::builtin::TVMStructFieldKind::kArrTypeLanes, 
field_dtype);
+
+  GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_lanes");
+  return Call(gvar_getter, {arg});
+}
+
+TVM_REGISTER_OP("relax.inspect.tensor_dtype_lanes")
+    .set_num_inputs(1)
+    .add_argument("tensor", "Tensor", "The tensor to be inspected")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoTensorDtypeLanes)
+    .set_attr<FLegalize>("FLegalize", LegalizeTensorDtypeLanes)
+    .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+    .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+//// relax.tensor_ndim
+
+Expr tensor_ndim(Expr expr) {
+  static const Op& op = Op::Get("relax.inspect.tensor_ndim");
+  return Call(op, {expr});
+}
+
+StructInfo InferStructInfoTensorNDim(const Call& call, const BlockBuilder&) {
+  auto dlpack_type = DataType::Int(32);
+
+  auto sinfo = GetTensorArgInfo(call);
+  if (sinfo->IsUnknownNdim()) {
+    return PrimStructInfo(dlpack_type);
+  } else {
+    return PrimStructInfo(IntImm(dlpack_type, sinfo->ndim));
+  }
+}
+
+Expr LegalizeTensorNDim(const BlockBuilder& bb, const Call& call) {
+  auto field_dtype = Downcast<PrimStructInfo>(call->struct_info_)->dtype;
+
+  Expr arg = call->args[0];
+  tir::PrimFunc getter = 
GetDLTensorField(tir::builtin::TVMStructFieldKind::kArrNDim, field_dtype);
+
+  GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_ndim");
+  return Call(gvar_getter, {arg});
+}
+
+TVM_REGISTER_OP("relax.inspect.tensor_ndim")
+    .set_num_inputs(1)
+    .add_argument("tensor", "Tensor", "The tensor to be inspected")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTensorNDim)
+    .set_attr<FLegalize>("FLegalize", LegalizeTensorNDim)
+    .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+    .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+//// relax.tensor_shape_i
+
+Expr tensor_shape_i(Expr expr) {
+  static const Op& op = Op::Get("relax.inspect.tensor_shape_i");
+  return Call(op, {expr});
+}
+
+StructInfo InferStructInfoTensorShape(const Call& call, const BlockBuilder&) {
+  auto dlpack_type = DataType::Int(64);
+
+  CHECK_EQ(call->args.size(), 2) << "TypeError: "
+                                 << "Operator " << call->op << " expects two 
arguments, "
+                                 << "but received " << call->args.size()
+                                 << " arguments: " << call->args;
+  const auto& arg = call->args[0];
+  const auto& axis = call->args[1];
+
+  auto tensor_sinfo = arg->struct_info_.as<TensorStructInfoNode>();
+  CHECK(tensor_sinfo) << "TypeError: "
+                      << "Operator " << call->op << " expects arguments 
(tensor, axis), "
+                      << "but the first argument " << arg << " in expression " 
<< call
+                      << " has struct info " << arg->struct_info_;
+
+  auto axis_sinfo = axis->struct_info_.as<PrimStructInfoNode>();
+  CHECK(axis_sinfo) << "TypeError: "
+                    << "Operator " << call->op << " expects arguments (tensor, 
axis), "
+                    << "but the second argument " << arg << " in expression " 
<< call
+                    << " has struct info " << axis->struct_info_;
+
+  auto int_imm_axis = axis_sinfo->value.as<IntImmNode>();
+
+  if (int_imm_axis) {
+    CHECK_GE(int_imm_axis->value, 0);
+  }
+  if (int_imm_axis && !tensor_sinfo->IsUnknownNdim()) {
+    CHECK_LT(int_imm_axis->value, tensor_sinfo->ndim)
+        << "ValueError: "
+        << "Expression " << call << " attempts to access " << arg << ".shape["
+        << int_imm_axis->value << "]"
+        << ", but " << arg << ".shape only has " << tensor_sinfo->ndim << " 
elements";
+  }
+
+  auto tensor_shape = tensor_sinfo->GetShape();
+  if (int_imm_axis && tensor_shape.defined()) {
+    return PrimStructInfo(tensor_shape.value()[int_imm_axis->value]);
+  } else {
+    return PrimStructInfo(dlpack_type);
+  }
+}
+
+Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) {
+  auto field_dtype = Downcast<PrimStructInfo>(call->struct_info_)->dtype;
+
+  tir::PrimFunc getter = [&]() -> tir::PrimFunc {
+    tir::Var dlpack_handle("dlpack_handle", DataType::Handle());
+    tir::Var axis("axis", DataType::Int(64));
+
+    tir::Var ndim("ndim", DataType::Int(32));
+
+    tir::Buffer shape_buffer = tir::decl_buffer({ndim}, field_dtype, "shape");
+
+    tir::Var extent("extent", field_dtype);
+
+    tir::Stmt body = tir::Evaluate(tvm::ret(extent));
+
+    body = tir::LetStmt(extent, tir::BufferLoad(shape_buffer, {axis}), body);
+    body = tir::DeclBuffer(shape_buffer, body);
+    body = tir::LetStmt(
+        shape_buffer->data,
+        tir::Call(DataType::Handle(), tir::builtin::tvm_struct_get(),
+                  {dlpack_handle, IntImm(DataType::Int(32), 0),
+                   IntImm(DataType::Int(32), 
tir::builtin::TVMStructFieldKind::kArrShape)}),
+        body);
+
+    body = tir::AssertStmt(
+        axis < tvm::cast(axis->dtype, ndim),
+        tir::StringImm("Specified axis may not be larger than the tensor's 
dimensionality"), body);
+
+    body = tir::LetStmt(
+        ndim,
+        tir::Call(ndim->dtype, tir::builtin::tvm_struct_get(),
+                  {dlpack_handle, IntImm(DataType::Int(32), 0),
+                   IntImm(DataType::Int(32), 
tir::builtin::TVMStructFieldKind::kArrNDim)}),
+        body);
+
+    body = tir::AssertStmt(0 <= axis, tir::StringImm("Specified axis may not 
be negative"), body);
+
+    DictAttrs attrs({{"tir.is_scheduled", Bool(true)}, {"tir.is_host", 
Bool(true)}});
+
+    tir::PrimFunc func({dlpack_handle, axis}, body, PrimType(field_dtype), {}, 
attrs);
+
+    FuncStructInfo sinfo(
+        {TensorStructInfo(DataType::Void(), kUnknownNDim), 
PrimStructInfo(axis->dtype)},
+        PrimStructInfo(field_dtype));
+    UpdateStructInfo(func, sinfo);
+    return func;
+  }();
+
+  GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_shape_i");
+  return Call(gvar_getter, call->args);
+}
+
+TVM_REGISTER_OP("relax.inspect.tensor_shape_i")
+    .set_num_inputs(2)
+    .add_argument("tensor", "Tensor", "The tensor to be inspected")
+    .add_argument("axis", "Prim(int64)", "The axis whose extent should be 
returned")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTensorShape)
+    .set_attr<FLegalize>("FLegalize", LegalizeTensorShape)
+    .set_attr<Bool>("RequiresArgumentShapes", Bool(false))
+    .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+}  // namespace inspect
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/tensor/inspect.h b/src/relax/op/tensor/inspect.h
new file mode 100644
index 0000000000..0225b00fb3
--- /dev/null
+++ b/src/relax/op/tensor/inspect.h
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  Sex The NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  Sex The License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file inspect.h
+ * \brief Operators to access runtime DLTensor parameters
+ */
+#ifndef TVM_RELAX_OP_TENSOR_INSPECT_H_
+#define TVM_RELAX_OP_TENSOR_INSPECT_H_
+
+#include <tvm/relax/expr.h>
+
+namespace tvm {
+namespace relax {
+namespace inspect {
+
+/* \brief Return the DLTensor::dtype::type_code field
+ *
+ * \param expr The relax expression to be inspected.  Must have
+ * `TensorStructInfo`.
+ *
+ * \returns The uint8_t value of the type_code, with
+ * `PrimStructInfo(DataType::UInt(8))`
+ */
+Expr tensor_dtype_code(Expr expr);
+
+/* \brief Return the DLTensor::dtype::bits field
+ *
+ * \param expr The relax expression to be inspected.  Must have
+ * `TensorStructInfo`.
+ *
+ * \returns The uint8_t value of the number of bits, with
+ * `PrimStructInfo(DataType::UInt(8))`.  For vectorized types, returns
+ * the bit width of the underlying scalar type (e.g. 32 for
+ * "float32x4", not 128).
+ */
+Expr tensor_dtype_bits(Expr expr);
+
+/* \brief Return the DLTensor::dtype::lanes field
+ *
+ * \param expr The relax expression to be inspected.  Must have
+ * `TensorStructInfo`.
+ *
+ * \returns The uint16_t value of the number of lanes, with
+ * `PrimStructInfo(DataType::UInt(16))`
+ */
+Expr tensor_dtype_lanes(Expr expr);
+
+/* \brief Return the DLTensor::ndim field
+ *
+ * \param expr The relax expression to be inspected.  Must have
+ * `TensorStructInfo`.
+ *
+ * \returns The int32_t value of the dimensionality, with
+ * `PrimStructInfo(DataType::Int(32))`.
+ */
+Expr tensor_ndim(Expr expr);
+
+/* \brief Return the DLTensor::shape[i] field
+ *
+ * \param expr The relax expression to be inspected.  Must have
+ * `TensorStructInfo`.
+ *
+ * \param axis The axis to inspect.  Must be within the range `0 <=
+ *     axis < tensor_ndim(expr)`, or else the results are undefined.
+ *
+ * \returns The int64_t extent of the specified tensor axis, with
+ * `PrimStructInfo(DataType::Int(64))`.
+ */
+Expr tensor_shape_i(Expr expr, Expr axis);
+
+}  // namespace inspect
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_OP_TENSOR_INSPECT_H_
diff --git a/src/relax/transform/legalize_ops.cc 
b/src/relax/transform/legalize_ops.cc
index c8fba59dab..343c18acd7 100644
--- a/src/relax/transform/legalize_ops.cc
+++ b/src/relax/transform/legalize_ops.cc
@@ -90,23 +90,32 @@ class LegalizeMutator : public ExprMutator {
   bool WrapPureCondition(const Op& op, const Expr& legalized) {
     static const auto& purity_map = Op::GetAttrMap<Bool>("FPurity");
 
-    // unlikely for this condition not to be met
-    if (const CallNode* call = legalized.as<CallNode>()) {
-      // if the original op is not pure, don't wrap
-      if (!(purity_map.count(op) && purity_map[op]->value)) {
+    const CallNode* call = legalized.as<CallNode>();
+
+    if (!call) {
+      // Unlikely for this condition to be met, but it is possible.
+      // For example, an operation could produce a Tuple output, and
+      // be legalized into separate calls for each item in the Tuple.
+      return false;
+    }
+
+    bool pure_original_op = purity_map.get(op, Bool(false))->value;
+    bool pure_legalized_op = [&]() -> bool {
+      if (auto legalized_op = call->op.as<Op>()) {
+        return purity_map.get(legalized_op.value(), Bool(false))->value;
+      } else if (auto func_sinfo = 
call->op->struct_info_.as<FuncStructInfoNode>()) {
+        return func_sinfo->purity;
+      } else {
         return false;
       }
-      if (const OpNode* call_op = call->op.as<OpNode>()) {
-        auto res_op = GetRef<Op>(call_op);
-        if (purity_map.count(res_op)) {
-          // if the legalized op is already pure, we *don't* need a wrapper
-          return !purity_map[res_op]->value;
-        }
-      }
-      // simplest case: wrap if the original op was pure and the result is 
somehow not
-      return true;
-    }
-    return false;
+    }();
+
+    // If the original op was pure, but the legalized op was not,
+    // the legalized op may occur in a context that requires pure
+    // functions, such as a `relax::DataflowBlock`.  In this case,
+    // we should wrap the legalized operation to indicate that it is
+    // still pure.
+    return pure_original_op && !pure_legalized_op;
   }
 
   Call WrapPureCall(const Call& ret) {
@@ -148,6 +157,7 @@ class LegalizeMutator : public ExprMutator {
   Expr VisitExpr_(const CallNode* call) final {
     Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call));
     static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
+    static const auto& requires_arg_shapes_map = 
Op::GetAttrMap<Bool>("RequiresArgumentShapes");
     static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed");
     static const Op& call_tir_op = Op::Get("relax.call_tir");
     static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
@@ -157,17 +167,72 @@ class LegalizeMutator : public ExprMutator {
     if (op_node == nullptr) {
       return visited_call;
     }
-
     auto op = GetRef<Op>(op_node);
-    std::string op_name(op->name);
-    bool is_data_dependent_op = (op_name.find("dynamic") != std::string::npos);
-    // Not all shape values are known
-    // Data-dependent ops are exception since their output shape will be 
identified at runtime.
-    // Legalizer will insert their shape functions, which are manually 
registered, and match cast
-    // to define symbolic output shape at compile time.
-    if (!std::all_of(visited_call->args.begin(), visited_call->args.end(),
-                     [](Expr arg) { return 
KnowAllShapeValues(GetStructInfo(arg)); }) ||
-        (!is_data_dependent_op && 
!KnowAllShapeValues(GetStructInfo(visited_call)))) {
+
+    bool can_legalize = [&]() -> bool {
+      bool requires_arg_shapes = requires_arg_shapes_map.get(op, 
Bool(true))->value;
+      if (!requires_arg_shapes) {
+        // This operator does not require its arguments to have a
+        // known shape/dtype.  For example, the "relax.tensor_ndim"
+        // operator can output the dimensionality of a tensor at
+        // runtime, and does not require the dimensionality to be
+        // known at compile-time.
+        return true;
+      }
+
+      bool arg_shapes_defined =
+          std::all_of(visited_call->args.begin(), visited_call->args.end(),
+                      [](Expr arg) { return 
KnowAllShapeValues(GetStructInfo(arg)); });
+      if (!arg_shapes_defined) {
+        // This operator cannot be legalized, because legalization
+        // requires the argument shapes to be known.
+        //
+        // TODO(Lunderberg):
+        //
+        //     Improve this fallback case, as failure to legalize can
+        //     produce unexpected errors during CodeGenVM.  This could
+        //     be done by having `R.Tensor(ndim=2)` be syntactic sugar
+        //     for `R.Tensor(shape=[m, n])`, where `m` and `n` are new
+        //     shape variables.  This would allow legalization into
+        //     dynamic TIR PrimFuncs.
+        //
+        //     This fallback would only be applicable for cases where
+        //     both the dtype and the dimensionality are known.  While
+        //     Relax can express a tensor with unknown dtype and
+        //     dimensionality as `TensorStructInfo(DataType::Void(),
+        //     kUnknownNDim)`, TIR cannot express unknown dtype or
+        //     unknown dimensionality.
+        return false;
+      }
+
+      std::string op_name(op->name);
+      bool is_data_dependent_op = (op_name.find("dynamic") != 
std::string::npos);
+      bool ret_shape_defined = KnowAllShapeValues(GetStructInfo(visited_call));
+      if (!is_data_dependent_op && !ret_shape_defined) {
+        // This operator cannot be legalized, because legalization by
+        // default requires the output shape.  The exception is
+        // data-dependent operators (e.g. `R.dynamic_strided_slice`),
+        // where the shape of the output depends on the runtime values
+        // stored in a tensor.
+        //
+        // For data-dependent ops, the output shape will be identified
+        // at runtime.  The Legalizer will insert their shape
+        // functions, which are manually registered for each
+        // data-dependent op, and match cast to define symbolic output
+        // shapes.  These symbolic output shapes at compile time can
+        // be by later operations to refer to the runtime shape.
+        //
+        // TODO(Lunderberg): Make a new operator attribute
+        // `.set_attr<Bool>("DataDependent")`, rather than relying on
+        // the name of the operator.
+        return false;
+      }
+
+      // All checks pass, this operator can be legalized.
+      return true;
+    }();
+
+    if (!can_legalize) {
       return visited_call;
     }
 
diff --git a/tests/python/relax/test_op_unpack.py 
b/tests/python/relax/test_op_unpack.py
new file mode 100644
index 0000000000..03e4e0fc85
--- /dev/null
+++ b/tests/python/relax/test_op_unpack.py
@@ -0,0 +1,127 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm.testing
+
+from tvm import relax
+from tvm.ir import Op
+from tvm.script import ir as I, relax as R
+
+# Parameterization for reading dtype of DLTensor.  Chosen to have
+# multiple distinct type codes, number of lanes, and widths.
+dtype = tvm.testing.parameter(
+    "int32",
+    "int64",
+    "float32",
+    "float32x4",
+    "bfloat",
+    "e4m3_float8",
+)
+shape = tvm.testing.parameter(
+    [],
+    [16],
+    [128, 256],
+    [1] * 64,
+)
+
+
+def test_tensor_dtype_code(dtype):
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(A: R.Tensor):
+            return A.dtype.type_code
+
+    built = relax.build(mod)
+    vm = relax.VirtualMachine(built, tvm.cpu())
+
+    arg = tvm.nd.empty([16], dtype)
+    res = vm["main"](arg)
+
+    expected_type_code = tvm.runtime.DataType(dtype).type_code
+    assert res == expected_type_code
+
+
+def test_tensor_dtype_bits(dtype):
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(A: R.Tensor):
+            return A.dtype.bits
+
+    built = relax.build(mod)
+    vm = relax.VirtualMachine(built, tvm.cpu())
+
+    arg = tvm.nd.empty([16], dtype)
+    res = vm["main"](arg)
+
+    expected_type_bits = tvm.runtime.DataType(dtype).bits
+    assert res == expected_type_bits
+
+
+def test_tensor_dtype_lanes(dtype):
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(A: R.Tensor):
+            return A.dtype.lanes
+
+    built = relax.build(mod)
+    vm = relax.VirtualMachine(built, tvm.cpu())
+
+    arg = tvm.nd.empty([16], dtype)
+    res = vm["main"](arg)
+
+    expected_type_lanes = tvm.runtime.DataType(dtype).lanes
+    assert res == expected_type_lanes
+
+
+def test_tensor_ndim(shape):
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(A: R.Tensor):
+            return A.ndim
+
+    built = relax.build(mod)
+    vm = relax.VirtualMachine(built, tvm.cpu())
+
+    arg = tvm.nd.empty(shape, "int32")
+    res = vm["main"](arg)
+
+    assert res == len(shape)
+
+
+def test_tensor_shape(shape):
+    @I.ir_module
+    class mod:
+        @R.function
+        def main(A: R.Tensor, axis: R.Prim("int64")):
+            return A.shape[axis]
+
+    built = relax.build(mod)
+    vm = relax.VirtualMachine(built, tvm.cpu())
+
+    arg = tvm.nd.empty(shape, "int32")
+
+    res = [vm["main"](arg, i) for i, _ in enumerate(shape)]
+
+    tvm.ir.assert_structural_equal(res, shape)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to