This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 37fe645f94 [Relax] Ingest Tensor.clamp from torch export (#17725)
37fe645f94 is described below

commit 37fe645f945d56e0d4bfac1d9f3bcf355f950a1b
Author: Hugo Latendresse <[email protected]>
AuthorDate: Tue Mar 11 20:27:53 2025 -0400

    [Relax] Ingest Tensor.clamp from torch export (#17725)
    
    Allow handling of Torch.clamp when only min is passed, only max is
    passed, or tensors are passed as arguments.
---
 .../frontend/torch/base_fx_graph_translator.py     | 96 +++++++++++++++++++---
 .../frontend/torch/exported_program_translator.py  |  3 +
 python/tvm/relax/frontend/torch/fx_translator.py   |  3 +-
 tests/python/relax/test_from_exported_to_cuda.py   | 81 ++++++++++++++++++
 .../relax/test_frontend_from_exported_program.py   | 62 +++++++++++++-
 tests/python/relax/test_frontend_from_fx.py        | 64 ++++++++++-----
 6 files changed, 276 insertions(+), 33 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index a0f00e1f4b..6bbc9d5de6 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -19,6 +19,7 @@
 # pylint: disable=import-outside-toplevel
 """Base class for PyTorch FX Graph importer."""
 import abc
+import math
 from typing import Callable, Dict, Optional, Tuple, Union
 
 from tvm import relax
@@ -141,19 +142,94 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
     def _clamp(self, node: fx.Node) -> relax.Expr:
         args = self.retrieve_args(node)
-        a_min = args[1] if len(args) > 1 else node.kwargs["min"]
-        a_max = args[2] if len(args) > 2 else node.kwargs["max"]
+        x = args[0]
+        a_min = args[1] if len(args) > 1 else node.kwargs.get("min", -math.inf)
+        a_max = args[2] if len(args) > 2 else node.kwargs.get("max", math.inf)
+
+        a_min = -math.inf if a_min is None else a_min
+        a_max = math.inf if a_max is None else a_max
+
+        # Handle the case where a_min is a tensor
         if not isinstance(a_min, (int, float)):
-            raise ValueError(
-                f"TVM only supports constant min value for torch.clamp/clip, "
-                f"but got {a_min} with type {type(a_min)}"
+            from torch import fx
+
+            if isinstance(a_min, fx.Node):
+                # Extract relax Expr (needed for fx.tracer)
+                a_min = self.env[a_min]
+            assert isinstance(a_min, relax.Expr), (
+                f"Unexpected argument type "
+                f"passed to torch.clamp/clip: {a_min} with type {type(a_min)}"
             )
+            a_min = self.block_builder.emit(relax.op.broadcast_to(a_min, 
self.shape_of(x)))
+            x = self.block_builder.emit(relax.op.maximum(x, a_min))
+            a_min = -math.inf
+
+        # Handle the case where a_max is a tensor
         if not isinstance(a_max, (int, float)):
-            raise ValueError(
-                f"TVM only supports constant max value for torch.clamp/clip, "
-                f"but got {a_max} with type {type(a_max)}"
+            from torch import fx
+
+            if isinstance(a_max, fx.Node):
+                # Extract relax Expr (needed for fx.tracer)
+                a_max = self.env[a_max]
+            assert isinstance(a_max, relax.Expr), (
+                f"Unexpected argument type "
+                f"passed to torch.clamp/clip: {a_max} with type {type(a_max)}"
+            )
+            a_max = self.block_builder.emit(relax.op.broadcast_to(a_max, 
self.shape_of(x)))
+            x = self.block_builder.emit(relax.op.minimum(x, a_max))
+            a_max = math.inf
+
+        return self.block_builder.emit(relax.op.clip(x, a_min, a_max))
+
+    def _clamp_min(self, node: fx.Node) -> relax.Expr:
+        args = self.retrieve_args(node)
+        x = args[0]
+        a_min = args[1] if len(args) > 1 else node.kwargs.get("min", -math.inf)
+        a_max = math.inf
+
+        a_min = -math.inf if a_min is None else a_min
+
+        # Handle the case where a_min is a tensor
+        if not isinstance(a_min, (int, float)):
+            from torch import fx
+
+            if isinstance(a_min, fx.Node):
+                # Extract relax Expr (needed for fx.tracer)
+                a_min = self.env[a_min]
+            assert isinstance(a_min, relax.Expr), (
+                f"Unexpected argument type "
+                f"passed to torch.clamp/clip: {a_min} with type {type(a_min)}"
             )
-        return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
+            a_min = self.block_builder.emit(relax.op.broadcast_to(a_min, 
self.shape_of(x)))
+            x = self.block_builder.emit(relax.op.maximum(x, a_min))
+            a_min = -math.inf
+
+        return self.block_builder.emit(relax.op.clip(x, a_min, a_max))
+
+    def _clamp_max(self, node: fx.Node) -> relax.Expr:
+        args = self.retrieve_args(node)
+        x = args[0]
+        a_min = -math.inf
+        a_max = args[2] if len(args) > 2 else node.kwargs.get("max", math.inf)
+
+        a_max = math.inf if a_max is None else a_max
+
+        # Handle the case where a_max is a tensor
+        if not isinstance(a_max, (int, float)):
+            from torch import fx
+
+            if isinstance(a_max, fx.Node):
+                # Extract relax Expr (needed for fx.tracer)
+                a_max = self.env[a_max]
+            assert isinstance(a_max, relax.Expr), (
+                f"Unexpected argument type "
+                f"passed to torch.clamp/clip: {a_max} with type {type(a_max)}"
+            )
+            a_max = self.block_builder.emit(relax.op.broadcast_to(a_max, 
self.shape_of(x)))
+            x = self.block_builder.emit(relax.op.minimum(x, a_max))
+            a_max = math.inf
+
+        return self.block_builder.emit(relax.op.clip(x, a_min, a_max))
 
     def _elu(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
@@ -696,8 +772,8 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             return self.block_builder.emit(relax.op.reshape(embedding, 
[*x_shape, emb_size]))
 
     def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> 
relax.Var:
-        from torch.fx.immutable_collections import immutable_list
         import numpy as np  # type: ignore
+        from torch.fx.immutable_collections import immutable_list
 
         if isinstance(normalized_shape, (immutable_list, tuple)):
             normalized_shape = tuple(normalized_shape)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 2103365c6c..71a3d13aa1 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -193,6 +193,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "bitwise_not.default": self._unary_op(relax.op.bitwise_not),
             "ceil.default": self._unary_op(relax.op.ceil),
             "clamp.default": self._clamp,
+            "clamp_min.default": self._clamp_min,
+            "clamp_max.default": self._clamp_max,
             "cos.default": self._unary_op(relax.op.cos),
             "cosh.default": self._unary_op(relax.op.cosh),
             "dropout.default": lambda node: self.env[node.args[0]],
@@ -294,6 +296,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "argmin.default": self._argmax_argmin(relax.op.argmin),
             # tensor manipulation
             "cat.default": self._cat,
+            "clamp.Tensor": self._clamp,
             "concat.default": self._cat,
             "copy_.default": self._copy_,
             "cumsum.default": self._cumsum,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index abda5088db..952fb6f971 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -18,8 +18,8 @@
 # pylint: disable=invalid-name, inconsistent-return-statements, 
unidiomatic-typecheck
 # pylint: disable=import-outside-toplevel
 """PyTorch FX frontend of Relax."""
-from typing import Callable, Dict, List, Tuple, Union
 from functools import partial, reduce
+from typing import Callable, Dict, List, Tuple, Union
 
 import tvm
 from tvm import relax
@@ -598,6 +598,7 @@ class TorchFXImporter(BaseFXGraphImporter):
         self,
     ) -> Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]]:
         import operator
+
         from torch import nn
 
         return {
diff --git a/tests/python/relax/test_from_exported_to_cuda.py 
b/tests/python/relax/test_from_exported_to_cuda.py
index e8b5da0dc2..6cc12370d6 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -56,6 +56,87 @@ def 
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
     np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, 
atol=1e-5)
 
 
[email protected]_targets("cuda")
+def test_tensor_clamp(target, dev):
+    class ClampBothTensor(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.register_buffer("min_val", torch.tensor(-1.0))
+            self.register_buffer("max_val", torch.tensor(1.0))
+
+        def forward(self, x):
+            return x.clamp(min=self.min_val, max=self.max_val)
+
+    class ClampBothInt(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.min_val = -1
+            self.max_val = 1
+
+        def forward(self, x):
+            return x.clamp(min=self.min_val, max=self.max_val)
+
+    class ClampMinOnlyTensor(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.register_buffer("min_val", torch.tensor(0.0))
+
+        def forward(self, x):
+            return x.clamp(min=self.min_val)
+
+    class ClampMinOnlyInt(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.min_val = 0
+
+        def forward(self, x):
+            return x.clamp(min=self.min_val)
+
+    class ClampMaxOnlyTensor(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.register_buffer("max_val", torch.tensor(0.5))
+
+        def forward(self, x):
+            return x.clamp(max=self.max_val)
+
+    class ClampMaxOnlyInt(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.max_val = 0.5
+
+        def forward(self, x):
+            return x.clamp(max=self.max_val)
+
+    class ClampDifferentValues(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.min_val = -2
+            self.max_val = 2
+
+        def forward(self, x):
+            return x.clamp(min=self.min_val, max=self.max_val)
+
+    # Create random data with values outside our clamp ranges
+    raw_data = np.random.uniform(-3.0, 3.0, (2, 3, 4, 5)).astype(np.float32)
+
+    torch_module0 = ClampBothTensor().eval()
+    torch_module1 = ClampBothInt().eval()
+    torch_module2 = ClampMinOnlyTensor().eval()
+    torch_module3 = ClampMinOnlyInt().eval()
+    torch_module4 = ClampMaxOnlyTensor().eval()
+    torch_module5 = ClampMaxOnlyInt().eval()
+    torch_module6 = ClampDifferentValues().eval()
+
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, 
target, dev)
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1, 
target, dev)
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, 
target, dev)
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, 
target, dev)
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module4, 
target, dev)
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module5, 
target, dev)
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module6, 
target, dev)
+
+
 @tvm.testing.parametrize_targets("cuda")
 def test_tensor_expand_as(target, dev):
     class ExpandAs0(torch.nn.Module):
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 6406610bf5..8b0a711a52 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -135,18 +135,70 @@ def test_extended_unary_ops():
     class expected_clamp:
         @R.function
         def main(
-            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+            input: R.Tensor((1, 3, 10, 10), dtype="float32"),
         ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
-            # block 0
             with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.clip(input_1, 0.1, 0.5)
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+                    input,
+                    R.prim_value(T.float64(0.10000000000000001)),
+                    R.prim_value(T.float64(0.5)),
+                )
                 gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                 R.output(gv)
             return gv
 
     verify_model(Clamp(), example_args, {}, expected_clamp)
 
+    class ClampMinOnly(Module):
+        def forward(self, input):
+            return torch.clamp(input, min=0.5, max=None)
+
+    @tvm.script.ir_module
+    class expected_clamp_min_only:
+        @R.function
+        def main(
+            input: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+                    input, R.prim_value(T.float64(0.5)), 
R.prim_value(T.float64("inf"))
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(ClampMinOnly(), example_args, {}, expected_clamp_min_only)
+
+    class ClampTensors(Module):
+        def forward(self, input):
+            return torch.clamp(input, min=input, max=input)
+
+    @tvm.script.ir_module
+    class expected_clamp_tensors:
+        @R.function
+        def main(
+            input: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to(
+                    input, R.shape([1, 3, 10, 10])
+                )
+                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.maximum(input, lv)
+                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.broadcast_to(
+                    input, R.shape([1, 3, 10, 10])
+                )
+                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.minimum(lv1, lv2)
+                lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+                    lv3, R.prim_value(T.float64("-inf")), 
R.prim_value(T.float64("inf"))
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv4,)
+                R.output(gv)
+            return gv
+
+    verify_model(ClampTensors(), example_args, {}, expected_clamp_tensors)
+
     # dropout
+
     class Dropout1(Module):
         def __init__(self):
             super().__init__()
@@ -3248,3 +3300,7 @@ def test_no_bind_return_tuple():
     exported_program = export(Identity(), args=example_args)
     mod = from_exported_program(exported_program, no_bind_return_tuple=True)
     tvm.ir.assert_structural_equal(mod, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 020fc8f5b3..fbea8b7388 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 import operator
 import pytest
 import torch
@@ -21,6 +22,7 @@ import torch.nn.functional as F
 from torch import fx
 from torch.nn import Module
 import torchvision
+import math
 
 import tvm
 from tvm import relax
@@ -1970,7 +1972,7 @@ def test_extended_unary_ops():
     class expected_clamp:
         @R.function
         def main(
-            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
         ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
             # block 0
             with R.dataflow():
@@ -1981,29 +1983,53 @@ def test_extended_unary_ops():
 
     verify_model(Clamp(), input_info, {}, expected_clamp)
 
-    from tvm.relax.frontend.torch import from_fx
-
-    with pytest.raises(
-        ValueError, match="TVM only supports constant max value for 
torch.clamp/clip"
-    ):
+    class ClampMinOnly(Module):
+        def forward(self, input):
+            return torch.clamp(input, min=0.5, max=None)
 
-        class Clamp_Error(Module):
-            def forward(self, input):
-                return torch.clamp(input, min=0.5, max=None)
+    @tvm.script.ir_module
+    class expected_clamp_min_only:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.clip(input_1, 0.5, math.inf)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
 
-        gm = fx.symbolic_trace(Clamp_Error())
-        from_fx(gm, input_info)
+    verify_model(ClampMinOnly(), input_info, {}, expected_clamp_min_only)
 
-    with pytest.raises(
-        ValueError, match="TVM only supports constant min value for 
torch.clamp/clip"
-    ):
+    class ClampTensors(Module):
+        def forward(self, input):
+            return torch.clamp(input, min=input, max=input)
 
-        class Clamp_Error(Module):
-            def forward(self, input):
-                return torch.clamp(input, min=input, max=input)
+    @tvm.script.ir_module
+    class expected_clamp_tensors:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to(
+                    inp_0, R.shape([1, 3, 10, 10])
+                )
+                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.maximum(inp_0, lv)
+                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.broadcast_to(
+                    inp_0, R.shape([1, 3, 10, 10])
+                )
+                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.minimum(lv1, lv2)
+                lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+                    lv3, R.prim_value(T.float64("-inf")), 
R.prim_value(T.float64("inf"))
+                )
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv4
+                R.output(gv)
+            return gv
 
-        gm = fx.symbolic_trace(Clamp_Error())
-        from_fx(gm, input_info)
+    verify_model(ClampTensors(), input_info, {}, expected_clamp_tensors)
 
     # dropout
     class Dropout1(Module):

Reply via email to