Lunderberg commented on code in PR #16049:
URL: https://github.com/apache/tvm/pull/16049#discussion_r1385606465
##########
tests/python/relax/test_transform_lift_transform_params.py:
##########
@@ -642,5 +642,95 @@ def slice(
tvm.ir.assert_structural_equal(Expected, after)
+def test_symbolic_var_in_param_shape():
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor((1, 16, 224, "n"), "float32"),
+ w1: R.Tensor((16, "m", 3, 3), "float32"),
+ w2: R.Tensor((16, "m", 3, 3), "float32"),
+ ) -> R.Tensor((1, 16, 224, 224), "float32"):
+ m = T.int64()
+ n = T.int64()
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ zeros = R.zeros((n, n), "float32")
+ w1 = R.add(w1, R.const(1, "float32"))
+ conv1 = R.nn.conv2d(x, w1, padding=(1, 1), data_layout="NCHW",
kernel_layout="OIHW")
+ conv2 = R.nn.conv2d(
+ conv1, w2, padding=(1, 1), data_layout="NCHW",
kernel_layout="OIHW"
+ )
+ R.output(conv2)
+ return conv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main_transform_params(
+ params: R.Tuple(
+ R.Tensor((16, "m", 3, 3), dtype="float32"),
+ R.Tensor((16, "m", 3, 3), dtype="float32"),
+ )
+ ) -> R.Tuple(
+ R.Tensor((16, "m", 3, 3), dtype="float32"), R.Tensor((16, "m", 3,
3), dtype="float32")
+ ):
+ m = T.int64()
+ with R.dataflow():
+ lv: R.Tensor((16, m, 3, 3), dtype="float32") = params[1]
+ lv1: R.Tensor((16, m, 3, 3), dtype="float32") = params[0]
+ lv2: R.Tensor((16, m, 3, 3), dtype="float32") = R.add(lv1,
R.const(1, "float32"))
+ gv: R.Tuple(
+ R.Tensor((16, m, 3, 3), dtype="float32"),
+ R.Tensor((16, m, 3, 3), dtype="float32"),
+ ) = (lv, lv2)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ x: R.Tensor((1, 16, 224, "n"), dtype="float32"),
+ transformed_param_0: R.Tensor((16, "m", 3, 3), dtype="float32"),
+ transformed_param_1: R.Tensor((16, "m", 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
+ n = T.int64()
+ m = T.int64()
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ zeros: R.Tensor((n, n), dtype="float32") = R.zeros(R.shape([n,
n]), dtype="float32")
+ lv: R.Tensor((16, m, 3, 3), dtype="float32") =
transformed_param_1
+ conv1: R.Tensor((1, 16, 224, n), dtype="float32") =
R.nn.conv2d(
+ x,
+ lv,
+ strides=[1, 1],
+ padding=[1, 1, 1, 1],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="void",
+ )
+ lv1: R.Tensor((16, m, 3, 3), dtype="float32") =
transformed_param_0
+ conv2: R.Tensor((1, 16, 224, n), dtype="float32") =
R.nn.conv2d(
+ conv1,
+ lv1,
+ strides=[1, 1],
+ padding=[1, 1, 1, 1],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="void",
+ )
+ R.output(conv2)
+ return conv2
+
+ mod = Before
+ after = relax.transform.LiftTransformParams()(mod)
+ tvm.ir.assert_structural_equal(after, Expected)
+
Review Comment:
Oof, looks like it's a pretty tricky one to find a long-term solution for.
1. `FuncStructInfo` is used both for top-level functions (must define all
symbolic vars) and for local bindings of functions (may use symbolic vars from
parent scope).
2. `FuncStructInfo` does not distinguish between usage of symbolic variables
and definitions of symbolic variables.
3. Because it cannot distinguish between valid usage of symbolic vars from
the parent scope and invalid usage of undefined symbolic vars,
`EraseToWellDefined` doesn't modify `FuncStructInfo`.
4. For the same reason, the `relax::Function` constructor does not throw an
error if a parameter contains context-defined symbolic variables.
5. `LiftTransformParams` generates a `relax::Function` that would be valid
in a local binding, but isn't valid as a top-level function. Therefore, the
undefined variables are retained.
I think the long-term resolution will be to either extend `FuncStructInfo`
to distinguish between usage/definition of symbolic variables, or to
preferentially use local bindings instead of new functions, in order to
centralize the special-handling into `relax.transform.LambdaLift`.
For now, I think adding the known failure mode with `@pytest.mark.xfail` is
sufficient.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]