Cookiee235 opened a new issue, #17270:
URL: https://github.com/apache/tvm/issues/17270

   ### Actual behavior
   ```
   Traceback (most recent call last):
     File "/share_container/optfuzz/res/bugs/inconsis222.py", line 258, in 
<module>
       np.testing.assert_allclose(before_outputs, after_outputs, 1e-3, 1e-3)
     File 
"/root/miniconda3/lib/python3.12/site-packages/numpy/testing/_private/utils.py",
 line 1504, in assert_allclose
       assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
     File "/root/miniconda3/lib/python3.12/contextlib.py", line 81, in inner
       return func(*args, **kwds)
              ^^^^^^^^^^^^^^^^^^^
     File 
"/root/miniconda3/lib/python3.12/site-packages/numpy/testing/_private/utils.py",
 line 718, in assert_array_compare
       flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File 
"/root/miniconda3/lib/python3.12/site-packages/numpy/testing/_private/utils.py",
 line 688, in func_assert_same_pos
       raise AssertionError(msg)
   AssertionError: 
   Not equal to tolerance rtol=0.001, atol=0.001
   
   x and y nan location mismatch:
    x: array([[ 7.936000e+04,  8.032000e+04,  8.128000e+04,  8.224000e+04,
            8.320000e+04,  8.416000e+04,  8.512000e+04,  8.608000e+04,
            7.168000e+04,  7.252000e+04,  1.898367e+16,  7.420000e+04,...
    y: array([[ 7.936000e+04,  8.032000e+04,  8.128000e+04,  8.224000e+04,
            8.320000e+04,  8.416000e+04,  8.512000e+04,  8.608000e+04,
                     nan,  7.252000e+04,  7.336000e+04,  7.420000e+04,...
   
   ```
   
   
   ### Steps to reproduce
   <details>
   <summary>This is a complex test case, I cannot further reduce this case due 
