[Mlir-commits] [mlir] [MLIR][test] Check for ml_dtypes before running tests (PR #123061)
Konrad Kleine
llvmlistbot at llvm.org
Thu Jan 16 07:44:50 PST 2025
kwk wrote:
@makslevental, unfortunately the "solution" doesn't work as I expected. The problem are the `CHECK`-lines which are of course not conditionalized. Take this for example:
```python
# Test bf16 memrefs
# CHECK-LABEL: TEST: testBF16Memref
def testBF16Memref():
with Context():
module = Module.parse(
"""
module {
func.func @main(%arg0: memref<1xbf16>,
%arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
%0 = arith.constant 0 : index
%1 = memref.load %arg0[%0] : memref<1xbf16>
memref.store %1, %arg1[%0] : memref<1xbf16>
return
}
} """
)
arg1 = np.array([0.5]).astype(bfloat16)
arg2 = np.array([0.0]).astype(bfloat16)
arg1_memref_ptr = ctypes.pointer(
ctypes.pointer(get_ranked_memref_descriptor(arg1))
)
arg2_memref_ptr = ctypes.pointer(
ctypes.pointer(get_ranked_memref_descriptor(arg2))
)
execution_engine = ExecutionEngine(lowerToLLVM(module))
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
# test to-numpy utility
# CHECK: [0.5]
npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
log(npout)
if HAS_ML_DTYPES:
run(testBF16Memref)
```
The expected `# CHECK-LABEL: TEST: testBF16Memref` and `# CHECK: [0.5]` will not be found.
What options do we have here?
https://github.com/llvm/llvm-project/pull/123061
More information about the Mlir-commits
mailing list