This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 001d5ec90c [Relax][PyTorch][Docs] Use `torch.export` insteamd of
`fx.symbolic_trace` for tutorial (#17436)
001d5ec90c is described below
commit 001d5ec90c2821b16f9d4edd913dfeff03c027a3
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Tue Oct 8 09:57:27 2024 +0900
[Relax][PyTorch][Docs] Use `torch.export` insteamd of `fx.symbolic_trace`
for tutorial (#17436)
* use torch.export
* in order to make interface consistent, user inputs should be placed first
* chore
---
docs/get_started/tutorials/ir_module.py | 15 +++--
docs/how_to/tutorials/e2e_opt_model.py | 18 +++---
.../frontend/torch/exported_program_translator.py | 71 +++++++++++-----------
.../relax/test_frontend_from_exported_program.py | 4 +-
4 files changed, 56 insertions(+), 52 deletions(-)
diff --git a/docs/get_started/tutorials/ir_module.py
b/docs/get_started/tutorials/ir_module.py
index f813333baf..0a825c3da7 100644
--- a/docs/get_started/tutorials/ir_module.py
+++ b/docs/get_started/tutorials/ir_module.py
@@ -40,8 +40,9 @@ from tvm import relax
# below.
import torch
-from torch import fx, nn
-from tvm.relax.frontend.torch import from_fx
+from torch import nn
+from torch.export import export
+from tvm.relax.frontend.torch import from_exported_program
######################################################################
# Import from existing models
@@ -67,13 +68,15 @@ class TorchModel(nn.Module):
return x
-# Give the input shape and data type
-input_info = [((1, 784), "float32")]
+# Give an example argument to torch.export
+example_args = (torch.randn(1, 784, dtype=torch.float32),)
# Convert the model to IRModule
with torch.no_grad():
- torch_fx_model = fx.symbolic_trace(TorchModel())
- mod_from_torch = from_fx(torch_fx_model, input_info,
keep_params_as_input=True)
+ exported_program = export(TorchModel().eval(), example_args)
+ mod_from_torch = from_exported_program(
+ exported_program, keep_params_as_input=True,
unwrap_unit_return_tuple=True
+ )
mod_from_torch, params_from_torch =
relax.frontend.detach_params(mod_from_torch)
# Print the IRModule
diff --git a/docs/how_to/tutorials/e2e_opt_model.py
b/docs/how_to/tutorials/e2e_opt_model.py
index 5c11439e16..532fb89fd3 100644
--- a/docs/how_to/tutorials/e2e_opt_model.py
+++ b/docs/how_to/tutorials/e2e_opt_model.py
@@ -34,10 +34,10 @@ Please note that default end-to-end optimization may not
suit complex models.
import os
import numpy as np
import torch
-from torch import fx
+from torch.export import export
from torchvision.models.resnet import ResNet18_Weights, resnet18
-torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
+torch_model = resnet18(weights=ResNet18_Weights.DEFAULT).eval()
######################################################################
# Review Overall Flow
@@ -63,21 +63,19 @@ torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
# Convert the model to IRModule
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Next step, we convert the model to an IRModule using the Relax frontend for
PyTorch for further
-# optimization. Besides the model, we also need to provide the input shape and
data type.
+# optimization.
import tvm
from tvm import relax
-from tvm.relax.frontend.torch import from_fx
+from tvm.relax.frontend.torch import from_exported_program
-torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
-
-# Give the input shape and data type
-input_info = [((1, 3, 224, 224), "float32")]
+# Give an example argument to torch.export
+example_args = (torch.randn(1, 3, 224, 224, dtype=torch.float32),)
# Convert the model to IRModule
with torch.no_grad():
- torch_fx_model = fx.symbolic_trace(torch_model)
- mod = from_fx(torch_fx_model, input_info, keep_params_as_input=True)
+ exported_program = export(torch_model, example_args)
+ mod = from_exported_program(exported_program, keep_params_as_input=True)
mod, params = relax.frontend.detach_params(mod)
mod.show()
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 1401a0bcef..7bcd20c462 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -34,37 +34,6 @@ class ExportedProgramImporter(BaseFXGraphImporter):
from torch import fx
- def create_input_vars(
- self, exported_program: torch.export.ExportedProgram
- ) -> Tuple[List[relax.Var], List[relax.Var]]:
- """Create relax input vars."""
- parameters_buffers_constants = []
- user_inputs = []
- for spec in exported_program.graph_signature.input_specs:
- name_hint = spec.arg.name
- if spec.kind is
torch.export.graph_signature.InputKind.CONSTANT_TENSOR:
- shape = exported_program.tensor_constants[spec.target].shape
- torch_dtype =
exported_program.tensor_constants[spec.target].dtype
- elif spec.kind is
torch.export.graph_signature.InputKind.USER_INPUT:
- for node in
exported_program.graph.find_nodes(op="placeholder", target=spec.target):
- if node.name == name_hint:
- shape = node.meta["tensor_meta"].shape
- torch_dtype = node.meta["tensor_meta"].dtype
- break
- else:
- # PARAMETER or BUFFER
- shape = exported_program.state_dict[spec.target].shape
- torch_dtype = exported_program.state_dict[spec.target].dtype
-
- dtype = self._convert_data_type(torch_dtype)
- relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape,
dtype))
- if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
- user_inputs.append(relax_var)
- else:
- parameters_buffers_constants.append(relax_var)
-
- return parameters_buffers_constants, user_inputs
-
########## Unary Ops ##########
def _hardtanh(self, node: fx.Node) -> relax.Expr:
@@ -178,6 +147,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
stride = [node.args[4] if len(node.args) > 4 else 1]
return self.block_builder.emit(relax.op.strided_slice(x, axes, begin,
end, stride))
+ ########## Others ##########
+
def create_convert_map(
self,
) -> Dict[str, Callable[[fx.Node], relax.Var]]:
@@ -293,6 +264,37 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"getitem": self._getitem,
}
+ def create_input_vars(
+ self, exported_program: torch.export.ExportedProgram
+ ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]:
+ """Create relax input vars."""
+ parameters_buffers_constants = OrderedDict()
+ user_inputs = OrderedDict()
+ for spec in exported_program.graph_signature.input_specs:
+ name_hint = spec.arg.name
+ if spec.kind is
torch.export.graph_signature.InputKind.CONSTANT_TENSOR:
+ shape = exported_program.tensor_constants[spec.target].shape
+ torch_dtype =
exported_program.tensor_constants[spec.target].dtype
+ elif spec.kind is
torch.export.graph_signature.InputKind.USER_INPUT:
+ for node in
exported_program.graph.find_nodes(op="placeholder", target=spec.target):
+ if node.name == name_hint:
+ shape = node.meta["tensor_meta"].shape
+ torch_dtype = node.meta["tensor_meta"].dtype
+ break
+ else:
+ # PARAMETER or BUFFER
+ shape = exported_program.state_dict[spec.target].shape
+ torch_dtype = exported_program.state_dict[spec.target].dtype
+
+ dtype = self._convert_data_type(torch_dtype)
+ relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape,
dtype))
+ if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
+ user_inputs[name_hint] = relax_var
+ else:
+ parameters_buffers_constants[name_hint] = relax_var
+
+ return parameters_buffers_constants, user_inputs
+
def from_exported_program(
self,
exported_program: torch.export.ExportedProgram,
@@ -305,7 +307,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
# Create input variables.
parameter_buffer_constant_vars, user_input_vars =
self.create_input_vars(exported_program)
- inputs_vars = parameter_buffer_constant_vars + user_input_vars
+ inputs_vars = user_input_vars.copy()
+ inputs_vars.update(parameter_buffer_constant_vars)
# Initialize the block builder with a function and a dataflow block.
self.block_builder = relax.BlockBuilder()
@@ -314,7 +317,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
nodes: List[fx.Node] = exported_program.graph.nodes
with self.block_builder.function(
- name=func_name, params=inputs_vars.copy(), attrs=func_attrs
+ name=func_name, params=list(inputs_vars.values()).copy(),
attrs=func_attrs
):
output = None
with self.block_builder.dataflow():
@@ -325,7 +328,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
# Ignore sym input
continue
- self.env[node] = inputs_vars.pop(0)
+ self.env[node] = inputs_vars[node.name]
elif node.op == "output":
args = self.retrieve_args(node)
assert len(args) == 1
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 65890ff697..0d8425fc7f 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3550,9 +3550,9 @@ def test_keep_params():
class expected1:
@R.function
def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
conv_weight: R.Tensor((6, 3, 7, 7), dtype="float32"),
conv_bias: R.Tensor((6,), dtype="float32"),
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")):
R.func_attr({"num_input": 1})
# block 0
@@ -3586,7 +3586,7 @@ def test_keep_params():
params = params["main"]
assert len(params) == len(func.params) - 1
- for param_var, param_ndarray in zip(func.params[:-1], params):
+ for param_var, param_ndarray in zip(func.params[1:], params):
assert tuple(x.value for x in param_var.struct_info.shape.values) ==
param_ndarray.shape
assert param_var.struct_info.dtype == param_ndarray.dtype