[Mlir-commits] [mlir] [MLIR][test] Fixup for checking for ml_dtypes (PR #123240)
Konrad Kleine
llvmlistbot at llvm.org
Fri Jan 17 00:54:00 PST 2025
https://github.com/kwk updated https://github.com/llvm/llvm-project/pull/123240
>From e9e99f7fa81754bc66101d9ceaaceec27a22a22d Mon Sep 17 00:00:00 2001
From: Konrad Kleine <kkleine at redhat.com>
Date: Thu, 16 Jan 2025 21:41:52 +0100
Subject: [PATCH] [MLIR][test] Fixup for checking for ml_dtypes
In order to optionally run some checks that depend on the `ml_dtypes`
python module we have to remove the `CHECK` lines for those tests or
they will be required and missed in the test output.
I've changed to use asserts as recommended in [1].
[1]: https://github.com/llvm/llvm-project/pull/123061#issuecomment-2596116023
---
mlir/test/python/execution_engine.py | 16 ++++++++++------
1 file changed, 10 insertions(+), 6 deletions(-)
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index e3f41815800d58..d569fcef32bfd2 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -566,13 +566,15 @@ def testBF16Memref():
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
# test to-numpy utility
- # CHECK: [0.5]
- npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
- log(npout)
+ x = ranked_memref_to_numpy(arg2_memref_ptr[0])
+ assert len(x) == 1
+ assert x[0] == 0.5
if HAS_ML_DTYPES:
run(testBF16Memref)
+else:
+ log("TEST: testBF16Memref")
# Test f8E5M2 memrefs
@@ -606,13 +608,15 @@ def testF8E5M2Memref():
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
# test to-numpy utility
- # CHECK: [0.5]
- npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
- log(npout)
+ x = ranked_memref_to_numpy(arg2_memref_ptr[0])
+ assert len(x) == 1
+ assert x[0] == 0.5
if HAS_ML_DTYPES:
run(testF8E5M2Memref)
+else:
+ log("TEST: testF8E5M2Memref")
# Test addition of two 2d_memref
More information about the Mlir-commits
mailing list