This is an automated email from the ASF dual-hosted git repository. jroesch pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push: new 372d737 [RELAY] Refactor FoldConstant to skip TNonComputationalOps (#6720) 372d737 is described below commit 372d7374d221fb98f7e7fe5d9d5c937059a35515 Author: Lily Orth-Smith <lilyorthsm...@gmail.com> AuthorDate: Sat Oct 24 00:23:50 2020 -0700 [RELAY] Refactor FoldConstant to skip TNonComputationalOps (#6720) * add TNonComputational to qnn ops and change FoldConstant * remove comments * check if op in nonComputational map * forgot to mark device_copy op as TNonComputational * hacky fix to fuseops pass * fix typo * manually skip device_copy in fold_constant * Update src/relay/transforms/fold_constant.cc Co-authored-by: Junru Shao <junrushao1...@gmail.com> Co-authored-by: Junru Shao <junrushao1...@gmail.com> --- src/relay/qnn/op/concatenate.cc | 1 + src/relay/qnn/op/convolution.cc | 1 + src/relay/qnn/op/dense.cc | 1 + src/relay/qnn/op/dequantize.cc | 1 + src/relay/qnn/op/op_common.h | 1 + src/relay/qnn/op/quantize.cc | 1 + src/relay/qnn/op/requantize.cc | 1 + src/relay/transforms/fold_constant.cc | 9 ++++++--- 8 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc index 29ecf45..88d2ecc 100644 --- a/src/relay/qnn/op/concatenate.cc +++ b/src/relay/qnn/op/concatenate.cc @@ -207,6 +207,7 @@ RELAY_REGISTER_OP("qnn.concatenate") "The quantization zero_point of the output tensor.") .set_support_level(11) .add_type_rel("QnnConcatenate", QnnConcatenateRel) + .set_attr<TNonComputational>("TNonComputational", true) .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnConcatenateLayout); diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index b2b6b09..73ee456 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -733,6 +733,7 @@ operator to understand how to scale back the int32 output to (u)int8. "The quantization zero_point of the weight tensor.") .set_support_level(11) .add_type_rel("QnnConv2D", QnnConv2DRel) + .set_attr<TNonComputational>("TNonComputational", true) .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnConv2DCanonicalize) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnConvInferCorrectLayout); diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 3cfc418..e1cbfaf 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -189,6 +189,7 @@ RELAY_REGISTER_OP("qnn.dense") "The quantization zero_point of the weight tensor.") .set_support_level(11) .add_type_rel("QDense", QnnDenseRel) + .set_attr<TNonComputational>("TNonComputational", true) .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnDenseCanonicalize); TVM_REGISTER_GLOBAL("relay.qnn.op._make.dense").set_body_typed(MakeQuantizedDense); diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index f0c139c..0a81f3f 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -136,6 +136,7 @@ The input is always quantized (int8, uint8) and will be converted to float32 giv .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") .set_support_level(11) .add_type_rel("Dequantize", DequantizeRel) + .set_attr<TNonComputational>("TNonComputational", true) .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", DequantizeQnnCanonicalize); TVM_REGISTER_GLOBAL("relay.qnn.op._make.dequantize").set_body_typed(MakeDequantize); diff --git a/src/relay/qnn/op/op_common.h b/src/relay/qnn/op/op_common.h index e99c11b..3ca8f64 100644 --- a/src/relay/qnn/op/op_common.h +++ b/src/relay/qnn/op/op_common.h @@ -215,6 +215,7 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con .add_argument("output_scale", "Tensor", "The scale of the output tensor.") \ .add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \ .add_type_rel("QnnBroadcast", QnnBroadcastRel) \ + .set_attr<TNonComputational>("TNonComputational", true) \ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnBinaryBroadcastLayout) } // namespace qnn diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 1b5cb5e..0784791 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -150,6 +150,7 @@ scale and zero point. "The quantization zero_point of the output tensor.") .set_support_level(11) .add_type_rel("Quantize", QuantizeRel) + .set_attr<TNonComputational>("TNonComputational", true) .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QuantizeQnnCanonicalize); TVM_REGISTER_GLOBAL("relay.qnn.op._make.quantize").set_body_typed(MakeQuantize); diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index ea87855..3572a39 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -324,6 +324,7 @@ Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input) "The quantization zero_point of the output tensor.") .set_support_level(11) .add_type_rel("Requantize", RequantizeRel) + .set_attr<TNonComputational>("TNonComputational", true) .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", RequantizeQnnCanonicalize) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", RequantizeInferCorrectLayout); diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 1de690d..4a739dd 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -151,9 +151,12 @@ class ConstantFolder : public MixedModeMutator { } // We should think about potentially constant evaluation over these ops too. - if (call->op == invoke_tvm_op_ || call->op == shape_func_op_ || call->op == alloc_tensor_op_ || - call->op == alloc_storage_op_ || call->op == device_copy_op_) { - return GetRef<Call>(call); + static auto fnoncomputational = Op::GetAttrMap<TNonComputational>("TNonComputational"); + if (const auto* call_node = call->op.as<OpNode>()) { + Op op = GetRef<Op>(call_node); + if ((fnoncomputational.count(op) && fnoncomputational[op]) || (call->op == device_copy_op_)) { + return GetRef<Call>(call); + } } bool all_const_args = true;