This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 8569b9fe63 [Unity][AMP] Fix merging concrete type and "unknown" type 
(#14612)
8569b9fe63 is described below

commit 8569b9fe632b9488978bede30e169d457aafd8c9
Author: masahi <masahi...@gmail.com>
AuthorDate: Mon Apr 17 11:42:38 2023 +0900

    [Unity][AMP] Fix merging concrete type and "unknown" type (#14612)
    
    * fix merging unknown type with concrete type
    
    * add test
    
    * put early return at the beginning
---
 src/relax/transform/infer_amp_utils.cc             |  6 ++
 .../relax/test_transform_to_mixed_precision.py     | 67 ++++++++++++++++++++--
 2 files changed, 67 insertions(+), 6 deletions(-)

diff --git a/src/relax/transform/infer_amp_utils.cc 
b/src/relax/transform/infer_amp_utils.cc
index 330fe9a72a..efe94c43c0 100644
--- a/src/relax/transform/infer_amp_utils.cc
+++ b/src/relax/transform/infer_amp_utils.cc
@@ -38,6 +38,12 @@ NType NTypeFrom(const Expr& expr, DataType dtype) { return 
NTypeFrom(GetStructIn
 
 NType NTypeMerge(const NType& a, const NType& b) {
   auto fcombine = [&](const String& a_str, const String& b_str) -> String {
+    if (a_str == "") {
+      return b_str;
+    } else if (b_str == "") {
+      return a_str;
+    }
+
     DataType a = DataType(String2DLDataType(a_str));
     DataType b = DataType(String2DLDataType(b_str));
     ICHECK_EQ(a.code(), b.code());
diff --git a/tests/python/relax/test_transform_to_mixed_precision.py 
b/tests/python/relax/test_transform_to_mixed_precision.py
index 6b699b5165..721cbd9d58 100644
--- a/tests/python/relax/test_transform_to_mixed_precision.py
+++ b/tests/python/relax/test_transform_to_mixed_precision.py
@@ -23,12 +23,15 @@ from tvm.relax.transform import ToMixedPrecision
 from tvm.script.parser import ir as I, relax as R
 
 
-def _assert_test(input, expected, expected2):
-    mod = ToMixedPrecision()(input)
-    tvm.ir.assert_structural_equal(mod, expected)
-    mod = ToMixedPrecision(out_dtype="float16")(input)
-    print(mod.script())
-    tvm.ir.assert_structural_equal(mod, expected2)
+def _assert_test(input, expected=None, expected2=None):
+    if expected:
+        mod = ToMixedPrecision()(input)
+        tvm.ir.assert_structural_equal(mod, expected)
+
+    if expected2:
+        mod = ToMixedPrecision(out_dtype="float16")(input)
+        print(mod.script())
+        tvm.ir.assert_structural_equal(mod, expected2)
 
 
 def test_conv2d():
@@ -841,5 +844,57 @@ def test_conv2d_bias_conv2d():
     _assert_test(Input, Expected, Expected2)
 
 
+def test_tuple_get():
+    @tvm.script.ir_module
+    class Module:
+        @R.function
+        def main(
+            x: R.Tensor((1, 4, 64, 64), dtype="float32"),
+            w: R.Tensor((512, 4, 3, 3), dtype="float32"),
+            bias: R.Tensor((512, 1, 1), dtype="float32"),
+        ) -> R.Tensor((1, 256, 64, 64), dtype="float32"):
+            with R.dataflow():
+                conv = R.nn.conv2d(
+                    x,
+                    w,
+                    strides=[1, 1],
+                    padding=[0, 0, 1, 1],
+                )
+                bias_out = R.add(conv, bias)
+                split = R.split(bias_out, indices_or_sections=2, axis=1)
+                out = R.add(split[0], split[1])
+                R.output(out)
+            return out
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((1, 4, 64, 64), dtype="float32"),
+            w: R.Tensor((512, 4, 3, 3), dtype="float32"),
+            bias: R.Tensor((512, 1, 1), dtype="float32"),
+        ) -> R.Tensor((1, 256, 64, 64), dtype="float32"):
+            with R.dataflow():
+                lv = R.astype(x, dtype="float16")
+                lv1 = R.astype(w, dtype="float16")
+                conv = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 1, 1],
+                    out_dtype="float16",
+                )
+                lv2 = R.astype(conv, dtype="float32")
+                bias_out = R.add(lv2, bias)
+                split = R.split(bias_out, indices_or_sections=2, axis=1)
+                lv3 = split[0]
+                lv4 = split[1]
+                out = R.add(lv3, lv4)
+                R.output(out)
+            return out
+
+    _assert_test(Module, expected2=Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to