[Mlir-commits] [mlir] Support float8_e3m4 and float8_e4m3 in np_to_memref (PR #186453)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 15 06:40:13 PDT 2026


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/186453

>From d75b9dc0e6f16bc0bbe47c1294ecc50588f782c9 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 13 Mar 2026 11:59:03 -0500
Subject: [PATCH 1/5] Support float8_e3m4 and float8_e4m3 in np_to_memref

---
 mlir/python/mlir/runtime/np_to_memref.py | 23 +++++++
 mlir/test/python/execution_engine.py     | 83 +++++++++++++++++++++++-
 2 files changed, 104 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 8cca1e7ad4a9e..8455e5b8b7b37 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -42,6 +42,15 @@ class F8E5M2(ctypes.Structure):
 
     _fields_ = [("f8E5M2", ctypes.c_int8)]
 
+class F8E3M4(ctypes.Structure):
+    """A ctype representation for MLIR's Float8E3M4."""
+
+    _fields_ = [("f8E3M4", ctypes.c_int8)]
+
+class F8E4M3(ctypes.Structure):
+    """A ctype representation for MLIR's Float8E4M3."""
+
+    _fields_ = [("f8E4M3", ctypes.c_int8)]
 
 # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
 def as_ctype(dtp):
@@ -56,6 +65,10 @@ def as_ctype(dtp):
         return BF16
     if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2:
         return F8E5M2
+    if ml_dtypes is not None and dtp == ml_dtypes.float8_e3m4:
+        return F8E3M4
+    if ml_dtypes is not None and dtp == ml_dtypes.float8_e4m3:
+        return F8E4M3
     return np.ctypeslib.as_ctypes_type(dtp)
 
 
@@ -77,6 +90,16 @@ def to_numpy(array):
     ), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
     if array.dtype == F8E5M2:
         return array.view("float8_e5m2")
+    assert not (
+        array.dtype == F8E3M4 and ml_dtypes is None
+    ), f"float8_e3m4 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
+    if array.dtype == F8E3M4:
+        return array.view("float8_e3m4")
+    assert not (
+        array.dtype == F8E4M3 and ml_dtypes is None
+    ), f"float8_e4m3 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
+    if array.dtype == F8E4M3:
+        return array.view("float8_e4m3")
     return array
 
 
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index b11340f2c19ce..87424f73d086a 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -8,7 +8,7 @@
 from mlir.runtime import *
 
 try:
-    from ml_dtypes import bfloat16, float8_e5m2
+    from ml_dtypes import bfloat16, float8_e5m2, float8_e3m4, float8_e4m3
 
     HAS_ML_DTYPES = True
 except ModuleNotFoundError:
@@ -616,12 +616,91 @@ def testF8E5M2Memref():
         assert len(x) == 1
         assert x[0] == 0.5
 
-
 if HAS_ML_DTYPES:
     run(testF8E5M2Memref)
 else:
     log("TEST: testF8E5M2Memref")
 
+# Test f8E3M4 memrefs
+# CHECK-LABEL: TEST: testF8E3M4Memref
+def testF8E3M4Memref():
+    with Context():
+        module = Module.parse(
+            """
+    module  {
+      func.func @main(%arg0: memref<1xf8E3M4>,
+                      %arg1: memref<1xf8E3M4>) attributes { llvm.emit_c_interface } {
+        %0 = arith.constant 0 : index
+        %1 = memref.load %arg0[%0] : memref<1xf8E3M4>
+        memref.store %1, %arg1[%0] : memref<1xf8E3M4>
+        return
+      }
+    } """
+        )
+
+        arg1 = np.array([0.5]).astype(float8_e3m4)
+        arg2 = np.array([0.0]).astype(float8_e3m4)
+
+        arg1_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg1))
+        )
+        arg2_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg2))
+        )
+
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
+
+        # test to-numpy utility
+        x = ranked_memref_to_numpy(arg2_memref_ptr[0])
+        assert len(x) == 1
+        assert x[0] == 0.5
+
+if HAS_ML_DTYPES:
+    run(testF8E3M4Memref)
+else:
+    log("TEST: testF8E3M4Memref")
+
+
+# Test f8E4M3 memrefs
+# CHECK-LABEL: TEST: testF8E4M3Memref
+def testF8E4M3Memref():
+    with Context():
+        module = Module.parse(
+            """
+    module  {
+      func.func @main(%arg0: memref<1xf8E4M3>,
+                      %arg1: memref<1xf8E4M3>) attributes { llvm.emit_c_interface } {
+        %0 = arith.constant 0 : index
+        %1 = memref.load %arg0[%0] : memref<1xf8E4M3>
+        memref.store %1, %arg1[%0] : memref<1xf8E4M3>
+        return
+      }
+    } """
+        )
+
+        arg1 = np.array([0.5]).astype(float8_e4m3)
+        arg2 = np.array([0.0]).astype(float8_e4m3)
+
+        arg1_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg1))
+        )
+        arg2_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg2))
+        )
+
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
+
+        # test to-numpy utility
+        x = ranked_memref_to_numpy(arg2_memref_ptr[0])
+        assert len(x) == 1
+        assert x[0] == 0.5
+
+if HAS_ML_DTYPES:
+    run(testF8E4M3Memref)
+else:
+    log("TEST: testF8E4M3Memref")
 
 #  Test addition of two 2d_memref
 # CHECK-LABEL: TEST: testDynamicMemrefAdd2D

