This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new b7d3762f39 [Unity][Hexagon] Don't pass raw scalars in 
hexagon/qnn/nn.py (#14474)
b7d3762f39 is described below

commit b7d3762f390a895020c40bf2c425fa49ae632bdd
Author: Krzysztof Parzyszek <kparz...@quicinc.com>
AuthorDate: Tue Apr 4 14:10:53 2023 -0500

    [Unity][Hexagon] Don't pass raw scalars in hexagon/qnn/nn.py (#14474)
    
    Everything is expected to be either te.Tensor, or a PrimExpr, so make
    sure that function parameters conform to that. Also, remove incorrect
    type annotations from `qnn_requantize`.
---
 python/tvm/topi/hexagon/qnn/nn.py | 14 ++++----------
 1 file changed, 4 insertions(+), 10 deletions(-)

diff --git a/python/tvm/topi/hexagon/qnn/nn.py 
b/python/tvm/topi/hexagon/qnn/nn.py
index 1a707cef7e..3024ec9165 100644
--- a/python/tvm/topi/hexagon/qnn/nn.py
+++ b/python/tvm/topi/hexagon/qnn/nn.py
@@ -86,7 +86,7 @@ def get_const_float_value(expr):
 
 def get_qnn_param(param, indices, axis):
     # Account scalar and 1D quantization parameters:
-    if len(param.shape) == 0:
+    if is_scalar(param):
         return param
 
     param_idx = tvm.tir.indexmod(indices[axis], topi.shape(param)[0])
@@ -213,13 +213,7 @@ def schedule_qnn_dequantize(outs):
 
 
 def qnn_requantize(
-    data: te.Tensor,
-    input_scale: te.Tensor,
-    input_zp: te.Tensor,
-    output_scale: te.Tensor,
-    output_zp: te.Tensor,
-    axis=-1,
-    out_dtype="int8",
+    data: te.Tensor, input_scale, input_zp, output_scale, output_zp, axis=-1, 
out_dtype="int8"
 ):
     """Compute for qnn.requantize
 
@@ -233,7 +227,6 @@ def qnn_requantize(
 
     TODO: support 'rounding' and 'compute_dtype' arguments.
     """
-
     if is_scalar(input_scale) and is_scalar(output_scale):
         iscale = get_const_float_value(input_scale)
         oscale = get_const_float_value(output_scale)
@@ -431,9 +424,10 @@ def qnn_mul(
     if is_scalar(lhs_scale) and is_scalar(rhs_scale):
         assert isinstance(lhs_scale, te.Tensor)
         assert isinstance(rhs_scale, te.Tensor)
-        iscale = get_const_float_value(lhs_scale.op.body[0]) * 
get_const_float_value(
+        iscale_val = get_const_float_value(lhs_scale.op.body[0]) * 
get_const_float_value(
             rhs_scale.op.body[0]
         )
+        iscale = tvm.tir.const(iscale_val)
     else:
         iscale = lhs_scale * rhs_scale
 

Reply via email to