This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 85976ea1a9 [Unity][Frontend] FX translator returning weights with 
`keep_params_as_input` (#14197)
85976ea1a9 is described below

commit 85976ea1a9b5b237215be9f26cc143038778d095
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Mar 4 22:59:53 2023 -0500

    [Unity][Frontend] FX translator returning weights with 
`keep_params_as_input` (#14197)
    
    PR #14067 introduces the flag `keep_params_as_input` to the FX
    translator, in the purpose to handle to model weights outside of the
    translated Relax function.
    
    This PR takes a further step, by returning the model weights as
    NDArrays when the flag `keep_params_as_input` is true. With this PR, the
    translator now can return back the weights upon requested. Otherwise,
    after the import we will lose the model weights in the given PyTorch
    model.
---
 python/tvm/relax/frontend/__init__.py            |  2 +
 python/tvm/relax/frontend/common.py              | 48 ++++++++++++++++++++++++
 python/tvm/relax/frontend/torch/dynamo.py        | 28 ++++++++++----
 python/tvm/relax/frontend/torch/fx_translator.py | 22 +++++++----
 tests/python/relax/test_frontend_dynamo.py       |  4 +-
 tests/python/relax/test_frontend_from_fx.py      | 19 ++++++++--
 6 files changed, 103 insertions(+), 20 deletions(-)

diff --git a/python/tvm/relax/frontend/__init__.py 
b/python/tvm/relax/frontend/__init__.py
index 6c9c188aaa..f3c0ed23eb 100644
--- a/python/tvm/relax/frontend/__init__.py
+++ b/python/tvm/relax/frontend/__init__.py
@@ -17,3 +17,5 @@
 """
 Frontends for constructing Relax programs, with the model importers
 """
+from . import torch
+from .common import ImporterOutput
diff --git a/python/tvm/relax/frontend/common.py 
b/python/tvm/relax/frontend/common.py
new file mode 100644
index 0000000000..cdb88cd12c
--- /dev/null
+++ b/python/tvm/relax/frontend/common.py
@@ -0,0 +1,48 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Commons for Relax frontend."""
+from typing import Dict, List, Optional
+
+import tvm
+
+
+class ImporterOutput:
+    """The data structure representing the result of frontend imports.
+
+    Attributes
+    ----------
+    mod : tvm.IRModule
+        The IRModule imported from frontend.
+
+    params : Optional[Dict[str, List[tvm.nd.NDArray]]]
+        The weights of the imported model, when the weights of the model are
+        requested to be kept as parameters of functions in the IRModule. (e.g.,
+        when the `keep_params_as_input` flag of `frontend.torch.from_fx` is 
set to
+        True.)
+        - `params` is defined to be None when not requested.
+        - The keys of `params` are the names of the Relax functions in the 
IRModule.
+        - Each weight tensor is in the form of TVM NDArray on device CPU.
+        - The order of the returned weights is in accordance with the order of
+        the kept Relax function input variables.
+    """
+
+    mod: tvm.IRModule
+    params: Optional[Dict[str, List[tvm.nd.NDArray]]]
+
+    def __init__(self, mod: tvm.IRModule, params: Optional[Dict[str, 
List[tvm.nd.NDArray]]]):
+        self.mod = mod
+        self.params = params
diff --git a/python/tvm/relax/frontend/torch/dynamo.py 
b/python/tvm/relax/frontend/torch/dynamo.py
index 589c6be3b5..3f30044bb8 100644
--- a/python/tvm/relax/frontend/torch/dynamo.py
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -24,7 +24,9 @@ from typing import Optional
 
 import tvm
 from tvm.relax import build as relax_build
-from tvm.relax.frontend.torch.fx_translator import from_fx
+
+from .fx_translator import from_fx
+from ..common import ImporterOutput
 
 
 def device_from_inputs(example_inputs):
@@ -72,7 +74,7 @@ def relax_dynamo(pipeline: Optional[tvm.transform.Pass] = 
None):
 
         device = device_from_inputs(example_inputs)
         input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in 
example_inputs]
-        mod = from_fx(graph_module, input_info)
+        mod = from_fx(graph_module, input_info).mod
 
         if device.type == "cuda":
             dev = tvm.cuda(device.index)
@@ -114,7 +116,7 @@ def relax_dynamo(pipeline: Optional[tvm.transform.Pass] = 
None):
     return _relax_backend
 
 
-def dynamo_capture_subgraphs(model, *params) -> tvm.ir.IRModule:
+def dynamo_capture_subgraphs(model, *params, **kwargs) -> ImporterOutput:
     """Capture subgraphs of the PyTorch model using torch.compile into an 
IRModule.
 
     Parameters
@@ -125,28 +127,38 @@ def dynamo_capture_subgraphs(model, *params) -> 
tvm.ir.IRModule:
     params : List[torch.Tensor]
         The parameters of the PyTorch model.
 
+    keep_params_as_input : bool
+        Whether to keep model parameters as input variables of the captured 
Relax functions.
+
     Returns
     -------
-    mod : tvm.ir.IRModule
-        The IRModule that contains captured subgraphs.
+    output : ImporterOutput
+        The output of translation, including the translated IRModule, and
+        the weights of the input model when `keep_params_as_input` is true.
     """
     import torch  # type: ignore[import]
     from torch import fx  # type: ignore[import]
     from torch import _dynamo as dynamo  # type: ignore[import]
 
+    keep_params_as_input = "keep_params_as_input" in kwargs and 
kwargs["keep_params_as_input"]
+
     mod = tvm.IRModule()
+    params_ndarray = dict() if keep_params_as_input else None
 
     def _capture(graph_module: fx.GraphModule, example_inputs):
         assert isinstance(graph_module, torch.fx.GraphModule)
         input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in 
example_inputs]
-        subgraph = from_fx(graph_module, input_info)
-        mod["subgraph_" + str(len(mod.get_global_vars()))] = subgraph["main"]
+        trace_output = from_fx(graph_module, input_info, keep_params_as_input)
+        func_name = f"subgraph_{len(mod.get_global_vars())}"
+        mod[func_name] = trace_output.mod["main"]
+        if keep_params_as_input:
+            params_ndarray[func_name] = trace_output.params["main"]
         return graph_module.forward
 
     dynamo.reset()
     compiled_model = torch.compile(model, backend=_capture)
     compiled_model(*params)
-    return mod
+    return ImporterOutput(mod, params_ndarray)
 
 
 @functools.lru_cache(None)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index b580e1679b..a73bc9d0db 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -24,6 +24,8 @@ from functools import reduce
 import tvm
 from tvm import relax
 
+from ..common import ImporterOutput
+
 
 class TorchFXImporter:
     """An importer from PyTorch FX to Relax."""
@@ -843,7 +845,7 @@ class TorchFXImporter:
 
     def from_fx(
         self, model, input_info: List[Tuple[Tuple[int], str]], 
keep_params_as_input: bool
-    ) -> tvm.IRModule:
+    ) -> ImporterOutput:
         """Convert a PyTorch FX GraphModule to a Relax program."""
         from torch import fx
 
@@ -860,18 +862,23 @@ class TorchFXImporter:
             )
 
         # Initialize the block builder with a function and a dataflow block.
+        func_name = "main"
         self.block_builder = relax.BlockBuilder()
         if keep_params_as_input:
+            params_ = []
             func_attrs = {"num_input": len(inputs)}
             for name, param in model.named_parameters():
                 shape = param.data.shape
                 dtype = self._convert_data_type(str(param.data.dtype))
                 inputs.append(relax.Var(name, relax.TensorStructInfo(shape, 
dtype)))
                 self.params[param] = inputs[-1]
+                params_.append(tvm.nd.array(param.data.cpu().numpy()))
+            params = {func_name: params_}
         else:
+            params = None
             func_attrs = None
 
-        with self.block_builder.function(name="main", params=inputs.copy(), 
attrs=func_attrs):
+        with self.block_builder.function(name=func_name, params=inputs.copy(), 
attrs=func_attrs):
             output = None
             with self.block_builder.dataflow():
                 # Translate model parameters.
@@ -916,12 +923,12 @@ class TorchFXImporter:
             assert output is not None
             self.block_builder.emit_func_output(output)
 
-        return self.block_builder.get()
+        return ImporterOutput(self.block_builder.get(), params)
 
 
 def from_fx(
     model, input_info: List[Tuple[Tuple[int], str]], keep_params_as_input: 
bool = False
-) -> tvm.IRModule:
+) -> ImporterOutput:
     """Convert a PyTorch FX GraphModule to a Relax program
 
     Parameters
@@ -937,8 +944,9 @@ def from_fx(
 
     Returns
     -------
-    module : tvm.IRModule
-        The converted Relax program.
+    output : ImporterOutput
+        The output of translation, including the translated IRModule, and
+        the weights of the input model when `keep_params_as_input` is true.
 
     Examples
     --------
@@ -981,7 +989,7 @@ def from_fx(
             raise RuntimeError("Failed to export the PyTorch model to FX.")
 
         # Use the importer to import the PyTorch model to Relax.
-        mod: tvm.IRModule = from_fx(graph_module, input_info)
+        mod: tvm.IRModule = from_fx(graph_module, input_info).mod
 
         # Print out the imported model.
         print(mod.script())
diff --git a/tests/python/relax/test_frontend_dynamo.py 
b/tests/python/relax/test_frontend_dynamo.py
index b47e3e22bd..14d1e48fb5 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -147,7 +147,7 @@ def test_subgraph_capture():
             return gv
 
     model = Input1()
-    mod = dynamo_capture_subgraphs(model, torch.randn(10, 100))
+    mod = dynamo_capture_subgraphs(model, torch.randn(10, 100)).mod
     binding = {"w0": model.lin.weight.detach().numpy(), "w1": 
model.lin.bias.detach().numpy()}
     binding = {k: tvm.nd.array(v) for k, v in binding.items()}
     expected = relax.transform.BindParams("subgraph_0", binding)(Expected1)
@@ -190,7 +190,7 @@ def test_subgraph_capture():
                 R.output(gv1)
             return gv1
 
-    mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10))
+    mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10)).mod
     tvm.ir.assert_structural_equal(mod, Expected2)
 
 
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 9ab0b3304c..e28483dc2f 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -27,7 +27,7 @@ def verify_model(torch_model, input_info, binding, expected, 
keep_params_as_inpu
     from tvm.relax.frontend.torch import from_fx
 
     graph_model = fx.symbolic_trace(torch_model)
-    mod = from_fx(graph_model, input_info, 
keep_params_as_input=keep_params_as_input)
+    mod = from_fx(graph_model, input_info, 
keep_params_as_input=keep_params_as_input).mod
     binding = {k: tvm.nd.array(v) for k, v in binding.items()}
     expected = relax.transform.BindParams("main", binding)(expected)
     tvm.ir.assert_structural_equal(mod, expected)
@@ -2096,7 +2096,9 @@ def test_view():
 @tvm.testing.requires_gpu
 def test_keep_params():
     import torch
+    from torch import fx
     from torch.nn import Module
+    from tvm.relax.frontend.torch import from_fx
 
     class Conv2D1(Module):
         def __init__(self):
@@ -2135,8 +2137,19 @@ def test_keep_params():
             return gv
 
     model = Conv2D1()
-    input_info = [([1, 3, 10, 10], "float32")]
-    verify_model(model, input_info, {}, expected1, keep_params_as_input=True)
+    graph_model = fx.symbolic_trace(model)
+    trace_output = from_fx(graph_model, [([1, 3, 10, 10], "float32")], 
keep_params_as_input=True)
+    tvm.ir.assert_structural_equal(trace_output.mod, expected1)
+    func = trace_output.mod["main"]
+    params = trace_output.params["main"]
+
+    assert len(params) == len(func.params) - 1
+    for param_var, param_ndarray in zip(func.params[1:], params):
+        assert tuple(x.value for x in param_var.struct_info.shape.values) == 
param_ndarray.shape
+        assert param_var.struct_info.dtype == param_ndarray.dtype
+
+    tvm.testing.assert_allclose(params[0].numpy(), 
model.conv.weight.detach().numpy())
+    tvm.testing.assert_allclose(params[1].numpy(), 
model.conv.bias.detach().numpy())
 
 
 @tvm.testing.requires_gpu

Reply via email to