This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1d39f2c974 [FQ2I] fix unary op output affine type in fq2i (#12224)
1d39f2c974 is described below

commit 1d39f2c974e09e5a767b67e127a5132f0b36c102
Author: Matthew Brookhart <mbrookh...@octoml.ai>
AuthorDate: Sat Jul 30 21:00:55 2022 -0600

    [FQ2I] fix unary op output affine type in fq2i (#12224)
    
    * fix unary op output affine type in fq2i
    
    * better names
    
    * add option to force to positive values for ops that are undefined on 
negative values
---
 .../transform/fake_quantization_to_integer.py      |  2 +-
 .../test_pass_fake_quantization_to_integer.py      | 35 +++++++++++++++-------
 2 files changed, 25 insertions(+), 12 deletions(-)

diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py 
b/python/tvm/relay/transform/fake_quantization_to_integer.py
index 8308298e70..b0464439b0 100644
--- a/python/tvm/relay/transform/fake_quantization_to_integer.py
+++ b/python/tvm/relay/transform/fake_quantization_to_integer.py
@@ -534,7 +534,7 @@ def register_unary_qnn(op_name, op):
             out_t.scale,
             out_t.zero_point,
         )
-        return [out, x_t]
+        return [out, out_t]
 
     return register_fake_quantization_to_integer(op_name, unary)
 
diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py 
b/tests/python/relay/test_pass_fake_quantization_to_integer.py
index d0c8cca6b7..38520ff2df 100644
--- a/tests/python/relay/test_pass_fake_quantization_to_integer.py
+++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py
@@ -318,23 +318,36 @@ def test_fake_quantize_global_avg_pool():
 
 
 class TestUnaryQNNOp:
-    def helper_test_fake_quantize_unary_op(self, fp32_op, scale=0.125):
-        x = relay.var("x", shape=[1, 3, 3, 3], dtype="int8")
-        mid_point = relay.const(-128)
+    def helper_test_fake_quantize_unary_op(self, fp32_op, pos_values=False):
+        for dtype in ["int8", "uint8"]:
+            x = relay.var("x", shape=[1, 3, 3, 3], dtype=dtype)
 
-        x = relay.qnn.op.dequantize(x, relay.const(scale), mid_point)
-        op = fp32_op(x)
-        op = relay.qnn.op.quantize(op, relay.const(scale), mid_point)
+            zero = -128 if dtype == "int8" else 0
+            if pos_values:
+                # Use a positive range for quanitzed ops that only work on 
positive values
+                input_mid_point = relay.const(zero)
+                output_mid_point = relay.const(zero)
+            else:
+                input_mid_point = relay.const(np.random.randint(0, 255) + zero)
+                output_mid_point = relay.const(np.random.randint(0, 255) + 
zero)
 
-        x_np = np.random.randint(-128, 127, size=[1, 3, 3, 3], dtype="int8")
+            input_scale = relay.const(np.random.rand())
+            output_scale = relay.const(np.random.rand())
 
-        compare_fq_to_int(op, [x_np], True)
+            x = relay.qnn.op.dequantize(x, input_scale, input_mid_point)
+            op = fp32_op(x)
+
+            op = relay.qnn.op.quantize(op, output_scale, output_mid_point, 
out_dtype=dtype)
+
+            x_np = np.random.randint(0 + zero, 255 + zero, size=[1, 3, 3, 3], 
dtype=dtype)
+
+            compare_fq_to_int(op, [x_np], True)
 
     def test_sqrt(self):
-        self.helper_test_fake_quantize_unary_op(fp32_op=relay.sqrt)
+        self.helper_test_fake_quantize_unary_op(fp32_op=relay.sqrt, 
pos_values=True)
 
     def test_rsqrt(self):
-        self.helper_test_fake_quantize_unary_op(fp32_op=relay.rsqrt)
+        self.helper_test_fake_quantize_unary_op(fp32_op=relay.rsqrt, 
pos_values=True)
 
     def test_exp(self):
         self.helper_test_fake_quantize_unary_op(fp32_op=relay.exp)
@@ -349,7 +362,7 @@ class TestUnaryQNNOp:
         self.helper_test_fake_quantize_unary_op(fp32_op=relay.tanh)
 
     def test_log(self):
-        self.helper_test_fake_quantize_unary_op(fp32_op=relay.log)
+        self.helper_test_fake_quantize_unary_op(fp32_op=relay.log, 
pos_values=True)
 
 
 def test_fake_quantize_reshape():

Reply via email to