This is an automated email from the ASF dual-hosted git repository. jwfromm pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 0e4c99cdb8 [Unity][ONNX] Improved symbolic handling and reshape functionality (#15550) 0e4c99cdb8 is described below commit 0e4c99cdb84d51a8fccd2fde9bbad1fbf2e33e73 Author: Josh Fromm <jwfr...@octoml.ai> AuthorDate: Tue Aug 15 09:03:49 2023 -0700 [Unity][ONNX] Improved symbolic handling and reshape functionality (#15550) * Improved symbolic handling and reshape functionality * retrigger ci --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 47 ++++++++++++++++++++++++- src/relax/op/tensor/manipulate.cc | 12 +++++-- tests/python/relax/test_op_manipulate.py | 3 ++ 3 files changed, 58 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 2ef6121002..152db73c51 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -150,7 +150,7 @@ def get_prim_expr_list( Parameters ---------- - inputs : Union[relax.Constant, relax.ShapeExpr] + inputs : Union[relax.Constant, relax.ShapeExpr, relax.PrimValue] The input value to try to convert to a list of PrimExpr. Returns @@ -165,6 +165,8 @@ def get_prim_expr_list( return np_value.tolist() elif isinstance(inputs, relax.ShapeExpr): return inputs.values + elif isinstance(inputs, relax.PrimValue): + return [inputs.value.value] else: raise ValueError("Cannot cast {} to list of PrimExpr".format(type(inputs))) @@ -233,6 +235,19 @@ class Div(OnnxOpConverter): if all([isinstance(inp, relax.Constant) for inp in inputs]): output = inputs[0].data.numpy() / inputs[1].data.numpy() return relax.const(output, inputs[0].struct_info.dtype) + if any([isinstance(inp, relax.PrimValue) for inp in inputs]): + x = ( + int(inputs[0].value) + if isinstance(inputs[0], relax.PrimValue) + else inputs[0].data.numpy() + ) + y = ( + int(inputs[1].value) + if isinstance(inputs[1], relax.PrimValue) + else inputs[1].data.numpy() + ) + return relax.PrimValue(int(x / y)) + return relax.op.divide(inputs[0], inputs[1]) @@ -359,6 +374,19 @@ class Add(OnnxOpConverter): if all([isinstance(inp, relax.Constant) for inp in inputs]): output = inputs[0].data.numpy() + inputs[1].data.numpy() return relax.const(output, output.dtype) + # If primvalues are involved, handle them directly. + if any([isinstance(inp, relax.PrimValue) for inp in inputs]): + x = ( + int(inputs[0].value) + if isinstance(inputs[0], relax.PrimValue) + else inputs[0].data.numpy() + ) + y = ( + int(inputs[1].value) + if isinstance(inputs[1], relax.PrimValue) + else inputs[1].data.numpy() + ) + return relax.PrimValue(int(x + y)) return relax.op.add(inputs[0], inputs[1]) @@ -367,9 +395,24 @@ class Mul(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr, params): + # When all inputs are constant, directly multiply. if all([isinstance(inp, relax.Constant) for inp in inputs]): output = inputs[0].data.numpy() * inputs[1].data.numpy() return relax.const(output, output.dtype) + # If primvalues are involved, handle them directly. + if any([isinstance(inp, relax.PrimValue) for inp in inputs]): + x = ( + int(inputs[0].value) + if isinstance(inputs[0], relax.PrimValue) + else inputs[0].data.numpy() + ) + y = ( + int(inputs[1].value) + if isinstance(inputs[1], relax.PrimValue) + else inputs[1].data.numpy() + ) + return relax.PrimValue(int(x * y)) + return relax.op.multiply(inputs[0], inputs[1]) @@ -382,6 +425,8 @@ class Cast(OnnxOpConverter): if isinstance(inputs[0], relax.Constant): output = inputs[0].data.numpy().astype(to_type) return relax.const(output, to_type) + if isinstance(inputs[0], relax.PrimValue): + return relax.PrimValue(inputs[0].value.astype(to_type)) return relax.op.astype(inputs[0], to_type) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 2d7e60c4f0..edf84e6887 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -604,11 +604,17 @@ TVM_REGISTER_OP("relax.permute_dims") /* relax.reshape */ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { - if (const auto* e = shape.as<ExprNode>()) { + const ArrayNode* array; + // Treat shape expressions as constant arrays to handle special values. + if (const auto* e = shape.as<ShapeExprNode>()) { + array = e->values.as<ArrayNode>(); + // Other non-shape expressions are used directly. + } else if (const auto* e = shape.as<ExprNode>()) { return GetRef<Expr>(e); + // Process special values in constants and produce an expression. + } else { + array = shape.as<ArrayNode>(); } - - const auto* array = shape.as<ArrayNode>(); CHECK(array != nullptr) << "Reshape only expects the input new shape to be either an Expr or an " "Array of PrimExprs. However, the given new shape is " << shape; diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 07e21cc179..b0b4b98ab5 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -72,6 +72,9 @@ def test_reshape_infer_struct_info(): bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32") ) _check_inference(bb, relax.op.reshape(x0, (-1,)), relax.TensorStructInfo((120,), "float32")) + _check_inference( + bb, relax.op.reshape(x0, relax.ShapeExpr([-1])), relax.TensorStructInfo((120,), "float32") + ) _check_inference( bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") )