This is an automated email from the ASF dual-hosted git repository. jwfromm pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 9c36056a12 [Unity] Allow modifying function signature by AMP to accept fp16 inputs (#14719) 9c36056a12 is described below commit 9c36056a12c0588540671b8771ad681d4bfb6618 Author: masahi <masahi...@gmail.com> AuthorDate: Fri Apr 28 00:43:18 2023 +0900 [Unity] Allow modifying function signature by AMP to accept fp16 inputs (#14719) * Modify func sig to accept fp16 inputs in AMP * add test * add doc * fix * cpplint --- include/tvm/relax/transform.h | 5 ++- python/tvm/relax/transform/transform.py | 10 ++++- src/relax/transform/to_mixed_precision.cc | 40 +++++++++++++++---- .../relax/test_transform_to_mixed_precision.py | 46 ++++++++++++++++++++++ 4 files changed, 91 insertions(+), 10 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 27bd1bd702..9c3a763d69 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -479,9 +479,12 @@ TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions); * \brief Automatic mixed precision pass. Currently the pass assumes the input module to be fp32 * only, and will automatically cast fp32 to fp16 for certain ops. * \param out_dtype The output data type of gemm/conv, which is the data type of the accumulator. + * \param fp16_input_names The names of function parameters whose dtype should become fp16. The + * function signature would change accordingly. * \return The Pass. */ -TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype); +TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype, + Optional<Array<String>> fp16_input_names = NullOpt); /*! * \brief Rewrite a Relax module for executing with CUDA graph. This pass identifies diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 870b731883..46f908c448 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -914,19 +914,25 @@ def DeadCodeElimination(entry_functions: Optional[List[str]] = None) -> tvm.ir.t return _ffi_api.DeadCodeElimination(entry_functions) # type: ignore -def ToMixedPrecision(out_dtype="float32") -> tvm.ir.transform.Pass: +def ToMixedPrecision( + out_dtype="float32", fp16_input_names: Optional[List[str]] = None +) -> tvm.ir.transform.Pass: """Automatic mixed precision pass. Currently the pass assumes the input module to be fp32 only, and will automatically cast fp32 to fp16 for certain ops. Parameters ---------- out_dtype : str The output data type of gemm/conv, which is the data type of the accumulator. + fp16_input_names : List[str] + The names of function parameters whose dtype should become fp16. The function signature + would change accordingly. + Returns ------- ret : tvm.transform.Pass The registered pass for mixed precision. """ - return _ffi_api.ToMixedPrecision(out_dtype) # type: ignore + return _ffi_api.ToMixedPrecision(out_dtype, fp16_input_names) # type: ignore def SplitCallTIRByPattern(patterns, fcodegen) -> tvm.ir.transform.Pass: diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc index a04d5dbd3a..64763276d0 100644 --- a/src/relax/transform/to_mixed_precision.cc +++ b/src/relax/transform/to_mixed_precision.cc @@ -27,6 +27,7 @@ #include <array> #include <cstdint> +#include <unordered_set> #include "../op/nn/convolution.h" #include "../op/tensor/datatype.h" @@ -273,13 +274,30 @@ class DTypeDecisionCollector : public ExprVisitor { class ToMixedPrecisionRewriter : public ExprMutator { public: - explicit ToMixedPrecisionRewriter(const VarDTypeMap* only_fp16_map, DataType output_dtype) - : only_fp16_map_(only_fp16_map), output_dtype_(output_dtype) {} + explicit ToMixedPrecisionRewriter(const VarDTypeMap* only_fp16_map, DataType output_dtype, + const std::unordered_set<std::string>& fp16_input_names) + : only_fp16_map_(only_fp16_map), + output_dtype_(output_dtype), + fp16_input_names_(fp16_input_names) {} private: Var GetRemapped(const Var& var) { auto it = var_remap_.find(var->vid); - return it == var_remap_.end() ? var : it->second; + if (it != var_remap_.end()) { + return it->second; + } else { + if (fp16_input_names_.count(var->name_hint())) { + auto sinfo = GetStructInfo(var); + if (auto tensor_sinfo = sinfo.as<TensorStructInfoNode>()) { + TensorStructInfo fp16_sinfo(tensor_sinfo->shape.value(), DataType::Float(16), + tensor_sinfo->span); + Var fp16_var(var->vid, fp16_sinfo, var->span); + var_remap_[var->vid] = fp16_var; + return fp16_var; + } + } + return var; + } } Array<Expr> RemapArgs(const Array<Expr>& args) { @@ -427,6 +445,8 @@ class ToMixedPrecisionRewriter : public ExprMutator { return VisitVar_(GetRef<Var>(op)); } + Var VisitVarDef(const Var& var) { return GetRemapped(var); } + Expr VisitExpr_(const DataflowVarNode* op) final { if (!builder_->CurrentBlockIsDataFlow()) { return ExprMutator::VisitExpr_(op); @@ -561,22 +581,28 @@ class ToMixedPrecisionRewriter : public ExprMutator { DataType fp32_ = DataType(DataType::TypeCode::kFloat, 32, 1); DataType output_dtype_; Array<Var> params_; + std::unordered_set<std::string> fp16_input_names_; const Op& wrap_param_op = Op::Get("relax.wrap_param"); }; -Expr ToMixedPrecision(const Function& f, const DataType& out_dtype) { +Expr ToMixedPrecision(const Function& f, const DataType& out_dtype, + Optional<Array<String>> fp16_input_names) { VarDTypeMap only_fp16_map = std::move(DTypeDecisionCollector::Collect(f, out_dtype)); - ToMixedPrecisionRewriter mutator(&only_fp16_map, out_dtype); + std::unordered_set<std::string> fp16_input_names_set; + if (fp16_input_names) { + fp16_input_names_set.insert(fp16_input_names.value().begin(), fp16_input_names.value().end()); + } + ToMixedPrecisionRewriter mutator(&only_fp16_map, out_dtype, fp16_input_names_set); return mutator(f); } namespace transform { -Pass ToMixedPrecision(const DataType& out_dtype) { +Pass ToMixedPrecision(const DataType& out_dtype, Optional<Array<String>> fp16_input_names) { runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast<Function>(ToMixedPrecision(f, out_dtype)); + return Downcast<Function>(ToMixedPrecision(f, out_dtype, fp16_input_names)); }; return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {}); } diff --git a/tests/python/relax/test_transform_to_mixed_precision.py b/tests/python/relax/test_transform_to_mixed_precision.py index 6bae732927..cb179a8c25 100644 --- a/tests/python/relax/test_transform_to_mixed_precision.py +++ b/tests/python/relax/test_transform_to_mixed_precision.py @@ -990,5 +990,51 @@ def test_conv2d_bias_fp32(): _assert_test(Input_bound, expected2=Expected_no_bias_cast) +def test_convert_sig(): + @tvm.script.ir_module + class Input: + @R.function + def main( + x: R.Tensor((1, 4, 64, 64), dtype="float32"), + w: R.Tensor((512, 4, 3, 3), dtype="float32"), + bias: R.Tensor((512,), dtype="float32"), + ) -> R.Tensor((1, 512, 64, 64), dtype="float32"): + # block 0 + with R.dataflow(): + lv142: R.Tensor((1, 4, 64, 64), dtype="float32") = R.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + out_dtype="float32", + ) + lv143: R.Tensor((1, 4, 1, 1), dtype="float32") = R.reshape(bias, (1, 512, 1, 1)) + lv144: R.Tensor((1, 4, 64, 64), dtype="float32") = R.add(lv142, lv143) + R.output(lv144) + return lv144 + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((1, 4, 64, 64), dtype="float32"), + w: R.Tensor((512, 4, 3, 3), dtype="float16"), + bias: R.Tensor((512,), dtype="float16"), + ) -> R.Tensor((1, 512, 64, 64), dtype="float32"): + with R.dataflow(): + lv = R.astype(x, dtype="float16") + lv142 = R.nn.conv2d( + lv, w, strides=[1, 1], padding=[0, 0, 0, 0], out_dtype="float16" + ) + lv143 = R.reshape(bias, R.shape([1, 512, 1, 1])) + lv1 = R.add(lv142, lv143) + lv144 = R.astype(lv1, dtype="float32") + R.output(lv144) + return lv144 + + mod = ToMixedPrecision(out_dtype="float16", fp16_input_names=["w", "bias"])(Input) + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main()