[Mlir-commits] [mlir] ba44d7b - [MLIR][test] Fixup for checking for ml_dtypes (#123240)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 17 07:25:12 PST 2025


Author: Konrad Kleine
Date: 2025-01-17T16:25:08+01:00
New Revision: ba44d7ba1fb3e27f51d65ea1af280e00382e09e0

URL: https://github.com/llvm/llvm-project/commit/ba44d7ba1fb3e27f51d65ea1af280e00382e09e0
DIFF: https://github.com/llvm/llvm-project/commit/ba44d7ba1fb3e27f51d65ea1af280e00382e09e0.diff

LOG: [MLIR][test] Fixup for checking for ml_dtypes (#123240)

In order to optionally run some checks that depend on the `ml_dtypes`
python module we have to remove the `CHECK` lines for those tests or
they will be required and missed in the test output.

I've changed to use asserts as recommended in [1].

[1]:
https://github.com/llvm/llvm-project/pull/123061#issuecomment-2596116023

Added: 
    

Modified: 
    mlir/test/python/execution_engine.py

Removed: 
    


################################################################################
diff  --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index e3f41815800d58..d569fcef32bfd2 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -566,13 +566,15 @@ def testBF16Memref():
         execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
 
         # test to-numpy utility
-        # CHECK: [0.5]
-        npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
-        log(npout)
+        x = ranked_memref_to_numpy(arg2_memref_ptr[0])
+        assert len(x) == 1
+        assert x[0] == 0.5
 
 
 if HAS_ML_DTYPES:
     run(testBF16Memref)
+else:
+    log("TEST: testBF16Memref")
 
 
 # Test f8E5M2 memrefs
@@ -606,13 +608,15 @@ def testF8E5M2Memref():
         execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
 
         # test to-numpy utility
-        # CHECK: [0.5]
-        npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
-        log(npout)
+        x = ranked_memref_to_numpy(arg2_memref_ptr[0])
+        assert len(x) == 1
+        assert x[0] == 0.5
 
 
 if HAS_ML_DTYPES:
     run(testF8E5M2Memref)
+else:
+    log("TEST: testF8E5M2Memref")
 
 
 #  Test addition of two 2d_memref


        


More information about the Mlir-commits mailing list