This is an automated email from the ASF dual-hosted git repository. jwfromm pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 39dc299c68 [Unity][Op] Enable special dimension value 0 in reshape (#14311) 39dc299c68 is described below commit 39dc299c688f30e22f4d4d334d099c04696da148 Author: Josh Fromm <jwfr...@octoml.ai> AuthorDate: Wed Mar 15 17:10:12 2023 -0700 [Unity][Op] Enable special dimension value 0 in reshape (#14311) [Unity] Enable special dimension value 0 in reshape --- src/relax/op/tensor/manipulate.cc | 42 +++++++++++++++++++++++++------- tests/python/relax/test_op_manipulate.py | 8 ++++-- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index d90fd41e1c..c7bf051302 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -522,7 +522,8 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { "Array of PrimExprs. However, the given new shape is " << shape; int dim_to_infer = -1; - PrimExpr new_shape_prod = IntImm(DataType::Int(64), 1); + // Keep track of which dimensions should be copied from input. + std::vector<int> zero_dims; for (int i = 0; i < static_cast<int>(array->size()); ++i) { const auto* _len = array->at(i).as<PrimExprNode>(); CHECK(_len != nullptr) << "Reshape only expects the input new shape to be either an Expr or an " @@ -533,7 +534,10 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { "integers. However, the give new shape is " << shape; const auto* int_len = len.as<IntImmNode>(); - if (int_len != nullptr && int_len->value == -1) { + if (int_len != nullptr && int_len->value == 0) { + // Note that this dimension should be copied from the original shape. + zero_dims.push_back(i); + } else if (int_len != nullptr && int_len->value == -1) { CHECK_EQ(dim_to_infer, -1) << "Reshape accepts at most one \"-1\" in the new shape. However, " "there are multiple \"-1\" in the given new shape " << shape; @@ -543,15 +547,12 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { << "Reshape requires all values in the new shape to be positive except a single \"-1\". " "However, the given new shape is " << shape; - // We expect any symbolic not to signal the intent of -1, and therefore do no check for - // symbolic value here. - new_shape_prod = new_shape_prod * len; } } Array<PrimExpr> array_ref = GetRef<Array<PrimExpr>>(array); // When there is no dimension to infer, just return the input array as ShapeExpr. - if (dim_to_infer == -1) { + if (dim_to_infer == -1 && zero_dims.empty()) { return ShapeExpr(array_ref); } @@ -569,9 +570,32 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { "to infer. However, the given input shape is " << data_sinfo->shape << " whose shape value is unknown."; - arith::Analyzer analyzer; - PrimExpr old_shape_prod = ComputeShapeProduct(shape_sinfo->values.value()); - array_ref.Set(dim_to_infer, analyzer.Simplify(floordiv(old_shape_prod, new_shape_prod))); + // Set any 0 valued dimensions to match the corresponding input shape. + if (!zero_dims.empty()) { + for (int i : zero_dims) { + array_ref.Set(i, shape_sinfo->values.value()[i]); + } + } + + // Set any -1 dimensions to complete the number of appropriate elements. + // Start by computing the shape product of all positive indices. + PrimExpr new_shape_prod = IntImm(DataType::Int(64), 1); + for (int i = 0; i < static_cast<int>(array_ref.size()); ++i) { + PrimExpr new_dim = array_ref[i]; + const auto* int_dim = new_dim.as<IntImmNode>(); + // We expect any symbolic not to signal the intent of -1, and therefore do no check for + // symbolic value here. + if (int_dim == nullptr || int_dim->value > 0) { + new_shape_prod = new_shape_prod * new_dim; + } + } + + // Assign appropriate value to -1 dimension. + if (dim_to_infer != -1) { + arith::Analyzer analyzer; + PrimExpr old_shape_prod = ComputeShapeProduct(shape_sinfo->values.value()); + array_ref.Set(dim_to_infer, analyzer.Simplify(floordiv(old_shape_prod, new_shape_prod))); + } return ShapeExpr(array_ref); } diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index af20639a8e..16bbc04d26 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -179,6 +179,12 @@ def test_reshape_infer_struct_info_shape_var(): _check_inference( bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") ) + _check_inference( + bb, relax.op.reshape(x0, (2, 3, 0, 5)), relax.TensorStructInfo((2, 3, 4, 5), "float32") + ) + _check_inference( + bb, relax.op.reshape(x0, (1, 3, 0, -1)), relax.TensorStructInfo((1, 3, 4, 10), "float32") + ) _check_inference( bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32") ) @@ -281,8 +287,6 @@ def test_reshape_infer_struct_info_non_positive_new_shape(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) - with pytest.raises(TVMError): - bb.normalize(relax.op.reshape(x, (2, 0, 4, 5))) with pytest.raises(TVMError): bb.normalize(relax.op.reshape(x, (-2, -3, -4, -5)))