This is an automated email from the ASF dual-hosted git repository. lunderberg pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new b21855758e [Relax] Implement operators to read runtime DLTensor* information (#16563) b21855758e is described below commit b21855758e40057e1b4d7f10410ed7bfb36aa808 Author: Eric Lunderberg <lunderb...@users.noreply.github.com> AuthorDate: Tue Feb 20 14:59:13 2024 -0600 [Relax] Implement operators to read runtime DLTensor* information (#16563) Relax is capable of expressing tensors whose element type is unknown. However, these must typically be replaced with a known dtype prior to compilation, as most operators require known data types prior to legalization. This can be done by using a `relax::MatchCast` node, such as accepting a parameter `arg: R.Tensor([16,16])`, then defining the dtype using `R.match_cast(arg, R.Tensor([16,16],'float16'))`. However, using a `R.match_cast` node requires knowing which data type should be used in the new `R.Tensor`, and raises an error for an incorrect data type. If an argument may be one of two distinct data types, `R.match_cast` cannot be used to check which data type is in use. This commit adds Relax operators to read the runtime values of a `DLTensor*` argument. These can be be used to normalize arguments prior to a compute step. For example, pre-processing a model weight that may be provided in either `float16` or `bfloat16` format. --- python/tvm/relax/expr.py | 186 +++++++++++++++++++ src/relax/op/tensor/inspect.cc | 351 +++++++++++++++++++++++++++++++++++ src/relax/op/tensor/inspect.h | 92 +++++++++ src/relax/transform/legalize_ops.cc | 115 +++++++++--- tests/python/relax/test_op_unpack.py | 127 +++++++++++++ 5 files changed, 846 insertions(+), 25 deletions(-) diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index c9780bea7e..12f08f4dbf 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -244,6 +244,192 @@ class ExprWithOp(Expr, Scriptable): raise IndexError from err raise + def _check_for_tensor_struct_info(self): + """Raise an error if this is something other than a Tensor + + Used for early checks in `expr.dtype` and `expr.shape` + accessors. While invalid usage would cause errors to be + raised during shape inference, an earlier check makes it + easier to find the invalid usage. + """ + if self.struct_info_ is None: + return + + if not isinstance(self.struct_info_, tvm.relax.TensorStructInfo): + raise TypeError( + f"Runtime unpacking of DLDataType is only implemented for tensors, " + f"but was applied to object {self} of type {type(self)}." + ) + + @property + def dtype(self) -> "_DLTensorDTypeProxy": + """Returns a proxy object for accessing DLTensor::dtype""" + self._check_for_tensor_struct_info() + return _DLTensorDTypeProxy(self) + + @property + def ndim(self) -> "Expr": + """Returns the runtime value of DLTensor::ndim""" + self._check_for_tensor_struct_info() + op = tvm.ir.Op.get("relax.inspect.tensor_ndim") + return tvm.relax.Call(op, [self]) + + @property + def shape(self) -> "_DLTensorShapeProxy": + """Returns a proxy object for accessing DLTensor::shape""" + self._check_for_tensor_struct_info() + return _DLTensorShapeProxy(self) + + +class _DLTensorDTypeProxy(tvm.runtime.ObjectGeneric): + """A proxy object for unpacking DLDatatype from DLTensor + + Exposes accessors for `DLDataType` fields `type_code`, `lanes`, + and `bits` within a `DLTensor::dtype`. Accessing these fields + will produce `relax.Call` expressions, representing the field's + runtime value. If the datatype of the tensor is known at + compile-time, the `relax.Call` will be normalized into a + `relax.PrimValue`, with no runtime cost. + + Parameters + ---------- + tensor: relax.Expr + + The relax tensor (or a variable referring to a relax tensor), + whose runtime shape is being inspected. + + """ + + def __init__(self, tensor): + self.tensor = tensor + + def asobject(self): + """Provide expected in error message + + This method is called when `_DLTensorDTypeProxy` is used in a + context that requires a `relax.Expr`. This usage is not + supported, and raising an error here can provide suggested + fixes that are not present in the default error message from + `tvm.runtime.convert_to_object`. + """ + + fields = [f"{self.tensor}.dtype.{field}" for field in ["type_code", "bits", "lanes"]] + raise TypeError( + f"{self.tensor}.dtype cannot be converted to a relax expression, " + f"and should be used as a proxy object to access " + f"fields {fields}" + ) + + @property + def type_code(self) -> Expr: + """Accessor for the DLDataType::bits field + + Returns + ------- + type_code: Expr + + The type code of the DLTensor. See the `DLDeviceType` + enum in `dlpack.h` for more information. + """ + op = tvm.ir.Op.get("relax.inspect.tensor_dtype_code") + return tvm.relax.Call(op, [self.tensor]) + + @property + def lanes(self) -> Expr: + """Accessor for the DLDataType::bits field + + Returns + ------- + lanes: Expr + + The number of lanes in the DLDataType + """ + op = tvm.ir.Op.get("relax.inspect.tensor_dtype_lanes") + return tvm.relax.Call(op, [self.tensor]) + + @property + def bits(self) -> Expr: + """Accessor for the DLDataType::bits field + + Returns + ------- + bits: Expr + + The number of bits in the DLDataType + """ + op = tvm.ir.Op.get("relax.inspect.tensor_dtype_bits") + return tvm.relax.Call(op, [self.tensor]) + + +class _DLTensorShapeProxy(tvm.runtime.ObjectGeneric): + """A proxy object for unpacking the shape from DLTensor + + Exposes accessors for the `DLTensor::shape` field. Accessing + these fields will produce `relax.Call` expressions, representing + the field's runtime value. If the datatype of the tensor is known + at compile-time, the `relax.Call` will be normalized into a + `relax.PrimValue`, with no runtime cost. + + Parameters + ---------- + tensor: relax.Expr + + The relax tensor (or a variable referring to a relax tensor), + whose runtime shape is being inspected. + """ + + def __init__(self, tensor): + self.tensor = tensor + + def asobject(self): + """Provide expected in error message + + This method is called when `_DLTensorShapeProxy` is used in a + context that requires a `relax.Expr`. This usage is not + supported, and raising an error here can provide suggested + fixes that are not present in the default error message from + `tvm.runtime.convert_to_object`. + """ + raise TypeError( + f"{self.tensor}.shape cannot be converted to a relax expression, " + f"and should be used as a proxy object to access the runtime shape of the DLTensor. " + f"The DLTensor::ndim field can be accessed as len({self.tensor}), " + f"and the DLTensor::shape array can be accessed as {self.tensor}.shape[i]" + ) + + def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr: + """Returns the extent of a tensor axis + + Parameters + ---------- + axis: Union[int, PrimExpr, Expr] + + The tensor axis whose extent should be returned. For ease + of use, any python integers or TIR expressions are + converted to `relax.Expr`. + + Returns + ------- + extent: Expr + + The extent of the tensor's axis. + """ + + if not isinstance(axis, tvm.relax.Expr): + axis = tvm.relax.PrimValue(axis) + + if axis.struct_info_ is not None and not isinstance( + axis.struct_info_, tvm.relax.PrimStructInfo + ): + raise TypeError( + f"The index used to access {self.tensor}.shape " + f'must have struct info R.Prim("int64"), ' + f"but index {axis} had struct info {axis.struct_info_}." + ) + + op = tvm.ir.Op.get("relax.inspect.tensor_shape_i") + return tvm.relax.Call(op, [self.tensor, axis]) + @tvm._ffi.register_object("relax.expr.Call") class Call(ExprWithOp): diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc new file mode 100644 index 0000000000..a40b2af5ef --- /dev/null +++ b/src/relax/op/tensor/inspect.cc @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file inspect.cc + * \brief Operators to access runtime DLTensor parameters + */ + +#include "inspect.h" + +#include <tvm/relax/op_attr_types.h> +#include <tvm/tir/builtin.h> +#include <tvm/tir/function.h> +#include <tvm/tir/op.h> + +namespace tvm { +namespace relax { +namespace inspect { + +TensorStructInfo GetTensorArgInfo(const Call& call) { + CHECK_EQ(call->args.size(), 1) << "TypeError: " + << "Operator " << call->op << " expects one argument, " + << "but received " << call->args.size() + << " arguments: " << call->args; + + const auto& arg = call->args[0]; + auto sinfo = GetStructInfo(arg); + + auto tensor_sinfo = sinfo.as<TensorStructInfo>(); + CHECK(tensor_sinfo) << "TypeError: " + << "Operator " << call->op << " expects a tensor argument, " + << "but argument " << arg << " has struct info " << sinfo; + + return tensor_sinfo.value(); +} + +DataType GetTensorDataType(const Call& call) { return GetTensorArgInfo(call)->dtype; } + +tir::PrimFunc GetDLTensorField(tir::builtin::TVMStructFieldKind field, DataType field_dtype) { + tir::Var dlpack_handle("dlpack_handle", DataType::Handle()); + + tir::Var value("value", field_dtype); + + tir::LetStmt body( + value, + tir::Call(field_dtype, tir::builtin::tvm_struct_get(), + {dlpack_handle, IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), field)}), + tir::Evaluate(tvm::ret(value))); + + DictAttrs attrs({{"tir.is_scheduled", Bool(true)}, {"tir.is_host", Bool(true)}}); + + tir::PrimFunc func(Array<tir::Var>{dlpack_handle}, body, PrimType(field_dtype), {}, attrs); + + FuncStructInfo sinfo({TensorStructInfo(DataType::Void(), kUnknownNDim)}, + PrimStructInfo(field_dtype)); + UpdateStructInfo(func, sinfo); + + return func; +} + +Expr NormalizeToKnownPrimValue(const BlockBuilder&, Call call) { + if (auto prim_sinfo = call->struct_info_.as<PrimStructInfoNode>()) { + if (prim_sinfo->value.defined()) { + return PrimValue(prim_sinfo->value.value()); + } + } + return call; +} + +//// relax.tensor_dtype_code + +Expr tensor_dtype_code(Expr expr) { + static const Op& op = Op::Get("relax.inspect.tensor_dtype_code"); + return Call(op, {expr}); +} + +StructInfo InferStructInfoTensorDtypeCode(const Call& call, const BlockBuilder&) { + auto dlpack_type = DataType::UInt(8); + + DataType dtype = GetTensorDataType(call); + if (dtype.is_void()) { + return PrimStructInfo(dlpack_type); + } else { + return PrimStructInfo(IntImm(dlpack_type, dtype.code())); + } +} + +Expr LegalizeTensorDtypeCode(const BlockBuilder& bb, const Call& call) { + auto field_dtype = Downcast<PrimStructInfo>(call->struct_info_)->dtype; + + Expr arg = call->args[0]; + tir::PrimFunc getter = + GetDLTensorField(tir::builtin::TVMStructFieldKind::kArrTypeCode, field_dtype); + + GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_code"); + return Call(gvar_getter, {arg}); +} + +TVM_REGISTER_OP("relax.inspect.tensor_dtype_code") + .set_num_inputs(1) + .add_argument("tensor", "Tensor", "The tensor to be inspected") + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTensorDtypeCode) + .set_attr<FLegalize>("FLegalize", LegalizeTensorDtypeCode) + .set_attr<Bool>("RequiresArgumentShapes", Bool(false)) + .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue) + .set_attr<Bool>("FPurity", Bool(true)); + +//// relax.tensor_dtype_bits + +Expr tensor_dtype_bits(Expr expr) { + static const Op& op = Op::Get("relax.inspect.tensor_dtype_bits"); + return Call(op, {expr}); +} + +StructInfo InferStructInfoTensorDtypeBits(const Call& call, const BlockBuilder&) { + auto dlpack_type = DataType::UInt(8); + + DataType dtype = GetTensorDataType(call); + if (dtype.is_void()) { + return PrimStructInfo(dlpack_type); + } else { + return PrimStructInfo(IntImm(dlpack_type, dtype.bits())); + } +} + +Expr LegalizeTensorDtypeBits(const BlockBuilder& bb, const Call& call) { + auto field_dtype = Downcast<PrimStructInfo>(call->struct_info_)->dtype; + + Expr arg = call->args[0]; + tir::PrimFunc getter = + GetDLTensorField(tir::builtin::TVMStructFieldKind::kArrTypeBits, field_dtype); + + GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_bits"); + return Call(gvar_getter, {arg}); +} + +TVM_REGISTER_OP("relax.inspect.tensor_dtype_bits") + .set_num_inputs(1) + .add_argument("tensor", "Tensor", "The tensor to be inspected") + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTensorDtypeBits) + .set_attr<FLegalize>("FLegalize", LegalizeTensorDtypeBits) + .set_attr<Bool>("RequiresArgumentShapes", Bool(false)) + .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue) + .set_attr<Bool>("FPurity", Bool(true)); + +//// relax.tensor_dtype_lanes + +Expr tensor_dtype_lanes(Expr expr) { + static const Op& op = Op::Get("relax.inspect.tensor_dtype_lanes"); + return Call(op, {expr}); +} + +StructInfo InferStructInfoTensorDtypeLanes(const Call& call, const BlockBuilder&) { + auto dlpack_type = DataType::UInt(16); + + DataType dtype = GetTensorDataType(call); + if (dtype.is_void()) { + return PrimStructInfo(dlpack_type); + } else { + return PrimStructInfo(IntImm(dlpack_type, dtype.lanes())); + } +} + +Expr LegalizeTensorDtypeLanes(const BlockBuilder& bb, const Call& call) { + auto field_dtype = Downcast<PrimStructInfo>(call->struct_info_)->dtype; + + Expr arg = call->args[0]; + tir::PrimFunc getter = + GetDLTensorField(tir::builtin::TVMStructFieldKind::kArrTypeLanes, field_dtype); + + GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_lanes"); + return Call(gvar_getter, {arg}); +} + +TVM_REGISTER_OP("relax.inspect.tensor_dtype_lanes") + .set_num_inputs(1) + .add_argument("tensor", "Tensor", "The tensor to be inspected") + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTensorDtypeLanes) + .set_attr<FLegalize>("FLegalize", LegalizeTensorDtypeLanes) + .set_attr<Bool>("RequiresArgumentShapes", Bool(false)) + .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue) + .set_attr<Bool>("FPurity", Bool(true)); + +//// relax.tensor_ndim + +Expr tensor_ndim(Expr expr) { + static const Op& op = Op::Get("relax.inspect.tensor_ndim"); + return Call(op, {expr}); +} + +StructInfo InferStructInfoTensorNDim(const Call& call, const BlockBuilder&) { + auto dlpack_type = DataType::Int(32); + + auto sinfo = GetTensorArgInfo(call); + if (sinfo->IsUnknownNdim()) { + return PrimStructInfo(dlpack_type); + } else { + return PrimStructInfo(IntImm(dlpack_type, sinfo->ndim)); + } +} + +Expr LegalizeTensorNDim(const BlockBuilder& bb, const Call& call) { + auto field_dtype = Downcast<PrimStructInfo>(call->struct_info_)->dtype; + + Expr arg = call->args[0]; + tir::PrimFunc getter = GetDLTensorField(tir::builtin::TVMStructFieldKind::kArrNDim, field_dtype); + + GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_ndim"); + return Call(gvar_getter, {arg}); +} + +TVM_REGISTER_OP("relax.inspect.tensor_ndim") + .set_num_inputs(1) + .add_argument("tensor", "Tensor", "The tensor to be inspected") + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTensorNDim) + .set_attr<FLegalize>("FLegalize", LegalizeTensorNDim) + .set_attr<Bool>("RequiresArgumentShapes", Bool(false)) + .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue) + .set_attr<Bool>("FPurity", Bool(true)); + +//// relax.tensor_shape_i + +Expr tensor_shape_i(Expr expr) { + static const Op& op = Op::Get("relax.inspect.tensor_shape_i"); + return Call(op, {expr}); +} + +StructInfo InferStructInfoTensorShape(const Call& call, const BlockBuilder&) { + auto dlpack_type = DataType::Int(64); + + CHECK_EQ(call->args.size(), 2) << "TypeError: " + << "Operator " << call->op << " expects two arguments, " + << "but received " << call->args.size() + << " arguments: " << call->args; + const auto& arg = call->args[0]; + const auto& axis = call->args[1]; + + auto tensor_sinfo = arg->struct_info_.as<TensorStructInfoNode>(); + CHECK(tensor_sinfo) << "TypeError: " + << "Operator " << call->op << " expects arguments (tensor, axis), " + << "but the first argument " << arg << " in expression " << call + << " has struct info " << arg->struct_info_; + + auto axis_sinfo = axis->struct_info_.as<PrimStructInfoNode>(); + CHECK(axis_sinfo) << "TypeError: " + << "Operator " << call->op << " expects arguments (tensor, axis), " + << "but the second argument " << arg << " in expression " << call + << " has struct info " << axis->struct_info_; + + auto int_imm_axis = axis_sinfo->value.as<IntImmNode>(); + + if (int_imm_axis) { + CHECK_GE(int_imm_axis->value, 0); + } + if (int_imm_axis && !tensor_sinfo->IsUnknownNdim()) { + CHECK_LT(int_imm_axis->value, tensor_sinfo->ndim) + << "ValueError: " + << "Expression " << call << " attempts to access " << arg << ".shape[" + << int_imm_axis->value << "]" + << ", but " << arg << ".shape only has " << tensor_sinfo->ndim << " elements"; + } + + auto tensor_shape = tensor_sinfo->GetShape(); + if (int_imm_axis && tensor_shape.defined()) { + return PrimStructInfo(tensor_shape.value()[int_imm_axis->value]); + } else { + return PrimStructInfo(dlpack_type); + } +} + +Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) { + auto field_dtype = Downcast<PrimStructInfo>(call->struct_info_)->dtype; + + tir::PrimFunc getter = [&]() -> tir::PrimFunc { + tir::Var dlpack_handle("dlpack_handle", DataType::Handle()); + tir::Var axis("axis", DataType::Int(64)); + + tir::Var ndim("ndim", DataType::Int(32)); + + tir::Buffer shape_buffer = tir::decl_buffer({ndim}, field_dtype, "shape"); + + tir::Var extent("extent", field_dtype); + + tir::Stmt body = tir::Evaluate(tvm::ret(extent)); + + body = tir::LetStmt(extent, tir::BufferLoad(shape_buffer, {axis}), body); + body = tir::DeclBuffer(shape_buffer, body); + body = tir::LetStmt( + shape_buffer->data, + tir::Call(DataType::Handle(), tir::builtin::tvm_struct_get(), + {dlpack_handle, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), tir::builtin::TVMStructFieldKind::kArrShape)}), + body); + + body = tir::AssertStmt( + axis < tvm::cast(axis->dtype, ndim), + tir::StringImm("Specified axis may not be larger than the tensor's dimensionality"), body); + + body = tir::LetStmt( + ndim, + tir::Call(ndim->dtype, tir::builtin::tvm_struct_get(), + {dlpack_handle, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), tir::builtin::TVMStructFieldKind::kArrNDim)}), + body); + + body = tir::AssertStmt(0 <= axis, tir::StringImm("Specified axis may not be negative"), body); + + DictAttrs attrs({{"tir.is_scheduled", Bool(true)}, {"tir.is_host", Bool(true)}}); + + tir::PrimFunc func({dlpack_handle, axis}, body, PrimType(field_dtype), {}, attrs); + + FuncStructInfo sinfo( + {TensorStructInfo(DataType::Void(), kUnknownNDim), PrimStructInfo(axis->dtype)}, + PrimStructInfo(field_dtype)); + UpdateStructInfo(func, sinfo); + return func; + }(); + + GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_shape_i"); + return Call(gvar_getter, call->args); +} + +TVM_REGISTER_OP("relax.inspect.tensor_shape_i") + .set_num_inputs(2) + .add_argument("tensor", "Tensor", "The tensor to be inspected") + .add_argument("axis", "Prim(int64)", "The axis whose extent should be returned") + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTensorShape) + .set_attr<FLegalize>("FLegalize", LegalizeTensorShape) + .set_attr<Bool>("RequiresArgumentShapes", Bool(false)) + .set_attr<FNormalize>("FNormalize", NormalizeToKnownPrimValue) + .set_attr<Bool>("FPurity", Bool(true)); + +} // namespace inspect +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/inspect.h b/src/relax/op/tensor/inspect.h new file mode 100644 index 0000000000..0225b00fb3 --- /dev/null +++ b/src/relax/op/tensor/inspect.h @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. Sex The NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. Sex The License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file inspect.h + * \brief Operators to access runtime DLTensor parameters + */ +#ifndef TVM_RELAX_OP_TENSOR_INSPECT_H_ +#define TVM_RELAX_OP_TENSOR_INSPECT_H_ + +#include <tvm/relax/expr.h> + +namespace tvm { +namespace relax { +namespace inspect { + +/* \brief Return the DLTensor::dtype::type_code field + * + * \param expr The relax expression to be inspected. Must have + * `TensorStructInfo`. + * + * \returns The uint8_t value of the type_code, with + * `PrimStructInfo(DataType::UInt(8))` + */ +Expr tensor_dtype_code(Expr expr); + +/* \brief Return the DLTensor::dtype::bits field + * + * \param expr The relax expression to be inspected. Must have + * `TensorStructInfo`. + * + * \returns The uint8_t value of the number of bits, with + * `PrimStructInfo(DataType::UInt(8))`. For vectorized types, returns + * the bit width of the underlying scalar type (e.g. 32 for + * "float32x4", not 128). + */ +Expr tensor_dtype_bits(Expr expr); + +/* \brief Return the DLTensor::dtype::lanes field + * + * \param expr The relax expression to be inspected. Must have + * `TensorStructInfo`. + * + * \returns The uint16_t value of the number of lanes, with + * `PrimStructInfo(DataType::UInt(16))` + */ +Expr tensor_dtype_lanes(Expr expr); + +/* \brief Return the DLTensor::ndim field + * + * \param expr The relax expression to be inspected. Must have + * `TensorStructInfo`. + * + * \returns The int32_t value of the dimensionality, with + * `PrimStructInfo(DataType::Int(32))`. + */ +Expr tensor_ndim(Expr expr); + +/* \brief Return the DLTensor::shape[i] field + * + * \param expr The relax expression to be inspected. Must have + * `TensorStructInfo`. + * + * \param axis The axis to inspect. Must be within the range `0 <= + * axis < tensor_ndim(expr)`, or else the results are undefined. + * + * \returns The int64_t extent of the specified tensor axis, with + * `PrimStructInfo(DataType::Int(64))`. + */ +Expr tensor_shape_i(Expr expr, Expr axis); + +} // namespace inspect +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_INSPECT_H_ diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index c8fba59dab..343c18acd7 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -90,23 +90,32 @@ class LegalizeMutator : public ExprMutator { bool WrapPureCondition(const Op& op, const Expr& legalized) { static const auto& purity_map = Op::GetAttrMap<Bool>("FPurity"); - // unlikely for this condition not to be met - if (const CallNode* call = legalized.as<CallNode>()) { - // if the original op is not pure, don't wrap - if (!(purity_map.count(op) && purity_map[op]->value)) { + const CallNode* call = legalized.as<CallNode>(); + + if (!call) { + // Unlikely for this condition to be met, but it is possible. + // For example, an operation could produce a Tuple output, and + // be legalized into separate calls for each item in the Tuple. + return false; + } + + bool pure_original_op = purity_map.get(op, Bool(false))->value; + bool pure_legalized_op = [&]() -> bool { + if (auto legalized_op = call->op.as<Op>()) { + return purity_map.get(legalized_op.value(), Bool(false))->value; + } else if (auto func_sinfo = call->op->struct_info_.as<FuncStructInfoNode>()) { + return func_sinfo->purity; + } else { return false; } - if (const OpNode* call_op = call->op.as<OpNode>()) { - auto res_op = GetRef<Op>(call_op); - if (purity_map.count(res_op)) { - // if the legalized op is already pure, we *don't* need a wrapper - return !purity_map[res_op]->value; - } - } - // simplest case: wrap if the original op was pure and the result is somehow not - return true; - } - return false; + }(); + + // If the original op was pure, but the legalized op was not, + // the legalized op may occur in a context that requires pure + // functions, such as a `relax::DataflowBlock`. In this case, + // we should wrap the legalized operation to indicate that it is + // still pure. + return pure_original_op && !pure_legalized_op; } Call WrapPureCall(const Call& ret) { @@ -148,6 +157,7 @@ class LegalizeMutator : public ExprMutator { Expr VisitExpr_(const CallNode* call) final { Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call)); static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize"); + static const auto& requires_arg_shapes_map = Op::GetAttrMap<Bool>("RequiresArgumentShapes"); static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed"); static const Op& call_tir_op = Op::Get("relax.call_tir"); static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); @@ -157,17 +167,72 @@ class LegalizeMutator : public ExprMutator { if (op_node == nullptr) { return visited_call; } - auto op = GetRef<Op>(op_node); - std::string op_name(op->name); - bool is_data_dependent_op = (op_name.find("dynamic") != std::string::npos); - // Not all shape values are known - // Data-dependent ops are exception since their output shape will be identified at runtime. - // Legalizer will insert their shape functions, which are manually registered, and match cast - // to define symbolic output shape at compile time. - if (!std::all_of(visited_call->args.begin(), visited_call->args.end(), - [](Expr arg) { return KnowAllShapeValues(GetStructInfo(arg)); }) || - (!is_data_dependent_op && !KnowAllShapeValues(GetStructInfo(visited_call)))) { + + bool can_legalize = [&]() -> bool { + bool requires_arg_shapes = requires_arg_shapes_map.get(op, Bool(true))->value; + if (!requires_arg_shapes) { + // This operator does not require its arguments to have a + // known shape/dtype. For example, the "relax.tensor_ndim" + // operator can output the dimensionality of a tensor at + // runtime, and does not require the dimensionality to be + // known at compile-time. + return true; + } + + bool arg_shapes_defined = + std::all_of(visited_call->args.begin(), visited_call->args.end(), + [](Expr arg) { return KnowAllShapeValues(GetStructInfo(arg)); }); + if (!arg_shapes_defined) { + // This operator cannot be legalized, because legalization + // requires the argument shapes to be known. + // + // TODO(Lunderberg): + // + // Improve this fallback case, as failure to legalize can + // produce unexpected errors during CodeGenVM. This could + // be done by having `R.Tensor(ndim=2)` be syntactic sugar + // for `R.Tensor(shape=[m, n])`, where `m` and `n` are new + // shape variables. This would allow legalization into + // dynamic TIR PrimFuncs. + // + // This fallback would only be applicable for cases where + // both the dtype and the dimensionality are known. While + // Relax can express a tensor with unknown dtype and + // dimensionality as `TensorStructInfo(DataType::Void(), + // kUnknownNDim)`, TIR cannot express unknown dtype or + // unknown dimensionality. + return false; + } + + std::string op_name(op->name); + bool is_data_dependent_op = (op_name.find("dynamic") != std::string::npos); + bool ret_shape_defined = KnowAllShapeValues(GetStructInfo(visited_call)); + if (!is_data_dependent_op && !ret_shape_defined) { + // This operator cannot be legalized, because legalization by + // default requires the output shape. The exception is + // data-dependent operators (e.g. `R.dynamic_strided_slice`), + // where the shape of the output depends on the runtime values + // stored in a tensor. + // + // For data-dependent ops, the output shape will be identified + // at runtime. The Legalizer will insert their shape + // functions, which are manually registered for each + // data-dependent op, and match cast to define symbolic output + // shapes. These symbolic output shapes at compile time can + // be by later operations to refer to the runtime shape. + // + // TODO(Lunderberg): Make a new operator attribute + // `.set_attr<Bool>("DataDependent")`, rather than relying on + // the name of the operator. + return false; + } + + // All checks pass, this operator can be legalized. + return true; + }(); + + if (!can_legalize) { return visited_call; } diff --git a/tests/python/relax/test_op_unpack.py b/tests/python/relax/test_op_unpack.py new file mode 100644 index 0000000000..03e4e0fc85 --- /dev/null +++ b/tests/python/relax/test_op_unpack.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm.testing + +from tvm import relax +from tvm.ir import Op +from tvm.script import ir as I, relax as R + +# Parameterization for reading dtype of DLTensor. Chosen to have +# multiple distinct type codes, number of lanes, and widths. +dtype = tvm.testing.parameter( + "int32", + "int64", + "float32", + "float32x4", + "bfloat", + "e4m3_float8", +) +shape = tvm.testing.parameter( + [], + [16], + [128, 256], + [1] * 64, +) + + +def test_tensor_dtype_code(dtype): + @I.ir_module + class mod: + @R.function + def main(A: R.Tensor): + return A.dtype.type_code + + built = relax.build(mod) + vm = relax.VirtualMachine(built, tvm.cpu()) + + arg = tvm.nd.empty([16], dtype) + res = vm["main"](arg) + + expected_type_code = tvm.runtime.DataType(dtype).type_code + assert res == expected_type_code + + +def test_tensor_dtype_bits(dtype): + @I.ir_module + class mod: + @R.function + def main(A: R.Tensor): + return A.dtype.bits + + built = relax.build(mod) + vm = relax.VirtualMachine(built, tvm.cpu()) + + arg = tvm.nd.empty([16], dtype) + res = vm["main"](arg) + + expected_type_bits = tvm.runtime.DataType(dtype).bits + assert res == expected_type_bits + + +def test_tensor_dtype_lanes(dtype): + @I.ir_module + class mod: + @R.function + def main(A: R.Tensor): + return A.dtype.lanes + + built = relax.build(mod) + vm = relax.VirtualMachine(built, tvm.cpu()) + + arg = tvm.nd.empty([16], dtype) + res = vm["main"](arg) + + expected_type_lanes = tvm.runtime.DataType(dtype).lanes + assert res == expected_type_lanes + + +def test_tensor_ndim(shape): + @I.ir_module + class mod: + @R.function + def main(A: R.Tensor): + return A.ndim + + built = relax.build(mod) + vm = relax.VirtualMachine(built, tvm.cpu()) + + arg = tvm.nd.empty(shape, "int32") + res = vm["main"](arg) + + assert res == len(shape) + + +def test_tensor_shape(shape): + @I.ir_module + class mod: + @R.function + def main(A: R.Tensor, axis: R.Prim("int64")): + return A.shape[axis] + + built = relax.build(mod) + vm = relax.VirtualMachine(built, tvm.cpu()) + + arg = tvm.nd.empty(shape, "int32") + + res = [vm["main"](arg, i) for i, _ in enumerate(shape)] + + tvm.ir.assert_structural_equal(res, shape) + + +if __name__ == "__main__": + tvm.testing.main()