[Mlir-commits] [mlir] Support float8_e3m4 and float8_e4m3 in np_to_memref (PR #186453)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 13 10:01:28 PDT 2026
https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/186453
>From d75b9dc0e6f16bc0bbe47c1294ecc50588f782c9 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 13 Mar 2026 11:59:03 -0500
Subject: [PATCH] Support float8_e3m4 and float8_e4m3 in np_to_memref
---
mlir/python/mlir/runtime/np_to_memref.py | 23 +++++++
mlir/test/python/execution_engine.py | 83 +++++++++++++++++++++++-
2 files changed, 104 insertions(+), 2 deletions(-)
diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 8cca1e7ad4a9e..8455e5b8b7b37 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -42,6 +42,15 @@ class F8E5M2(ctypes.Structure):
_fields_ = [("f8E5M2", ctypes.c_int8)]
+class F8E3M4(ctypes.Structure):
+ """A ctype representation for MLIR's Float8E3M4."""
+
+ _fields_ = [("f8E3M4", ctypes.c_int8)]
+
+class F8E4M3(ctypes.Structure):
+ """A ctype representation for MLIR's Float8E4M3."""
+
+ _fields_ = [("f8E4M3", ctypes.c_int8)]
# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
def as_ctype(dtp):
@@ -56,6 +65,10 @@ def as_ctype(dtp):
return BF16
if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2:
return F8E5M2
+ if ml_dtypes is not None and dtp == ml_dtypes.float8_e3m4:
+ return F8E3M4
+ if ml_dtypes is not None and dtp == ml_dtypes.float8_e4m3:
+ return F8E4M3
return np.ctypeslib.as_ctypes_type(dtp)
@@ -77,6 +90,16 @@ def to_numpy(array):
), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
if array.dtype == F8E5M2:
return array.view("float8_e5m2")
+ assert not (
+ array.dtype == F8E3M4 and ml_dtypes is None
+ ), f"float8_e3m4 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
+ if array.dtype == F8E3M4:
+ return array.view("float8_e3m4")
+ assert not (
+ array.dtype == F8E4M3 and ml_dtypes is None
+ ), f"float8_e4m3 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
+ if array.dtype == F8E4M3:
+ return array.view("float8_e4m3")
return array
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index b11340f2c19ce..87424f73d086a 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -8,7 +8,7 @@
from mlir.runtime import *
try:
- from ml_dtypes import bfloat16, float8_e5m2
+ from ml_dtypes import bfloat16, float8_e5m2, float8_e3m4, float8_e4m3
HAS_ML_DTYPES = True
except ModuleNotFoundError:
@@ -616,12 +616,91 @@ def testF8E5M2Memref():
assert len(x) == 1
assert x[0] == 0.5
-
if HAS_ML_DTYPES:
run(testF8E5M2Memref)
else:
log("TEST: testF8E5M2Memref")
+# Test f8E3M4 memrefs
+# CHECK-LABEL: TEST: testF8E3M4Memref
+def testF8E3M4Memref():
+ with Context():
+ module = Module.parse(
+ """
+ module {
+ func.func @main(%arg0: memref<1xf8E3M4>,
+ %arg1: memref<1xf8E3M4>) attributes { llvm.emit_c_interface } {
+ %0 = arith.constant 0 : index
+ %1 = memref.load %arg0[%0] : memref<1xf8E3M4>
+ memref.store %1, %arg1[%0] : memref<1xf8E3M4>
+ return
+ }
+ } """
+ )
+
+ arg1 = np.array([0.5]).astype(float8_e3m4)
+ arg2 = np.array([0.0]).astype(float8_e3m4)
+
+ arg1_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg1))
+ )
+ arg2_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg2))
+ )
+
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
+
+ # test to-numpy utility
+ x = ranked_memref_to_numpy(arg2_memref_ptr[0])
+ assert len(x) == 1
+ assert x[0] == 0.5
+
+if HAS_ML_DTYPES:
+ run(testF8E3M4Memref)
+else:
+ log("TEST: testF8E3M4Memref")
+
+
+# Test f8E4M3 memrefs
+# CHECK-LABEL: TEST: testF8E4M3Memref
+def testF8E4M3Memref():
+ with Context():
+ module = Module.parse(
+ """
+ module {
+ func.func @main(%arg0: memref<1xf8E4M3>,
+ %arg1: memref<1xf8E4M3>) attributes { llvm.emit_c_interface } {
+ %0 = arith.constant 0 : index
+ %1 = memref.load %arg0[%0] : memref<1xf8E4M3>
+ memref.store %1, %arg1[%0] : memref<1xf8E4M3>
+ return
+ }
+ } """
+ )
+
+ arg1 = np.array([0.5]).astype(float8_e4m3)
+ arg2 = np.array([0.0]).astype(float8_e4m3)
+
+ arg1_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg1))
+ )
+ arg2_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg2))
+ )
+
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
+
+ # test to-numpy utility
+ x = ranked_memref_to_numpy(arg2_memref_ptr[0])
+ assert len(x) == 1
+ assert x[0] == 0.5
+
+if HAS_ML_DTYPES:
+ run(testF8E4M3Memref)
+else:
+ log("TEST: testF8E4M3Memref")
# Test addition of two 2d_memref
# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
More information about the Mlir-commits
mailing list