This is an automated email from the ASF dual-hosted git repository.

echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0e905aa4d7 [Relay] Stop ToMixedPrecision when constant is out of dtype 
range (#15461)
0e905aa4d7 is described below

commit 0e905aa4d755cbaeb71cd1fe979b91434177b256
Author: Egor Churaev <[email protected]>
AuthorDate: Thu Aug 3 08:27:34 2023 +0300

    [Relay] Stop ToMixedPrecision when constant is out of dtype range (#15461)
    
    * [Relay] Stop ToMixedPrecision when constant is out of dtype range
    
    In some layers, e.g. Clip, we might have a compilation error in the
    case when operation takes on the input a constant which is out of
    target data type range.
    
    To prevent such situation, a new method was introduced. It compares
    values of constant attributes with the range of the target data type. In
    case if the value is out of range then float32 will be used.
    
    * Fix lint
---
 src/relay/transforms/to_mixed_precision.cc    | 41 ++++++++++++++++++++--
 tests/python/relay/test_to_mixed_precision.py | 49 +++++++++++++++++++++++++++
 2 files changed, 88 insertions(+), 2 deletions(-)

diff --git a/src/relay/transforms/to_mixed_precision.cc 
b/src/relay/transforms/to_mixed_precision.cc
index 820bc6e58e..4638ee5477 100644
--- a/src/relay/transforms/to_mixed_precision.cc
+++ b/src/relay/transforms/to_mixed_precision.cc
@@ -31,6 +31,7 @@
 
 #include <utility>
 
+#include "../../support/scalars.h"
 #include "pattern_utils.h"
 
 namespace tvm {
@@ -110,6 +111,39 @@ class MixedPrecisionPass : public MixedModeMutator {
   std::vector<DataType> original_dtype_;
   bool keep_orig_output_dtype_;
 
+  /*! \brief If some of the constant attributes are out of 
mixed_precision_type_ bounds, then
+   * computation cannot be performed in mixed precision. */
+  bool IsMixedPrecisionApplicableToAttrs(const Attrs& attrs) const {
+    if (attrs.get() != nullptr) {
+      double min_bound;
+      double max_bound;
+      if (mixed_precision_type_.is_float16()) {
+        min_bound = -support::kMaxFloat16;
+        max_bound = support::kMaxFloat16;
+      } else if (mixed_precision_type_.is_bfloat16()) {
+        min_bound = -support::kMaxBFloat16;
+        max_bound = support::kMaxBFloat16;
+      } else if (mixed_precision_type_.is_float8()) {
+        double bound = (mixed_precision_type_.code() == DataType::kE4M3Float) 
? support::kMaxE4M3
+                                                                              
: support::kMaxE5M2;
+        min_bound = -bound;
+        max_bound = bound;
+      } else if (mixed_precision_type_.is_float()) {
+        min_bound = std::numeric_limits<float>::lowest();
+        max_bound = std::numeric_limits<float>::max();
+      } else {
+        return true;
+      }
+
+      if (auto cur_attrs = attrs.as<ClipAttrs>()) {
+        if (cur_attrs->a_min < min_bound || cur_attrs->a_max > max_bound) {
+          return false;
+        }
+      }
+    }
+    return true;
+  }
+
   Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) 
const {
     /* If the accumulation dtype is in the attributes make a copy and mutate 
the field. */
     Attrs cur_attrs = call->attrs;
@@ -382,9 +416,12 @@ class MixedPrecisionPass : public MixedModeMutator {
           all_args_mixed_type_compatible ? MIXED_PRECISION_ALWAYS : 
MIXED_PRECISION_NEVER;
     }
 
+    bool is_mixed_precision_applicable =
+        static_cast<bool>(final_category == MIXED_PRECISION_ALWAYS &&
+                          
IsMixedPrecisionApplicableToAttrs(pre_call_node->attrs));
     // Create the new arguments to the call.
     DataType wanted_arg_dtypes =
-        final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type_ : 
DataType::Float(32);
+        is_mixed_precision_applicable ? mixed_precision_type_ : 
DataType::Float(32);
     auto call_args_and_types = CastAllArgs(post_call_node->args, 
cur_arg_types, wanted_arg_dtypes);
     Array<Expr> new_args = call_args_and_types.first;
     Array<Type> new_arg_types;
@@ -397,7 +434,7 @@ class MixedPrecisionPass : public MixedModeMutator {
     }
 
     // Finally create the new attributes.
-    if (final_category == MIXED_PRECISION_ALWAYS) {
+    if (is_mixed_precision_applicable) {
       Attrs new_attrs = GetNewAttrs(pre_call_node, accumulation_dtype);
       Expr output = Call(cur_op, new_args, new_attrs, new_arg_types, 
pre_call_node->span);
       if (accumulation_dtype != output_dtype) {
diff --git a/tests/python/relay/test_to_mixed_precision.py 
b/tests/python/relay/test_to_mixed_precision.py
index 771d366df0..a802eee6d6 100644
--- a/tests/python/relay/test_to_mixed_precision.py
+++ b/tests/python/relay/test_to_mixed_precision.py
@@ -537,5 +537,54 @@ def 
test_convert_follow_node_with_integer_arguments(target_precision):
     assert tvm.ir.structural_equal(expected_mod, output_mod)
 
 
+def test_clip(target_precision):
+    data = relay.var("data", shape=[1, 10], dtype="float32")
+    res = relay.clip(data, a_min=-128000, a_max=128000)
+
+    mod = tvm.IRModule.from_expr(res)
+
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=[1, 10]).astype("float32"),
+    }
+    output_mod = verify_mixed_precision_output_close(
+        mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, 
rtol=0.01
+    )
+
+    # Create expected module
+    if target_precision == "bfloat16":
+        data = relay.cast(relay.var("data", shape=[1, 10]), target_precision)
+    res = relay.clip(data, a_min=-128000, a_max=128000)
+    expected_mod = tvm.IRModule.from_expr(res)
+    expected_mod = InferType()(expected_mod)
+    assert tvm.ir.structural_equal(expected_mod, output_mod)
+
+
+def test_clip_with_pre_op(target_precision):
+    data = relay.var("data", shape=[1, 10], dtype="float32")
+    const = relay.const(5, "float32")
+    res = relay.divide(data, const)
+    res = relay.clip(res, a_min=-128000, a_max=128000)
+
+    mod = tvm.IRModule.from_expr(res)
+
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=[1, 10]).astype("float32"),
+    }
+    output_mod = verify_mixed_precision_output_close(
+        mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, 
rtol=0.01
+    )
+
+    # Create expected module
+    data = relay.cast(relay.var("data", shape=[1, 10]), target_precision)
+    const = relay.cast(relay.const(5, "float32"), target_precision)
+    res = relay.divide(data, const)
+    if target_precision == "float16":
+        res = relay.cast(res, "float32")
+    res = relay.clip(res, a_min=-128000, a_max=128000)
+    expected_mod = tvm.IRModule.from_expr(res)
+    expected_mod = InferType()(expected_mod)
+    assert tvm.ir.structural_equal(expected_mod, output_mod)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to