>From b0949ea2c8928d690b2c730ce2f6e536fc8a6887 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 13 Mar 2026 12:07:23 -0500
Subject: [PATCH 2/5] formatting

---
 mlir/python/mlir/runtime/np_to_memref.py | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 8455e5b8b7b37..00c8609f2a27f 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -37,21 +37,25 @@ class BF16(ctypes.Structure):
 
     _fields_ = [("bf16", ctypes.c_int16)]
 
+
 class F8E5M2(ctypes.Structure):
     """A ctype representation for MLIR's Float8E5M2."""
 
     _fields_ = [("f8E5M2", ctypes.c_int8)]
 
+
 class F8E3M4(ctypes.Structure):
     """A ctype representation for MLIR's Float8E3M4."""
 
     _fields_ = [("f8E3M4", ctypes.c_int8)]
 
+
 class F8E4M3(ctypes.Structure):
     """A ctype representation for MLIR's Float8E4M3."""
 
     _fields_ = [("f8E4M3", ctypes.c_int8)]
 
+
 # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
 def as_ctype(dtp):
     """Converts dtype to ctype."""

>From c93960da918a206de713dd7860942d5c6d207f7f Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 13 Mar 2026 12:15:03 -0500
Subject: [PATCH 3/5] formatting again

---
 mlir/test/python/execution_engine.py | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 87424f73d086a..858ee089042ad 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -616,11 +616,13 @@ def testF8E5M2Memref():
         assert len(x) == 1
         assert x[0] == 0.5
 
+
 if HAS_ML_DTYPES:
     run(testF8E5M2Memref)
 else:
     log("TEST: testF8E5M2Memref")
 
+
 # Test f8E3M4 memrefs
 # CHECK-LABEL: TEST: testF8E3M4Memref
 def testF8E3M4Memref():
@@ -656,6 +658,7 @@ def testF8E3M4Memref():
         assert len(x) == 1
         assert x[0] == 0.5
 
+
 if HAS_ML_DTYPES:
     run(testF8E3M4Memref)
 else:
@@ -697,11 +700,13 @@ def testF8E4M3Memref():
         assert len(x) == 1
         assert x[0] == 0.5
 
+
 if HAS_ML_DTYPES:
     run(testF8E4M3Memref)
 else:
     log("TEST: testF8E4M3Memref")
 
+
 #  Test addition of two 2d_memref
 # CHECK-LABEL: TEST: testDynamicMemrefAdd2D
 def testDynamicMemrefAdd2D():

>From 1d8d2d61b324437212ec150d632f3a24510aa1a2 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 13 Mar 2026 21:04:43 -0500
Subject: [PATCH 4/5] refactor to_numpy

---
 mlir/python/mlir/runtime/np_to_memref.py | 17 ++++++-----------
 1 file changed, 6 insertions(+), 11 deletions(-)

diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 00c8609f2a27f..5b88e5760fa4b 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -76,6 +76,9 @@ def as_ctype(dtp):
     return np.ctypeslib.as_ctypes_type(dtp)
 
 
+ML_DTYPES_REQUIRED = [BF16, F8E5M2, F8E3M4, F8E4M3]
+
+
 def to_numpy(array):
     """Converts ctypes array back to numpy dtype array."""
     if array.dtype == C128:
@@ -85,25 +88,17 @@ def to_numpy(array):
     if array.dtype == F16:
         return array.view("float16")
     assert not (
-        array.dtype == BF16 and ml_dtypes is None
-    ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
+        array.dtype in ML_DTYPES_REQUIRED and ml_dtypes is None
+    ), f"{array.dtype.__name__} requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
     if array.dtype == BF16:
         return array.view("bfloat16")
-    assert not (
-        array.dtype == F8E5M2 and ml_dtypes is None
-    ), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
     if array.dtype == F8E5M2:
         return array.view("float8_e5m2")
-    assert not (
-        array.dtype == F8E3M4 and ml_dtypes is None
-    ), f"float8_e3m4 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
     if array.dtype == F8E3M4:
         return array.view("float8_e3m4")
-    assert not (
-        array.dtype == F8E4M3 and ml_dtypes is None
-    ), f"float8_e4m3 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
     if array.dtype == F8E4M3:
         return array.view("float8_e4m3")
+
     return array
 
 

>From b5ff726b6c703cb28373f0f551346e150a03c01f Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 15 Mar 2026 08:39:55 -0500
Subject: [PATCH 5/5] address comments

---
 mlir/python/mlir/runtime/np_to_memref.py | 7 ++-----
 1 file changed, 2 insertions(+), 5 deletions(-)

diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 5b88e5760fa4b..d65ba51afdb90 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -76,9 +76,6 @@ def as_ctype(dtp):
     return np.ctypeslib.as_ctypes_type(dtp)
 
 
-ML_DTYPES_REQUIRED = [BF16, F8E5M2, F8E3M4, F8E4M3]
-
-
 def to_numpy(array):
     """Converts ctypes array back to numpy dtype array."""
     if array.dtype == C128:
@@ -88,8 +85,8 @@ def to_numpy(array):
     if array.dtype == F16:
         return array.view("float16")
     assert not (
-        array.dtype in ML_DTYPES_REQUIRED and ml_dtypes is None
-    ), f"{array.dtype.__name__} requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
+        array.dtype in (BF16, F8E5M2, F8E3M4, F8E4M3) and ml_dtypes is None
+    ), f"{array.dtype=} requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
     if array.dtype == BF16:
         return array.view("bfloat16")
     if array.dtype == F8E5M2:



More information about the Mlir-commits mailing list