[Mlir-commits] [mlir] [MLIR][test] Check for ml_dtypes before running tests (PR #123061)

Konrad Kleine llvmlistbot at llvm.org
Thu Jan 16 07:44:50 PST 2025


kwk wrote:

@makslevental, unfortunately the "solution" doesn't work as I expected. The problem are the `CHECK`-lines which are of course not conditionalized. Take this for example:

```python
# Test bf16 memrefs
# CHECK-LABEL: TEST: testBF16Memref
def testBF16Memref():
    with Context():
        module = Module.parse(
            """
    module  {
      func.func @main(%arg0: memref<1xbf16>,
                      %arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
        %0 = arith.constant 0 : index
        %1 = memref.load %arg0[%0] : memref<1xbf16>
        memref.store %1, %arg1[%0] : memref<1xbf16>
        return
      }
    } """
        )

        arg1 = np.array([0.5]).astype(bfloat16)
        arg2 = np.array([0.0]).astype(bfloat16)

        arg1_memref_ptr = ctypes.pointer(
            ctypes.pointer(get_ranked_memref_descriptor(arg1))
        )
        arg2_memref_ptr = ctypes.pointer(
            ctypes.pointer(get_ranked_memref_descriptor(arg2))
        )

        execution_engine = ExecutionEngine(lowerToLLVM(module))
        execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)

        # test to-numpy utility
        # CHECK: [0.5]
        npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
        log(npout)


if HAS_ML_DTYPES:
    run(testBF16Memref)
```

The expected `# CHECK-LABEL: TEST: testBF16Memref` and `# CHECK: [0.5]` will not be found.

What options do we have here?

https://github.com/llvm/llvm-project/pull/123061


More information about the Mlir-commits mailing list