gemini-code-assist[bot] commented on code in PR #18646:
URL: https://github.com/apache/tvm/pull/18646#discussion_r2671332349


##########
tests/python/codegen/test_target_codegen_llvm.py:
##########
@@ -378,6 +378,31 @@ def check_llvm(n):
     check_llvm(64)
 
 
[email protected]_llvm
+def test_llvm_cast_float_to_bool():
+    a_np = np.array([0.0, 1.0, np.nan, np.inf], dtype="float32")
+    n = a_np.shape[0]
+
+    A = te.placeholder((n,), name="A", dtype="float32")
+    C = te.compute((n,), lambda i: A[i].astype("bool"), name="C")
+
+    # Convert to TIR and create schedule
+    mod = te.create_prim_func([A, C])
+    sch = tir.Schedule(mod)
+
+    # build and invoke the kernel.
+    f = tvm.compile(sch.mod, target="llvm")
+    dev = tvm.cpu(0)
+
+    # launch the kernel.
+    a = tvm.runtime.tensor(a_np, dev)
+    c = tvm.runtime.empty((n,), dtype="bool", device=dev)
+    f(a, c)
+    c_np = np.array([False, True, True, True], dtype="bool")
+
+    tvm.testing.assert_allclose(c.numpy(), c_np)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This is a great test case that covers the essential scenarios for casting 
floats to booleans. To make it even more comprehensive, I suggest 
parameterizing it to run against multiple float dtypes (`float16`, `float32`, 
and `float64`). This will ensure the fix holds for different precisions and 
improve test coverage.
   
   ```python
   @tvm.testing.requires_llvm
   @pytest.mark.parametrize("dtype", ["float16", "float32", "float64"])
   def test_llvm_cast_float_to_bool(dtype):
       if dtype == "float16" and tvm.target.codegen.llvm_version_major() < 8:
           pytest.skip("float16 support requires LLVM 8 or greater")
   
       a_np = np.array([0.0, 1.0, np.nan, np.inf], dtype=dtype)
       n = a_np.shape[0]
   
       A = te.placeholder((n,), name="A", dtype=dtype)
       C = te.compute((n,), lambda i: A[i].astype("bool"), name="C")
   
       # Convert to TIR and create schedule
       mod = te.create_prim_func([A, C])
       sch = tir.Schedule(mod)
   
       # build and invoke the kernel.
       f = tvm.compile(sch.mod, target="llvm")
       dev = tvm.cpu(0)
   
       # launch the kernel.
       a = tvm.runtime.tensor(a_np, dev)
       c = tvm.runtime.empty((n,), dtype="bool", device=dev)
       f(a, c)
       c_np = np.array([False, True, True, True], dtype="bool")
   
       tvm.testing.assert_allclose(c.numpy(), c_np)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to