slyubomirsky commented on code in PR #16602:
URL: https://github.com/apache/tvm/pull/16602#discussion_r1496856295


##########
python/tvm/relax/transform/lazy_transform_params.py:
##########
@@ -157,24 +159,60 @@ def transform(self, func: relax.Function) -> 
relax.Function:
         self.memory_free_insertion = liveness.var_liveness_end
 
         # Step 3. rewrite get item and set item
-        new_body = func.body
         if self.fget_item is not None:
-            new_body = LazyInputMutator(self, self.mod).visit_expr(new_body)
+            new_func = LazyInputMutator(self, self.mod).visit_expr(func)
 
+        new_body = new_func.body
         if self.fset_item is not None:
+            leaf_outputs = {
+                expr: indices
+                for expr, indices in self.out_tuple_map.items()
+                if not isinstance(expr, relax.Var)
+            }
+            if leaf_outputs:
+                new_bindings = [
+                    relax.VarBinding(
+                        relax.Var("_", relax.ObjectStructInfo()),
+                        relax.Call(
+                            relax.ExternFunc(self.fset_item),
+                            [*self.extra_set_item_params, index, expr],
+                            None,
+                            [relax.ObjectStructInfo()],
+                        ),
+                    )
+                    for expr, indices in leaf_outputs.items()
+                    for index in indices

Review Comment:
   You've gotta love the syntax for nested list comprehensions 
:slightly_smiling_face:



##########
python/tvm/relax/transform/lazy_transform_params.py:
##########
@@ -157,24 +159,60 @@ def transform(self, func: relax.Function) -> 
relax.Function:
         self.memory_free_insertion = liveness.var_liveness_end
 
         # Step 3. rewrite get item and set item
-        new_body = func.body
         if self.fget_item is not None:
-            new_body = LazyInputMutator(self, self.mod).visit_expr(new_body)
+            new_func = LazyInputMutator(self, self.mod).visit_expr(func)
 
+        new_body = new_func.body
         if self.fset_item is not None:
+            leaf_outputs = {
+                expr: indices
+                for expr, indices in self.out_tuple_map.items()
+                if not isinstance(expr, relax.Var)
+            }
+            if leaf_outputs:
+                new_bindings = [
+                    relax.VarBinding(
+                        relax.Var("_", relax.ObjectStructInfo()),
+                        relax.Call(
+                            relax.ExternFunc(self.fset_item),
+                            [*self.extra_set_item_params, index, expr],
+                            None,
+                            [relax.ObjectStructInfo()],
+                        ),
+                    )
+                    for expr, indices in leaf_outputs.items()
+                    for index in indices
+                ]
+                new_body = relax.SeqExpr(
+                    [*new_body.blocks, relax.BindingBlock(new_bindings)], 
new_body.body
+                )
+

Review Comment:
   I presume these additions are for handling the non-var case mentioned in the 
description?



##########
tests/python/relax/test_transform_lazy_transform_params.py:
##########
@@ -602,5 +602,77 @@ def main_transform_params() -> R.Tuple:
     tvm.ir.assert_structural_equal(after, Expected)
 
 
+def test_params_without_tuple():
+    @I.ir_module
+    class Before:
+        @R.function
+        def transform_params(A: R.Tensor([16, 16], "float32"), B: 
R.Tensor([16, 16], "float32")):
+            C = R.multiply(A, R.const(2, "float32"))
+            D = R.add(C, B)
+            return (D, B)
+
+    @I.ir_module
+    class Expected:
+        @R.function(pure=False)
+        def transform_params():
+            A = R.call_packed("get_item", R.prim_value(0), 
sinfo_args=[R.Object])
+            A = R.match_cast(A, R.Tensor([16, 16], "float32"))
+            C = R.multiply(A, R.const(2, "float32"))
+
+            B = R.call_packed("get_item", R.prim_value(1), 
sinfo_args=[R.Object])
+            B = R.match_cast(B, R.Tensor([16, 16], "float32"))
+            D = R.add(C, B)
+            return (D, B)
+
+    After = LazyTransformParams(fset_item=None)(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_retain_before_num_input():
+    """Only lazily load parameters after num_input"""
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def transform_params(
+            relax_rank: R.Prim(value="rank"),
+            A: R.Tensor([16, 16], "float32"),
+            B: R.Tensor([16, 16], "float32"),
+        ):
+            R.func_attr({"num_input": 1})
+            rank = T.int64()
+            A_sharded = R.strided_slice(
+                A, axes=[0], begin=[rank * 8], end=[(rank + 1) * 8], 
assume_inbound=True
+            )
+            B_sharded = R.strided_slice(
+                B, axes=[1], begin=[rank * 8], end=[(rank + 1) * 8], 
assume_inbound=True
+            )
+            return (A_sharded, B_sharded)
+
+    @I.ir_module
+    class Expected:
+        @R.function(pure=False)
+        def transform_params(relax_rank: R.Prim(value="rank")):
+            R.func_attr({"num_input": 1})
+            rank = T.int64()
+
+            A = R.call_packed("get_item", R.prim_value(0), 
sinfo_args=[R.Object])
+            A = R.match_cast(A, R.Tensor([16, 16], "float32"))
+            A_sharded = R.strided_slice(
+                A, axes=[0], begin=[rank * 8], end=[(rank + 1) * 8], 
assume_inbound=True
+            )
+
+            B = R.call_packed("get_item", R.prim_value(1), 
sinfo_args=[R.Object])
+            B = R.match_cast(B, R.Tensor([16, 16], "float32"))
+            B_sharded = R.strided_slice(
+                B, axes=[1], begin=[rank * 8], end=[(rank + 1) * 8], 
assume_inbound=True
+            )
+
+            return (A_sharded, B_sharded)
+
+    After = LazyTransformParams(fset_item=None)(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+

Review Comment:
   Are there any test cases that make use of extra parameters for get_item and 
set_item? If it's not tested, it should be. If there also isn't a case of a 
non-var output (I'm not sure exactly what that should look like, as I haven't 
used this pass), that would be good to add too.



##########
python/tvm/relax/transform/lazy_transform_params.py:
##########
@@ -157,24 +159,60 @@ def transform(self, func: relax.Function) -> 
relax.Function:
         self.memory_free_insertion = liveness.var_liveness_end
 
         # Step 3. rewrite get item and set item
-        new_body = func.body
         if self.fget_item is not None:
-            new_body = LazyInputMutator(self, self.mod).visit_expr(new_body)
+            new_func = LazyInputMutator(self, self.mod).visit_expr(func)
 
+        new_body = new_func.body
         if self.fset_item is not None:
+            leaf_outputs = {
+                expr: indices
+                for expr, indices in self.out_tuple_map.items()
+                if not isinstance(expr, relax.Var)
+            }
+            if leaf_outputs:
+                new_bindings = [
+                    relax.VarBinding(
+                        relax.Var("_", relax.ObjectStructInfo()),
+                        relax.Call(
+                            relax.ExternFunc(self.fset_item),
+                            [*self.extra_set_item_params, index, expr],
+                            None,
+                            [relax.ObjectStructInfo()],
+                        ),
+                    )
+                    for expr, indices in leaf_outputs.items()
+                    for index in indices
+                ]
+                new_body = relax.SeqExpr(
+                    [*new_body.blocks, relax.BindingBlock(new_bindings)], 
new_body.body
+                )
+
             new_body = LazyOutputMutator(self, self.mod).visit_expr(new_body)
 
         # Step 4. Add parameters of get_item and set_item (except index) to 
the function.
-        params = [*self.extra_get_item_params, *self.extra_set_item_params]
+        params = [
+            *func.params[:num_input],
+            *self.extra_get_item_params,
+            *self.extra_set_item_params,
+        ]
 
         # Step 5. Find all shape parameters that should be retained as
         # parameters.
         symbolic_vars = relax.analysis.defined_symbolic_vars(func)
         if symbolic_vars:
+
+            def unpack_sinfo(sinfo):
+                if isinstance(sinfo, relax.TupleStructInfo):
+                    for field in sinfo.fields:
+                        yield from unpack_sinfo(field)

Review Comment:
   First I'd seen `yield from`, this seems like a good use for it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to