This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch torchbench in repository https://gitbox.apache.org/repos/asf/tvm.git
commit bacf3946c727682e7aad82f03e34abbbd9f120a2 Author: Masahiro Masuda <masahi...@gmail.com> AuthorDate: Wed Sep 14 13:09:45 2022 +0900 support constant folding on ndarray_size --- python/tvm/relay/frontend/pytorch.py | 2 +- src/relay/transforms/fold_constant.cc | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index e2badaabf7..722b2889d3 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2489,7 +2489,7 @@ class PyTorchOpConverter: ) def numel(self, inputs, input_types): - return _op.ndarray_size(inputs[0]) + return fold_constant(_op.ndarray_size(inputs[0])) def empty(self, inputs, input_types): shape = inputs[0] diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 9dec840be0..f484dfc700 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -188,8 +188,7 @@ class ConstantFolder : public MixedModeMutator { if (is_no_computational && (is_no_qnn_canonicalized || !fold_qnn_)) { return std::move(post_call); } - if (op == device_copy_op_ || op == shape_of_op_ || op == vm_shape_of_op_ || - op == ndarray_size_op_) { + if (op == device_copy_op_ || op == shape_of_op_ || op == vm_shape_of_op_) { // We should think about potentially constant evaluation over these ops too. return std::move(post_call); } @@ -383,6 +382,13 @@ class ConstantFolder : public MixedModeMutator { // TODO(mbs): This is not necessary since we only ever ask for the shapes for // pre-rewritten expressions which will always have a checked_type. return const_node->tensor_type()->shape; + // } else if (auto ttype = input->type_as<TensorTypeNode>()) { + } else if (const auto* var = input.as<VarNode>()) { + auto ty = var->type_annotation; + if (ty->IsInstance<TensorTypeNode>()) { + return Downcast<TensorType>(ty)->shape; + } + return {}; } else if (input->checked_type_.defined()) { return input->checked_type().as<TensorTypeNode>()->shape; } else {