to unknown root case</summary>
   
   
   ```python
   import tvm
   from tvm import relax
   import numpy as np
   import tvm
   metadata = tvm.ir.load_json("""{
     \"root\": 1, 
     \"nodes\": [
       {
         \"type_key\": \"\"
       }, 
       {
         \"type_key\": \"Map\", 
         \"keys\": [
           \"relax.expr.Constant\"
         ], 
         \"data\": [2]
       }, 
       {
         \"type_key\": \"Array\", 
         \"data\": [3]
       }, 
       {
         \"type_key\": \"relax.expr.Constant\", 
         \"attrs\": {
           \"_checked_type_\": \"11\", 
           \"data\": \"0\", 
           \"span\": \"0\", 
           \"struct_info_\": \"4\"
         }
       }, 
       {
         \"type_key\": \"relax.TensorStructInfo\", 
         \"attrs\": {
           \"dtype\": \"float32\", 
           \"ndim\": \"2\", 
           \"shape\": \"5\", 
           \"span\": \"0\", 
           \"vdevice\": \"0\"
         }
       }, 
       {
         \"type_key\": \"relax.expr.ShapeExpr\", 
         \"attrs\": {
           \"_checked_type_\": \"10\", 
           \"span\": \"0\", 
           \"struct_info_\": \"9\", 
           \"values\": \"6\"
         }
       }, 
       {
         \"type_key\": \"Array\", 
         \"data\": [7, 8]
       }, 
       {
         \"type_key\": \"IntImm\", 
         \"attrs\": {
           \"dtype\": \"int64\", 
           \"span\": \"0\", 
           \"value\": \"16\"
         }
       }, 
       {
         \"type_key\": \"IntImm\", 
         \"attrs\": {
           \"dtype\": \"int64\", 
           \"span\": \"0\", 
           \"value\": \"16\"
         }
       }, 
       {
         \"type_key\": \"relax.ShapeStructInfo\", 
         \"attrs\": {
           \"ndim\": \"2\", 
           \"span\": \"0\", 
           \"values\": \"6\"
         }
       }, 
       {
         \"type_key\": \"relax.ShapeType\", 
         \"attrs\": {
           \"ndim\": \"2\", 
           \"span\": \"0\"
         }
       }, 
       {
         \"type_key\": \"relax.DynTensorType\", 
         \"attrs\": {
           \"dtype\": \"float32\", 
           \"ndim\": \"2\", 
           \"span\": \"0\"
         }
       }
     ], 
     \"b64ndarrays\": [
       
\"P6G0lvBAXt0AAAAAAAAAAAEAAAAAAAAAAgAAAAIgAQAQAAAAAAAAABAAAAAAAAAAAAQAAAAAAAAAAAAAAACAPwAAAEAAAEBAAACAQAAAoEAAAMBAAADgQAAAAEEAABBBAAAgQQAAMEEAAEBBAABQQQAAYEEAAHBBAACAQQAAiEEAAJBBAACYQQAAoEEAAKhBAACwQQAAuEEAAMBBAADIQQAA0EEAANhBAADgQQAA6EEAAPBBAAD4QQAAAEIAAARCAAAIQgAADEIAABBCAAAUQgAAGEIAABxCAAAgQgAAJEIAAChCAAAsQgAAMEIAADRCAAA4QgAAPEIAAEBCAABEQgAASEIAAExCAABQQgAAVEIAAFhCAABcQgAAYEIAAGRCAABoQgAAbEIAAHBCAAB0QgAAeEIAAHxCAACAQgAAgkIAAIRCAACGQgAAiEIAAIpCAACMQgAAjkIAAJBCAACSQgAAlEIAAJZCAACYQgAAmkIAAJxCAACeQgAAoEIAAKJCAACkQgAApkIAAKhCAACqQgAArEIAAK5CAACwQgAAskIAALRCAAC2QgAAuEIAALpCAAC8QgAAvkIAAMBCAADCQgAAxEIAAMZCAADIQgAAykIAAMxCAADOQgAA0EIAANJCAADUQgAA1kIAANhCAADaQgAA3EIAAN5CAADgQgAA4kIAAORCAADmQgAA6EIAAOpCAADsQgAA7kIAAPBCAADyQgAA9EIAAPZCAAD4QgAA+kIAAPxCAAD+QgAAAEMAAAFDAAACQwAAA0MAAARDAAAFQwAABkMAAAdDAAAIQwAACUMAAApDAAALQwAADEMAAA1DAAAOQwAAD0MAABBDAAARQwAAEkMAABNDAAAUQwAAFUMAABZDAAAXQwAAGEMAABlDAAAaQwAAG0MAABxDAAAdQwAAHkMAAB9DAAAgQwAAIUMAACJDAAAjQwAAJEMAACVDAAAmQwAAJ0MAAChDAAApQwAAKkMAA
 
CtDAAAsQwAALUMAAC5DAAAvQwAAMEMAADFDAAAyQwAAM0MAADRDAAA1QwAANkMAADdDAAA4QwAAOUMAADpDAAA7QwAAPEMAAD1DAAA+QwAAP0MAAEBDAABBQwAAQkMAAENDAABEQwAARUMAAEZDAABHQwAASEMAAElDAABKQwAAS0MAAExDAABNQwAATkMAAE9DAABQQwAAUUMAAFJDAABTQwAAVEMAAFVDAABWQwAAV0MAAFhDAABZQwAAWkMAAFtDAABcQwAAXUMAAF5DAABfQwAAYEMAAGFDAABiQwAAY0MAAGRDAABlQwAAZkMAAGdDAABoQwAAaUMAAGpDAABrQwAAbEMAAG1DAABuQwAAb0MAAHBDAABxQwAAckMAAHNDAAB0QwAAdUMAAHZDAAB3QwAAeEMAAHlDAAB6QwAAe0MAAHxDAAB9QwAAfkMAAH9D\"
     ], 
     \"attrs\": {\"tvm_version\": \"0.17.dev0\"}
   }""")
   from tvm.script import ir as I
   from tvm.script import tir as T
   from tvm.script import relax as R
   
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def add(A: T.Buffer((T.int64(16), T.int64(16)), "float32"), B: 
T.Buffer((T.int64(16), T.int64(16)), "float32"), T_add: T.Buffer((T.int64(16), 
T.int64(16)), "float32")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
               with T.block("T_add"):
                   v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                   T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
                   T.writes(T_add[v_ax0, v_ax1])
                   T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]
   
       @T.prim_func(private=True)
       def cast(gv: T.Buffer((T.int64(16), T.int64(16)), "float32"), compute: 
T.Buffer((T.int64(16), T.int64(16)), "int64")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           for i0, i1 in T.grid(T.int64(16), T.int64(16)):
               with T.block("compute"):
                   v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                   T.reads(gv[v_i0, v_i1])
                   T.writes(compute[v_i0, v_i1])
                   compute[v_i0, v_i1] = T.Cast("int64", gv[v_i0, v_i1])
   
       @T.prim_func(private=True)
       def matmul(x: T.Buffer((T.int64(1), T.int64(16)), "float32"), weight: 
T.Buffer((T.int64(16), T.int64(32)), "float32"), matmul: T.Buffer((T.int64(1), 
T.int64(32)), "float32")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           for i0, i1, k in T.grid(T.int64(1), T.int64(32), T.int64(16)):
               with T.block("matmul"):
                   v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                   T.reads(x[v_i0, v_k], weight[v_k, v_i1])
                   T.writes(matmul[v_i0, v_i1])
                   with T.init():
                       matmul[v_i0, v_i1] = T.float32(0)
                   matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + x[v_i0, v_k] * 
weight[v_k, v_i1]
   
       @T.prim_func(private=True)
       def reshape(gv: T.Buffer((T.int64(16), T.int64(16)), "float32"), 
T_reshape: T.Buffer((T.int64(256),), "float32")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           for ax0 in range(T.int64(256)):
               with T.block("T_reshape"):
                   v_ax0 = T.axis.spatial(T.int64(256), ax0)
                   T.reads(gv[v_ax0 % T.int64(256) // T.int64(16), v_ax0 % 
T.int64(16)])
                   T.writes(T_reshape[v_ax0])
                   T_reshape[v_ax0] = gv[v_ax0 % T.int64(256) // T.int64(16), 
v_ax0 % T.int64(16)]
   
       @T.prim_func(private=True)
       def reshape1(temp: T.Buffer((T.int64(16),), "float32"), T_reshape: 
T.Buffer((T.int64(1), T.int64(16)), "float32")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           for ax0, ax1 in T.grid(T.int64(1), T.int64(16)):
               with T.block("T_reshape"):
                   v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                   T.reads(temp[v_ax1 % T.int64(16)])
                   T.writes(T_reshape[v_ax0, v_ax1])
                   T_reshape[v_ax0, v_ax1] = temp[v_ax1 % T.int64(16)]
   
       @T.prim_func(private=True)
       def reshape2(gv: T.Buffer((T.int64(16), T.int64(16)), "int64"), 
T_reshape: T.Buffer((T.int64(256),), "int64")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           for ax0 in range(T.int64(256)):
               with T.block("T_reshape"):
                   v_ax0 = T.axis.spatial(T.int64(256), ax0)
                   T.reads(gv[v_ax0 % T.int64(256) // T.int64(16), v_ax0 % 
T.int64(16)])
                   T.writes(T_reshape[v_ax0])
                   T_reshape[v_ax0] = gv[v_ax0 % T.int64(256) // T.int64(16), 
v_ax0 % T.int64(16)]
   
       @T.prim_func(private=True)
       def reshape3(temp: T.Buffer((T.int64(32),), "int64"), T_reshape: 
T.Buffer((T.int64(32),), "int64")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           for ax0 in range(T.int64(32)):
               with T.block("T_reshape"):
                   v_ax0 = T.axis.spatial(T.int64(32), ax0)
                   T.reads(temp[v_ax0 % T.int64(32)])
                   T.writes(T_reshape[v_ax0])
                   T_reshape[v_ax0] = temp[v_ax0 % T.int64(32)]
   
       @T.prim_func(private=True)
       def strided_slice(tensor_1dim: T.Buffer((T.int64(256),), "float32"), 
T_strided_slice_with_axes: T.Buffer((T.int64(16),), "float32")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           for ax0 in range(T.int64(16)):
               with T.block("T_strided_slice_with_axes"):
                   v_ax0 = T.axis.spatial(T.int64(16), ax0)
                   T.reads(tensor_1dim[v_ax0])
                   T.writes(T_strided_slice_with_axes[v_ax0])
                   T_strided_slice_with_axes[v_ax0] = tensor_1dim[v_ax0]
   
       @T.prim_func(private=True)
       def strided_slice1(tensor_1dim: T.Buffer((T.int64(256),), "int64"), 
T_strided_slice_with_axes: T.Buffer((T.int64(32),), "int64")):
           T.func_attr({"tir.noalias": T.bool(True)})
           # with T.block("root"):
           for ax0 in range(T.int64(32)):
               with T.block("T_strided_slice_with_axes"):
                   v_ax0 = T.axis.spatial(T.int64(32), ax0)
                   T.reads(tensor_1dim[v_ax0])
                   T.writes(T_strided_slice_with_axes[v_ax0])
                   T_strided_slice_with_axes[v_ax0] = tensor_1dim[v_ax0]
   
       @T.prim_func(private=True)
       def take(var_weight_table: T.handle, routing_table: 
T.Buffer((T.int64(32),), "int64"), T_take: T.Buffer((T.int64(16), T.int64(32)), 
"float32")):
           T.func_attr({"tir.noalias": T.bool(True)})
           weight_table_size = T.int64()
           weight_table = T.match_buffer(var_weight_table, (T.int64(16), 
weight_table_size))
           # with T.block("root"):
           for ax0, ax1 in T.grid(T.int64(16), T.int64(32)):
               with T.block("T_take"):
                   v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                   T.reads(weight_table[v_ax0, routing_table[v_ax1]], 
routing_table[v_ax1])
                   T.writes(T_take[v_ax0, v_ax1])
                   T_take[v_ax0, v_ax1] = weight_table[v_ax0, 
routing_table[v_ax1]]
   
       @R.function
       def main_7(x: R.Tensor((1, 16), dtype="float32"), weight_table: 
R.Tensor((16, "weight_table_size"), dtype="float32"), routing_table: 
R.Tensor((32,), dtype="int64")) -> R.Tensor((1, 32), dtype="float32"):
           weight_table_size = T.int64()
           cls = Module
           with R.dataflow():
               weight = R.call_tir(cls.take, (weight_table, routing_table), 
out_sinfo=R.Tensor((16, 32), dtype="float32"))
               out = R.call_tir(cls.matmul, (x, weight), out_sinfo=R.Tensor((1, 
32), dtype="float32"))
               R.output(out)
           return out
   
       @R.function
       def main() -> R.Tensor((1, 32), dtype="float32"):
           cls = Module
           gv = R.call_tir(cls.add, (metadata["relax.expr.Constant"][0], 
metadata["relax.expr.Constant"][0]), out_sinfo=R.Tensor((16, 16), 
dtype="float32"))
           tensor_1dim = R.call_tir(cls.reshape, (gv,), 
out_sinfo=R.Tensor((256,), dtype="float32"))
           temp = R.call_tir(cls.strided_slice, (tensor_1dim,), 
out_sinfo=R.Tensor((16,), dtype="float32"))
           para0 = R.call_tir(cls.reshape1, (temp,), out_sinfo=R.Tensor((1, 
16), dtype="float32"))
           para1: R.Tensor((16, 16), dtype="float32") = gv
           gv_1 = R.call_tir(cls.cast, (gv,), out_sinfo=R.Tensor((16, 16), 
dtype="int64"))
           tensor_1dim_1 = R.call_tir(cls.reshape2, (gv_1,), 
out_sinfo=R.Tensor((256,), dtype="int64"))
           temp_1 = R.call_tir(cls.strided_slice1, (tensor_1dim_1,), 
out_sinfo=R.Tensor((32,), dtype="int64"))
           para2 = R.call_tir(cls.reshape3, (temp_1,), 
out_sinfo=R.Tensor((32,), dtype="int64"))
           res: R.Tensor((1, 32), dtype="float32") = cls.main_7(para0, para1, 
para2)
           return res
   
   
   def compile_mod(mod, func_name, target, *inputs):
       ex = relax.build(mod, target='llvm')
       vm = relax.VirtualMachine(ex, tvm.cpu())
       mod_outputs = vm[f'{func_name}'](*inputs)
       mod_outputs = mod_outputs.numpy()
       return mod_outputs
   mod = Module
   before_outputs = compile_mod(mod, 'main', 'llvm')
   mod = relax.transform.FoldConstant()(mod)
   mod = relax.transform.ReorderTakeAfterMatmul()(mod)
   after_outputs = compile_mod(mod, 'main', 'llvm')
   np.testing.assert_allclose(before_outputs, after_outputs, 1e-3, 1e-3)
   
   ```
   </details>
   
   CC @Lunderberg @junrushao 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to