This is an automated email from the ASF dual-hosted git repository.

lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9899f9cd28 [AOT][Testing] Improve output mismatch information on test 
failure (#16765)
9899f9cd28 is described below

commit 9899f9cd2801b3234437df5cd8ab10504b9608bc
Author: Andrei Hutu <andrei.h...@arm.com>
AuthorDate: Mon Mar 25 09:06:49 2024 +0000

    [AOT][Testing] Improve output mismatch information on test failure (#16765)
    
    Enhanced AOT test harness to include overall mismatch percentage and the 
individual mismatch positions from the output tensor for debugging test 
failures. Both of these are still gated behind `print_output_on_mismatch == 
True`.
    I also added tests to check for the presence and correctness of this new 
debug information.
    Sample output:
    
    ```
    Element [Position]: Actual, Reference
    -------------------------------------
    Element [0, 8, 8, 7]: 521.846313, 521.847412
    Element [0, 9, 8, 51]: 478.874359, 478.875549
    Element [0, 9, 9, 6]: 462.901001, 462.899658
    
    Mismatched elements: 3 / 16384 (0.02%)
    ...
    ```
---
 python/tvm/testing/aot.py                       | 48 ++++++++++++++++-------
 tests/python/relay/aot/test_aot_test_harness.py | 52 ++++++++++++++++++++++++-
 2 files changed, 85 insertions(+), 15 deletions(-)

diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py
index 3a117624df..959d1cf58e 100644
--- a/python/tvm/testing/aot.py
+++ b/python/tvm/testing/aot.py
@@ -476,20 +476,40 @@ def _emit_main_compare(
 
         if print_output_on_mismatch:
             main_file.write(
-                f"int mismatch = 0;"
-                f'printf("Actual, Reference\\n");\n'
-                f"for (int i = 0; i<{data_length_var_name}; i++) {{\n"
-                f"\tif ({comparison_function}({actual_data_name}[i]-"
-                f"{expected_data_name}[i]) > {tolerance}) {{\n"
-                f'\t\tprintf("{value_format_specifier}, 
{value_format_specifier}\\n"'
-                f", {actual_data_name}[i], {expected_data_name}[i]);\n"
-                f"\t\tmismatch = 1;\n"
-                f"\t}}\n"
-                f"}}"
-                f"if (mismatch == 1) {{\n"
-                f'\tprintf("{AOT_FAILURE_TOKEN}\\n");\n'
-                f"\treturn -1;\n"
-                f"}}"
+                f"""
+                {{
+                int mismatch = 0;
+                int out_ndim = {outputs[key].ndim};
+                int out_shape[] = {{{','.join(map(str, outputs[key].shape))}}};
+                int out_indices[out_ndim];
+                printf("Element [Position]: Actual, Reference\\n");
+                printf("-------------------------------------\\n");
+                for (int i = 0; i<{data_length_var_name}; i++) {{
+                  if ({comparison_function}({actual_data_name}[i] -
+                      {expected_data_name}[i]) > {tolerance}) {{
+                    int flat_index = i;
+                    for (int j = out_ndim - 1; j >= 0; j--){{
+                      out_indices[j] = flat_index % out_shape[j];
+                      flat_index /= out_shape[j];
+                    }}
+                    printf("Element [%d", out_indices[0]);
+                    for (int j = 1; j < out_ndim; j++)
+                      printf(", %d", out_indices[j]);
+                    printf("]: {value_format_specifier}, 
{value_format_specifier}\\n",
+                           {actual_data_name}[i], {expected_data_name}[i]);
+                    mismatch += 1;
+                  }}
+                }}
+                if (mismatch >= 1) {{
+                  float percent_mismatched =
+                      ((float) mismatch) / ((float) {data_length_var_name}) * 
100;
+                  printf("\\nMismatched elements: %d / %zu (%.2f%%)\\n",
+                         mismatch, {data_length_var_name}, percent_mismatched);
+                  printf("{AOT_FAILURE_TOKEN}\\n");
+                  return -1;
+                }}
+                }}
+                """
             )
         else:
             main_file.write(
diff --git a/tests/python/relay/aot/test_aot_test_harness.py 
b/tests/python/relay/aot/test_aot_test_harness.py
index 8ec9506f9f..3d10f15d4a 100644
--- a/tests/python/relay/aot/test_aot_test_harness.py
+++ b/tests/python/relay/aot/test_aot_test_harness.py
@@ -46,7 +46,57 @@ def test_output_on_mismatch_option():
         ).astype(dtype)
     }
 
-    msg = ".*Actual, Reference\n2.000000, 0.000000\nAOT_TEST_FAILURE.*"
+    msg = ".*Actual, Reference(\n|.)*2.000000, 
0.000000(\n|.)*AOT_TEST_FAILURE.*"
+    with pytest.raises(RuntimeError, match=msg):
+        compile_and_run(
+            AOTTestModel(module=tvm.IRModule.from_expr(func), inputs={}, 
outputs=outputs),
+            test_runner,
+            interface_api,
+            use_unpacked_api,
+            print_output_on_mismatch=True,
+        )
+
+
+def test_output_position_on_mismatch():
+    """
+    Test the mismatch position output for the print_output_on_mismatch option.
+    """
+    interface_api = "packed"
+    use_unpacked_api = True
+    test_runner = AOTTestRunner()
+    dtype = "float32"
+
+    x = np.zeros(shape=(2, 2), dtype=dtype)
+    x[-1, -1] = 1
+    func = relay.Function([], relay.const(x, dtype=dtype))
+    outputs = {"output": np.zeros(shape=(2, 2), dtype=dtype)}
+
+    msg = ".*Element \\[1, 1\\]:.*"
+    with pytest.raises(RuntimeError, match=msg):
+        compile_and_run(
+            AOTTestModel(module=tvm.IRModule.from_expr(func), inputs={}, 
outputs=outputs),
+            test_runner,
+            interface_api,
+            use_unpacked_api,
+            print_output_on_mismatch=True,
+        )
+
+
+def test_mismatch_percentage():
+    """
+    Test the mismatch percentage for the print_output_on_mismatch option.
+    """
+    interface_api = "packed"
+    use_unpacked_api = True
+    test_runner = AOTTestRunner()
+    dtype = "float32"
+
+    x = np.zeros(shape=(8,), dtype=dtype)
+    x[0] = 1
+    func = relay.Function([], relay.const(x, dtype=dtype))
+    outputs = {"output": np.zeros(shape=(8,), dtype=dtype)}
+
+    msg = ".*Mismatched elements: 1 / 8 \\(12.50%\\).*"
     with pytest.raises(RuntimeError, match=msg):
         compile_and_run(
             AOTTestModel(module=tvm.IRModule.from_expr(func), inputs={}, 
outputs=outputs),

Reply via email to