This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 9c36056a12 [Unity] Allow modifying function signature by AMP to accept 
fp16 inputs  (#14719)
9c36056a12 is described below

commit 9c36056a12c0588540671b8771ad681d4bfb6618
Author: masahi <masahi...@gmail.com>
AuthorDate: Fri Apr 28 00:43:18 2023 +0900

    [Unity] Allow modifying function signature by AMP to accept fp16 inputs  
(#14719)
    
    * Modify func sig to accept fp16 inputs in AMP
    
    * add test
    
    * add doc
    
    * fix
    
    * cpplint
---
 include/tvm/relax/transform.h                      |  5 ++-
 python/tvm/relax/transform/transform.py            | 10 ++++-
 src/relax/transform/to_mixed_precision.cc          | 40 +++++++++++++++----
 .../relax/test_transform_to_mixed_precision.py     | 46 ++++++++++++++++++++++
 4 files changed, 91 insertions(+), 10 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 27bd1bd702..9c3a763d69 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -479,9 +479,12 @@ TVM_DLL Pass DeadCodeElimination(Array<runtime::String> 
entry_functions);
  * \brief Automatic mixed precision pass. Currently the pass assumes the input 
module to be fp32
  * only, and will automatically cast fp32 to fp16 for certain ops.
  * \param out_dtype The output data type of gemm/conv, which is the data type 
of the accumulator.
+ * \param fp16_input_names The names of function parameters whose dtype should 
become fp16. The
+ * function signature would change accordingly.
  * \return The Pass.
  */
-TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype);
+TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype,
+                              Optional<Array<String>> fp16_input_names = 
NullOpt);
 
 /*!
  * \brief Rewrite a Relax module for executing with CUDA graph. This pass 
identifies
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 870b731883..46f908c448 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -914,19 +914,25 @@ def DeadCodeElimination(entry_functions: 
Optional[List[str]] = None) -> tvm.ir.t
     return _ffi_api.DeadCodeElimination(entry_functions)  # type: ignore
 
 
-def ToMixedPrecision(out_dtype="float32") -> tvm.ir.transform.Pass:
+def ToMixedPrecision(
+    out_dtype="float32", fp16_input_names: Optional[List[str]] = None
+) -> tvm.ir.transform.Pass:
     """Automatic mixed precision pass. Currently the pass assumes the input 
module to be fp32
     only, and will automatically cast fp32 to fp16 for certain ops.
     Parameters
     ----------
     out_dtype : str
         The output data type of gemm/conv, which is the data type of the 
accumulator.
+    fp16_input_names : List[str]
+        The names of function parameters whose dtype should become fp16. The  
function signature
+        would change accordingly.
+
     Returns
     -------
     ret : tvm.transform.Pass
         The registered pass for mixed precision.
     """
-    return _ffi_api.ToMixedPrecision(out_dtype)  # type: ignore
+    return _ffi_api.ToMixedPrecision(out_dtype, fp16_input_names)  # type: 
ignore
 
 
 def SplitCallTIRByPattern(patterns, fcodegen) -> tvm.ir.transform.Pass:
diff --git a/src/relax/transform/to_mixed_precision.cc 
b/src/relax/transform/to_mixed_precision.cc
index a04d5dbd3a..64763276d0 100644
--- a/src/relax/transform/to_mixed_precision.cc
+++ b/src/relax/transform/to_mixed_precision.cc
@@ -27,6 +27,7 @@
 
 #include <array>
 #include <cstdint>
+#include <unordered_set>
 
 #include "../op/nn/convolution.h"
 #include "../op/tensor/datatype.h"
@@ -273,13 +274,30 @@ class DTypeDecisionCollector : public ExprVisitor {
 
 class ToMixedPrecisionRewriter : public ExprMutator {
  public:
-  explicit ToMixedPrecisionRewriter(const VarDTypeMap* only_fp16_map, DataType 
output_dtype)
-      : only_fp16_map_(only_fp16_map), output_dtype_(output_dtype) {}
+  explicit ToMixedPrecisionRewriter(const VarDTypeMap* only_fp16_map, DataType 
output_dtype,
+                                    const std::unordered_set<std::string>& 
fp16_input_names)
+      : only_fp16_map_(only_fp16_map),
+        output_dtype_(output_dtype),
+        fp16_input_names_(fp16_input_names) {}
 
  private:
   Var GetRemapped(const Var& var) {
     auto it = var_remap_.find(var->vid);
-    return it == var_remap_.end() ? var : it->second;
+    if (it != var_remap_.end()) {
+      return it->second;
+    } else {
+      if (fp16_input_names_.count(var->name_hint())) {
+        auto sinfo = GetStructInfo(var);
+        if (auto tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
+          TensorStructInfo fp16_sinfo(tensor_sinfo->shape.value(), 
DataType::Float(16),
+                                      tensor_sinfo->span);
+          Var fp16_var(var->vid, fp16_sinfo, var->span);
+          var_remap_[var->vid] = fp16_var;
+          return fp16_var;
+        }
+      }
+      return var;
+    }
   }
 
   Array<Expr> RemapArgs(const Array<Expr>& args) {
@@ -427,6 +445,8 @@ class ToMixedPrecisionRewriter : public ExprMutator {
     return VisitVar_(GetRef<Var>(op));
   }
 
+  Var VisitVarDef(const Var& var) { return GetRemapped(var); }
+
   Expr VisitExpr_(const DataflowVarNode* op) final {
     if (!builder_->CurrentBlockIsDataFlow()) {
       return ExprMutator::VisitExpr_(op);
@@ -561,22 +581,28 @@ class ToMixedPrecisionRewriter : public ExprMutator {
   DataType fp32_ = DataType(DataType::TypeCode::kFloat, 32, 1);
   DataType output_dtype_;
   Array<Var> params_;
+  std::unordered_set<std::string> fp16_input_names_;
 
   const Op& wrap_param_op = Op::Get("relax.wrap_param");
 };
 
-Expr ToMixedPrecision(const Function& f, const DataType& out_dtype) {
+Expr ToMixedPrecision(const Function& f, const DataType& out_dtype,
+                      Optional<Array<String>> fp16_input_names) {
   VarDTypeMap only_fp16_map = std::move(DTypeDecisionCollector::Collect(f, 
out_dtype));
-  ToMixedPrecisionRewriter mutator(&only_fp16_map, out_dtype);
+  std::unordered_set<std::string> fp16_input_names_set;
+  if (fp16_input_names) {
+    fp16_input_names_set.insert(fp16_input_names.value().begin(), 
fp16_input_names.value().end());
+  }
+  ToMixedPrecisionRewriter mutator(&only_fp16_map, out_dtype, 
fp16_input_names_set);
   return mutator(f);
 }
 
 namespace transform {
 
-Pass ToMixedPrecision(const DataType& out_dtype) {
+Pass ToMixedPrecision(const DataType& out_dtype, Optional<Array<String>> 
fp16_input_names) {
   runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
       [=](Function f, IRModule m, PassContext pc) {
-        return Downcast<Function>(ToMixedPrecision(f, out_dtype));
+        return Downcast<Function>(ToMixedPrecision(f, out_dtype, 
fp16_input_names));
       };
   return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {});
 }
diff --git a/tests/python/relax/test_transform_to_mixed_precision.py 
b/tests/python/relax/test_transform_to_mixed_precision.py
index 6bae732927..cb179a8c25 100644
--- a/tests/python/relax/test_transform_to_mixed_precision.py
+++ b/tests/python/relax/test_transform_to_mixed_precision.py
@@ -990,5 +990,51 @@ def test_conv2d_bias_fp32():
     _assert_test(Input_bound, expected2=Expected_no_bias_cast)
 
 
+def test_convert_sig():
+    @tvm.script.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((1, 4, 64, 64), dtype="float32"),
+            w: R.Tensor((512, 4, 3, 3), dtype="float32"),
+            bias: R.Tensor((512,), dtype="float32"),
+        ) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv142: R.Tensor((1, 4, 64, 64), dtype="float32") = R.nn.conv2d(
+                    x,
+                    w,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    out_dtype="float32",
+                )
+                lv143: R.Tensor((1, 4, 1, 1), dtype="float32") = 
R.reshape(bias, (1, 512, 1, 1))
+                lv144: R.Tensor((1, 4, 64, 64), dtype="float32") = 
R.add(lv142, lv143)
+                R.output(lv144)
+            return lv144
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((1, 4, 64, 64), dtype="float32"),
+            w: R.Tensor((512, 4, 3, 3), dtype="float16"),
+            bias: R.Tensor((512,), dtype="float16"),
+        ) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
+            with R.dataflow():
+                lv = R.astype(x, dtype="float16")
+                lv142 = R.nn.conv2d(
+                    lv, w, strides=[1, 1], padding=[0, 0, 0, 0], 
out_dtype="float16"
+                )
+                lv143 = R.reshape(bias, R.shape([1, 512, 1, 1]))
+                lv1 = R.add(lv142, lv143)
+                lv144 = R.astype(lv1, dtype="float32")
+                R.output(lv144)
+            return lv144
+
+    mod = ToMixedPrecision(out_dtype="float16", fp16_input_names=["w", 
"bias"])(Input)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to