[Mlir-commits] [mlir] [MLIR][test] Fixup for checking for ml_dtypes (PR #123240)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 16 12:46:22 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Konrad Kleine (kwk)

<details>
<summary>Changes</summary>

In order to optionally run some checks that depend on the `ml_dtypes` python module we have to remove the `CHECK` lines for those tests or they will be required and missed in the test output.

I've changed to use asserts as recommended in [1].

[1]: https://github.com/llvm/llvm-project/pull/123061#issuecomment-2596116023

---
Full diff: https://github.com/llvm/llvm-project/pull/123240.diff


1 Files Affected:

- (modified) mlir/test/python/execution_engine.py (+6-8) 


``````````diff
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index e3f41815800d58..cab6b69a01f4cc 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -536,7 +536,6 @@ def testComplexUnrankedMemrefAdd():
 
 
 # Test bf16 memrefs
-# CHECK-LABEL: TEST: testBF16Memref
 def testBF16Memref():
     with Context():
         module = Module.parse(
@@ -566,9 +565,9 @@ def testBF16Memref():
         execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
 
         # test to-numpy utility
-        # CHECK: [0.5]
-        npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
-        log(npout)
+        x = ranked_memref_to_numpy(arg2_memref_ptr[0])
+        assert len(x) == 1
+        assert x[0] == 0.5
 
 
 if HAS_ML_DTYPES:
@@ -576,7 +575,6 @@ def testBF16Memref():
 
 
 # Test f8E5M2 memrefs
-# CHECK-LABEL: TEST: testF8E5M2Memref
 def testF8E5M2Memref():
     with Context():
         module = Module.parse(
@@ -606,9 +604,9 @@ def testF8E5M2Memref():
         execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
 
         # test to-numpy utility
-        # CHECK: [0.5]
-        npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
-        log(npout)
+        x = ranked_memref_to_numpy(arg2_memref_ptr[0])
+        assert len(x) == 1
+        assert x[0] == 0.5
 
 
 if HAS_ML_DTYPES:

``````````

</details>


https://github.com/llvm/llvm-project/pull/123240


More information about the Mlir-commits mailing list