This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 76155c2f3c [QNN] Support different qnn params between in/out tensor in 
leaky_relu (#12116)
76155c2f3c is described below

commit 76155c2f3c327ad7ada9d8fcb1c7f6f447dcc0ec
Author: zhaoyang-star <zhaoyangs...@foxmail.com>
AuthorDate: Sat Jul 23 05:33:31 2022 +0800

    [QNN] Support different qnn params between in/out tensor in leaky_relu 
(#12116)
    
    * [QNN] Support different qnn params between in/out tensor in leaky_relu
    
    * format code
    
    * format code
    
    * fix bug
    
    * fix format
    
    * fix format
    
    * fix
---
 python/tvm/relay/frontend/qnn_torch.py             |  6 +-
 python/tvm/relay/qnn/op/qnn.py                     | 21 ++++--
 .../transform/fake_quantization_to_integer.py      |  9 ++-
 src/relay/qnn/op/leaky_relu.cc                     | 85 +++++++++++++++-------
 tests/python/relay/test_op_qnn_leaky_relu.py       | 30 +++++---
 5 files changed, 104 insertions(+), 47 deletions(-)

diff --git a/python/tvm/relay/frontend/qnn_torch.py 
b/python/tvm/relay/frontend/qnn_torch.py
index 251f46630a..74d5e2e0f5 100644
--- a/python/tvm/relay/frontend/qnn_torch.py
+++ b/python/tvm/relay/frontend/qnn_torch.py
@@ -963,7 +963,11 @@ def _leaky_relu(fp32_piggy_back=False):
         alpha = inputs[1]
         output_scale = _expr.const(inputs[3])
         output_zero_point = _expr.const(inputs[4])
-        return relay.qnn.op.leaky_relu(inputs[0], alpha, output_scale, 
output_zero_point)
+        input_scale = _expr.const(inputs[5])
+        input_zero_point = _expr.const(inputs[6])
+        return relay.qnn.op.leaky_relu(
+            inputs[0], alpha, input_scale, input_zero_point, output_scale, 
output_zero_point
+        )
 
     def _impl(inputs, _):
         assert len(inputs) == 7, "Input quant params not found in op inputs"
diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py
index edb528708c..17dba15e09 100644
--- a/python/tvm/relay/qnn/op/qnn.py
+++ b/python/tvm/relay/qnn/op/qnn.py
@@ -1179,7 +1179,7 @@ reg.register_pattern("qnn.quantize", OpPattern.OPAQUE)
 reg.register_pattern("qnn.dequantize", OpPattern.OPAQUE)
 
 
-def leaky_relu(x, alpha, scale, zero_point):
+def leaky_relu(x, alpha, input_scale, input_zero_point, output_scale, 
output_zero_point):
     """Quantized leaky relu.
 
     Parameters
@@ -1188,11 +1188,14 @@ def leaky_relu(x, alpha, scale, zero_point):
         The quantized input tensor.
     alpha: double
         The alpha value.
-    scale: relay.Expr
-        The scale of the quantized expr.
-    zero_point: relay.Expr
-       The zero point of quantized expr.
-
+    input_scale: relay.Expr
+        The scale of the input quantized expr.
+    input_zero_point: relay.Expr
+       The zero point of input quantized expr.
+    output_scale: relay.Expr
+        The scale of the output quantized expr.
+    output_zero_point: relay.Expr
+       The zero point of output quantized expr.
     Returns
     -------
     result : relay.Expr
@@ -1201,6 +1204,8 @@ def leaky_relu(x, alpha, scale, zero_point):
     return _make.leaky_relu(
         x,
         alpha,
-        scale,
-        zero_point,
+        input_scale,
+        input_zero_point,
+        output_scale,
+        output_zero_point,
     )
diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py 
b/python/tvm/relay/transform/fake_quantization_to_integer.py
index 4436960a20..8308298e70 100644
--- a/python/tvm/relay/transform/fake_quantization_to_integer.py
+++ b/python/tvm/relay/transform/fake_quantization_to_integer.py
@@ -364,10 +364,13 @@ def relu(expr, type_map):
 def leaky_relu(expr, type_map):
     """Rewrite a leaky relu op"""
     arg = expr.args[0]
-    t = type_map[arg]
+    x_t = type_map[arg]
+    out_t = type_map[expr]
     alpha = expr.attrs.alpha
-    output = relay.qnn.op.leaky_relu(expr, alpha, t.scale, t.zero_point)
-    return [output, t]
+    output = relay.qnn.op.leaky_relu(
+        expr, alpha, x_t.scale, x_t.zero_point, out_t.scale, out_t.zero_point
+    )
+    return [output, x_t]
 
 
 @register_fake_quantization_to_integer("nn.pad")
diff --git a/src/relay/qnn/op/leaky_relu.cc b/src/relay/qnn/op/leaky_relu.cc
index a4881dfbbd..75bfabb7db 100644
--- a/src/relay/qnn/op/leaky_relu.cc
+++ b/src/relay/qnn/op/leaky_relu.cc
@@ -32,8 +32,8 @@ namespace qnn {
 
 bool QnnLeakyReluRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
                      const TypeReporter& reporter) {
-  // Expected Types: data, scale, zero_point
-  ICHECK_EQ(types.size(), 4);
+  // Expected Types: data, input_scale, input_zero_point, output_scale, 
output_zero_point, out_type
+  ICHECK_EQ(types.size(), 6);
   const auto* x = types[0].as<TensorTypeNode>();
   if (x == nullptr) return false;
   ICHECK(x->dtype == DataType::Int(8) || x->dtype == DataType::UInt(8))
@@ -42,31 +42,37 @@ bool QnnLeakyReluRel(const Array<Type>& types, int 
num_inputs, const Attrs& attr
   ICHECK(param != nullptr) << "LeakyReluAttrs cannot be nullptr.";
 
   // Check the types of scale and zero points.
-  for (size_t i = 1; i < 3; ++i) {
+  for (size_t i = 1; i < 5; ++i) {
     if (types[i].as<IncompleteTypeNode>()) {
       return false;
     }
   }
 
-  ICHECK(IsScalarType(types[1], DataType::Float(32)));  // scale
-  ICHECK(IsScalarType(types[2], DataType::Int(32)));    // zero_point
+  ICHECK(IsScalarType(types[1], DataType::Float(32)));  // input_scale
+  ICHECK(IsScalarType(types[2], DataType::Int(32)));    // input_zero_point
+  ICHECK(IsScalarType(types[3], DataType::Float(32)));  // output_scale
+  ICHECK(IsScalarType(types[4], DataType::Int(32)));    // output_zero_point
 
   // Assign types for scale and zero points.
-  reporter->Assign(types[1], TensorType({}, DataType::Float(32)));  // scale
-  reporter->Assign(types[2], TensorType({}, DataType::Int(32)));    // 
zero_point
+  reporter->Assign(types[1], TensorType({}, DataType::Float(32)));  // 
input_scale
+  reporter->Assign(types[2], TensorType({}, DataType::Int(32)));    // 
input_zero_point
+  reporter->Assign(types[3], TensorType({}, DataType::Float(32)));  // 
output_scale
+  reporter->Assign(types[4], TensorType({}, DataType::Int(32)));    // 
output_zero_point
 
   // Collect the input tensor and output tensor devoid of scale and zero 
points to reuse Relay
   // IdentityRel infer type function.
-  Array<Type> tensor_types = {types[0], types[3]};
+  Array<Type> tensor_types = {types[0], types[5]};
   return IdentityRel(tensor_types, 2, attrs, reporter);
 }
 
 // Positional relay function to create quantized leaky relu operator used by 
frontend FFI.
-Expr MakeQuantizedLeakyRelu(Expr x, double alpha, Expr scale, Expr zero_point) 
{
+Expr MakeQuantizedLeakyRelu(Expr x, double alpha, Expr input_scale, Expr 
input_zero_point,
+                            Expr output_scale, Expr output_zero_point) {
   auto attrs = make_object<LeakyReluAttrs>();
   attrs->alpha = alpha;
   static const Op& op = Op::Get("qnn.leaky_relu");
-  return Call(op, {x, scale, zero_point}, Attrs(attrs), {});
+  return Call(op, {x, input_scale, input_zero_point, output_scale, 
output_zero_point}, Attrs(attrs),
+              {});
 }
 
 /*
@@ -82,42 +88,69 @@ Expr QnnLeakyReluCanonicalize(const Attrs& attrs, const 
Array<Expr>& new_args,
   // by a small alpha value < 1.
   //
   // We assume the same scale and zero point for alpha and the input tensor.
-  // Let T = s(q_t - z) where q_t is the input arg[0]
-  // Then, the quantized value of alpha * T is:
-  // q(a * T, s, z) = [(a * T) / s] + z = a * s(q_t - z) / s + z = a * (q_t - 
z) + z
-  // = a * q_t + (1 - a) * z
+  // LeakyReLU can be written in terms of respective quantized tensors, scales 
and
+  // zero points as
   //
-  // We return the quantized value of alpha * T for all values q_t < 
input_zero_point.
-
-  ICHECK_EQ(new_args.size(), 3);
-  Expr quantized_data = Cast(new_args[0], DataType::Int(32));
+  //    scale_o * (Q_o - zp_o) = alpha * scale_i * (Q_i - zp_i)  when Q_i < 
zp_i  (1)
+  //    scale_o * (Q_o - zp_o) = scale_i * (Q_i - zp_i)  when Q_i >= zp_i  (2)
+  //
+  // Since the input qnn params can be different than output qnn params, we 
first requantize the
+  // input tensor to the output qnn params. After requantizing Q_i, equation 
(1) becames equation
+  // (3) where Q_i' is the requantized data from Q_i.
+  //
+  //    scale_o * (Q_o - zp_o) = alpha * scale_o * (Q_i' - zp_o)  when Q_i < 
zp_i  (3)
+  //                       Q_o = alpha * Q_i' + (1 - alpha) * zp_o  when Q_i < 
zp_i  (4)
+  //
+  // It is equal to requantize Q_i to Q_o using scale_o and zp_o in equation 
(2).
+  // So equation (2) becomes
+  //
+  //                       Q_o = requantize(Q_i)  when Q_i >= zp_i  (5)
+  //
+  // Finnally, Q_o could be calculated by equation (4) and equation (5).
+  ICHECK_EQ(new_args.size(), 5);
+  Expr data = Cast(new_args[0], DataType::Int(32));
+  Expr input_scale = new_args[1];
   Expr input_zero_point = Cast(new_args[2], DataType::Int(32));
+  Expr output_scale = new_args[3];
+  Expr output_zero_point = Cast(new_args[4], DataType::Int(32));
 
   const auto* q_attrs = attrs.as<LeakyReluAttrs>();
   auto alpha = q_attrs->alpha;
 
+  const auto input_shape = get_shape(arg_types[0]);
+  const auto input_dtype = arg_types[0].as<TensorTypeNode>()->dtype;
+
+  // requantize the input to Q_i'
+  auto requantized_expr = RequantizeOrUpcast(data, input_scale, 
input_zero_point, output_scale,
+                                             output_zero_point, input_shape);
+
+  // alpha * Q_i'
   int32_t fixed_point_multiplier, shift;
   std::tie(fixed_point_multiplier, shift) = 
GetFixedPointMultiplierShift(alpha);
-  auto prod = FixedPointMultiply(quantized_data, fixed_point_multiplier, 
shift);
+  auto prod = FixedPointMultiply(requantized_expr, fixed_point_multiplier, 
shift);
 
+  // (1 - alpha) * zp_o
   int32_t fixed_point_multiplier_z, shift_z;
   std::tie(fixed_point_multiplier_z, shift_z) = GetFixedPointMultiplierShift(1 
- alpha);
-  auto scaled_z = FixedPointMultiply(input_zero_point, 
fixed_point_multiplier_z, shift_z);
+  auto scaled_z = FixedPointMultiply(output_zero_point, 
fixed_point_multiplier_z, shift_z);
 
+  // alpha * Q_i' + (1 - alpha) * zp_o
   auto add = Add(prod, scaled_z);
-  auto output = Where(Less(quantized_data, input_zero_point), add, 
quantized_data);
+  auto output = Where(Less(data, input_zero_point), add, requantized_expr);
 
-  const auto* input_type = arg_types[0].as<TensorTypeNode>();
-  return ConvertDtype(output, input_type->dtype);
+  return ConvertDtype(output, input_dtype);
 }
 
 RELAY_REGISTER_OP("qnn.leaky_relu")
     .describe("Leaky relu for quantized tensors.")
     .set_attrs_type<LeakyReluAttrs>()
-    .set_num_inputs(3)
+    .set_num_inputs(5)
     .add_argument("data", "Quantized Tensor", "The input data.")
-    .add_argument("scale", "Tensor", "The quantization scale of the input 
tensor.")
-    .add_argument("zero_point", "Tensor", "The quantization zero_point of the 
input tensor.")
+    .add_argument("input_scale", "Tensor", "The quantization scale of the 
input tensor.")
+    .add_argument("input_zero_point", "Tensor", "The quantization zero_point 
of the input tensor.")
+    .add_argument("output_scale", "Tensor", "The quantization scale of the 
output tensor.")
+    .add_argument("output_zero_point", "Tensor",
+                  "The quantization zero_point of the output tensor.")
     .set_support_level(11)
     .add_type_rel("QLeakyRelu", QnnLeakyReluRel)
     .set_attr<TNonComputational>("TNonComputational", true)
diff --git a/tests/python/relay/test_op_qnn_leaky_relu.py 
b/tests/python/relay/test_op_qnn_leaky_relu.py
index 76f581817c..ade897bf6e 100644
--- a/tests/python/relay/test_op_qnn_leaky_relu.py
+++ b/tests/python/relay/test_op_qnn_leaky_relu.py
@@ -24,26 +24,36 @@ def dequantize(data, scale, zp):
     return scale * (np.asarray(data) - zp)
 
 
-def generate_golden_output(x_data, dequantized_x, alpha, scale, zero_point):
+def generate_golden_output(x_data, dequantized_x, alpha, o_scale, 
o_zero_point, i_zero_point):
     prod = np.multiply(dequantized_x, alpha)
-    prod = np.around(prod / scale + zero_point)
+    prod = np.around(prod / o_scale + o_zero_point)
 
-    output = np.where(x_data < zero_point, prod, x_data)
+    q_min = np.iinfo(np.uint8).min
+    q_max = np.iinfo(np.uint8).max
+    prod = np.clip(prod, q_min, q_max)
+
+    requantized = np.clip(np.round(dequantized_x / o_scale + o_zero_point), 
q_min, q_max)
+
+    output = np.where(x_data < i_zero_point, prod, requantized)
     return output
 
 
 def test_qnn_leaky_relu():
     data_dtype = "uint8"
-    scale = 0.125
-    zero_point = 60
+    input_scale = 0.125
+    input_zero_point = 60
+    output_scale = 0.6
+    output_zero_point = 17
     alpha = 0.9
 
     x = relay.var("x", shape=(1, 4), dtype=data_dtype)
     y = relay.qnn.op.leaky_relu(
         x=x,
         alpha=alpha,
-        scale=relay.const(scale, "float32"),
-        zero_point=relay.const(zero_point, "int32"),
+        input_scale=relay.const(input_scale, "float32"),
+        input_zero_point=relay.const(input_zero_point, "int32"),
+        output_scale=relay.const(output_scale, "float32"),
+        output_zero_point=relay.const(output_zero_point, "int32"),
     )
 
     func = relay.Function([x], y)
@@ -53,8 +63,10 @@ def test_qnn_leaky_relu():
     func = mod["main"]
 
     x_data = np.array((255, 133, 0, 9)).reshape((1, 4))
-    x_dequantized = dequantize(x_data, scale, zero_point)
-    golden_output = generate_golden_output(x_data, x_dequantized, alpha, 
scale, zero_point)
+    x_dequantized = dequantize(x_data, input_scale, input_zero_point)
+    golden_output = generate_golden_output(
+        x_data, x_dequantized, alpha, output_scale, output_zero_point, 
input_zero_point
+    )
 
     op_res = relay.create_executor("graph", device=tvm.cpu(0), 
target="llvm").evaluate(func)(x_data)
 

Reply via email to