ANSHUMAN87 commented on a change in pull request #7807:
URL: https://github.com/apache/tvm/pull/7807#discussion_r617695429



##########
File path: tests/python/relay/test_pass_simplify_expr.py
##########
@@ -106,10 +106,112 @@ def expected3():
         y = relay.transpose(y, axes=[0, 2, 3, 1])
         return relay.Function([x], y)
 
+    # Test a series of transpose and rank changing layout_transform
+    def before4():
+        x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32")  # NHWC
+        y = relay.transpose(x, axes=[0, 3, 1, 2])  # To NCHW
+        y = relay.layout_transform(y, "NCHW", "NCHW4c")  # To NCHW4c
+        y = relay.nn.relu(y)
+        y = relay.layout_transform(y, "NCHW4c", "NCHW")  # To NCHW
+        y = relay.transpose(y, axes=[0, 2, 3, 1])  # To NHWC
+        return relay.Function([x], y)
+
+    def expected4():
+        x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32")  # NHWC
+        y = relay.layout_transform(x, "NHWC", "NCHW4c")  # To NCHW4c
+        y = relay.nn.relu(y)
+        y = relay.layout_transform(y, "NCHW4c", "NHWC")  # To NHWC
+        return relay.Function([x], y)
+
+    def before5():
+        x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32")  # NHWC
+        y = relay.layout_transform(x, "NHWC", "NCHW")  # To NCHW
+        y = relay.layout_transform(y, "NCHW", "NCHW4c")  # To NCHW4c
+        y = relay.nn.relu(y)
+        y = relay.layout_transform(y, "NCHW4c", "NCHW")  # To NCHW
+        y = relay.layout_transform(y, "NCHW", "NHWC")  # To NHWC
+        return relay.Function([x], y)
+
+    def expected5():
+        x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32")  # NHWC
+        y = relay.layout_transform(x, "NHWC", "NCHW4c")  # To NCHW4c
+        y = relay.nn.relu(y)
+        y = relay.layout_transform(y, "NCHW4c", "NHWC")  # To NHWC
+        return relay.Function([x], y)
+
+    def before6():
+        x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32")  # NHWC
+        y = relay.layout_transform(x, "NCHW", "NHWC")
+        y = relay.layout_transform(y, "NHWC", "NCHW")
+        y = relay.nn.relu(y)
+        return relay.Function([x], y)
+
+    def expected6():
+        x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32")  # NHWC
+        y = relay.nn.relu(x)
+        return relay.Function([x], y)
+
+    def before7():
+        x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")  # NCHW4c
+        y = relay.layout_transform(x, "NCHW4c", "NCHW8c")
+        y = relay.layout_transform(y, "NCHW8c", "NCHW4c")
+        y = relay.nn.relu(y)
+        return relay.Function([x], y)
+
+    def expected7():
+        x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")  # NCHW4c
+        y = relay.nn.relu(x)
+        return relay.Function([x], y)
+
+    def before8():
+        x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")  # NCHW4c
+        y = relay.layout_transform(x, "NCHW4c", "NCHW")
+        y = relay.layout_transform(y, "NCHW", "NCHW8c")
+        y = relay.nn.relu(y)
+        return relay.Function([x], y)
+
+    def expected8():
+        x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")  # NCHW4c
+        y = relay.layout_transform(x, "NCHW4c", "NCHW8c")
+        y = relay.nn.relu(y)
+        return relay.Function([x], y)
+
+    def before9():
+        x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32")  # NHWC
+        y = relay.layout_transform(x, "NCHW", "NCHW4c")  # To NCHW4c

Review comment:
       The comment for x say the layout is # NHWC , but the transform below 
does from "NCHW" --> "NCHW4c".
   It does not match. Is it intended ? 
   Same case in test_case 6 also.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to