This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 8569b9fe63 [Unity][AMP] Fix merging concrete type and "unknown" type (#14612) 8569b9fe63 is described below commit 8569b9fe632b9488978bede30e169d457aafd8c9 Author: masahi <masahi...@gmail.com> AuthorDate: Mon Apr 17 11:42:38 2023 +0900 [Unity][AMP] Fix merging concrete type and "unknown" type (#14612) * fix merging unknown type with concrete type * add test * put early return at the beginning --- src/relax/transform/infer_amp_utils.cc | 6 ++ .../relax/test_transform_to_mixed_precision.py | 67 ++++++++++++++++++++-- 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/src/relax/transform/infer_amp_utils.cc b/src/relax/transform/infer_amp_utils.cc index 330fe9a72a..efe94c43c0 100644 --- a/src/relax/transform/infer_amp_utils.cc +++ b/src/relax/transform/infer_amp_utils.cc @@ -38,6 +38,12 @@ NType NTypeFrom(const Expr& expr, DataType dtype) { return NTypeFrom(GetStructIn NType NTypeMerge(const NType& a, const NType& b) { auto fcombine = [&](const String& a_str, const String& b_str) -> String { + if (a_str == "") { + return b_str; + } else if (b_str == "") { + return a_str; + } + DataType a = DataType(String2DLDataType(a_str)); DataType b = DataType(String2DLDataType(b_str)); ICHECK_EQ(a.code(), b.code()); diff --git a/tests/python/relax/test_transform_to_mixed_precision.py b/tests/python/relax/test_transform_to_mixed_precision.py index 6b699b5165..721cbd9d58 100644 --- a/tests/python/relax/test_transform_to_mixed_precision.py +++ b/tests/python/relax/test_transform_to_mixed_precision.py @@ -23,12 +23,15 @@ from tvm.relax.transform import ToMixedPrecision from tvm.script.parser import ir as I, relax as R -def _assert_test(input, expected, expected2): - mod = ToMixedPrecision()(input) - tvm.ir.assert_structural_equal(mod, expected) - mod = ToMixedPrecision(out_dtype="float16")(input) - print(mod.script()) - tvm.ir.assert_structural_equal(mod, expected2) +def _assert_test(input, expected=None, expected2=None): + if expected: + mod = ToMixedPrecision()(input) + tvm.ir.assert_structural_equal(mod, expected) + + if expected2: + mod = ToMixedPrecision(out_dtype="float16")(input) + print(mod.script()) + tvm.ir.assert_structural_equal(mod, expected2) def test_conv2d(): @@ -841,5 +844,57 @@ def test_conv2d_bias_conv2d(): _assert_test(Input, Expected, Expected2) +def test_tuple_get(): + @tvm.script.ir_module + class Module: + @R.function + def main( + x: R.Tensor((1, 4, 64, 64), dtype="float32"), + w: R.Tensor((512, 4, 3, 3), dtype="float32"), + bias: R.Tensor((512, 1, 1), dtype="float32"), + ) -> R.Tensor((1, 256, 64, 64), dtype="float32"): + with R.dataflow(): + conv = R.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 1, 1], + ) + bias_out = R.add(conv, bias) + split = R.split(bias_out, indices_or_sections=2, axis=1) + out = R.add(split[0], split[1]) + R.output(out) + return out + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((1, 4, 64, 64), dtype="float32"), + w: R.Tensor((512, 4, 3, 3), dtype="float32"), + bias: R.Tensor((512, 1, 1), dtype="float32"), + ) -> R.Tensor((1, 256, 64, 64), dtype="float32"): + with R.dataflow(): + lv = R.astype(x, dtype="float16") + lv1 = R.astype(w, dtype="float16") + conv = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 1, 1], + out_dtype="float16", + ) + lv2 = R.astype(conv, dtype="float32") + bias_out = R.add(lv2, bias) + split = R.split(bias_out, indices_or_sections=2, axis=1) + lv3 = split[0] + lv4 = split[1] + out = R.add(lv3, lv4) + R.output(out) + return out + + _assert_test(Module, expected2=Expected) + + if __name__ == "__main__": tvm.testing.main()