[Mlir-commits] [mlir] 9e22690 - Revert "Support float8_e3m4 and float8_e4m3 in np_to_memref (#186453)" (#186677)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 15 09:53:04 PDT 2026


Author: srcarroll
Date: 2026-03-15T11:52:59-05:00
New Revision: 9e22690671e946c8a6a2a92497d308c359ba31e4

URL: https://github.com/llvm/llvm-project/commit/9e22690671e946c8a6a2a92497d308c359ba31e4
DIFF: https://github.com/llvm/llvm-project/commit/9e22690671e946c8a6a2a92497d308c359ba31e4.diff

LOG: Revert "Support float8_e3m4 and float8_e4m3 in np_to_memref (#186453)" (#186677)

This reverts commit 57427f84fe5fdda71aef4be257ed28d7b4f55d05.

For some reason mlir-nvidia CI is failing to import `float8_e3m4` from
`ml_dtypes`. See
https://lab.llvm.org/buildbot/#/builders/138/builds/27095.

Added: 
    

Modified: 
    mlir/python/mlir/runtime/np_to_memref.py
    mlir/test/python/execution_engine.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index d65ba51afdb90..8cca1e7ad4a9e 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -37,25 +37,12 @@ class BF16(ctypes.Structure):
 
     _fields_ = [("bf16", ctypes.c_int16)]
 
-
 class F8E5M2(ctypes.Structure):
     """A ctype representation for MLIR's Float8E5M2."""
 
     _fields_ = [("f8E5M2", ctypes.c_int8)]
 
 
-class F8E3M4(ctypes.Structure):
-    """A ctype representation for MLIR's Float8E3M4."""
-
-    _fields_ = [("f8E3M4", ctypes.c_int8)]
-
-
-class F8E4M3(ctypes.Structure):
-    """A ctype representation for MLIR's Float8E4M3."""
-
-    _fields_ = [("f8E4M3", ctypes.c_int8)]
-
-
 # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
 def as_ctype(dtp):
     """Converts dtype to ctype."""
@@ -69,10 +56,6 @@ def as_ctype(dtp):
         return BF16
     if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2:
         return F8E5M2
-    if ml_dtypes is not None and dtp == ml_dtypes.float8_e3m4:
-        return F8E3M4
-    if ml_dtypes is not None and dtp == ml_dtypes.float8_e4m3:
-        return F8E4M3
     return np.ctypeslib.as_ctypes_type(dtp)
 
 
@@ -85,17 +68,15 @@ def to_numpy(array):
     if array.dtype == F16:
         return array.view("float16")
     assert not (
-        array.dtype in (BF16, F8E5M2, F8E3M4, F8E4M3) and ml_dtypes is None
-    ), f"{array.dtype=} requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
+        array.dtype == BF16 and ml_dtypes is None
+    ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
     if array.dtype == BF16:
         return array.view("bfloat16")
+    assert not (
+        array.dtype == F8E5M2 and ml_dtypes is None
+    ), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
     if array.dtype == F8E5M2:
         return array.view("float8_e5m2")
-    if array.dtype == F8E3M4:
-        return array.view("float8_e3m4")
-    if array.dtype == F8E4M3:
-        return array.view("float8_e4m3")
-
     return array
 
 

diff  --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 858ee089042ad..b11340f2c19ce 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -8,7 +8,7 @@
 from mlir.runtime import *
 
 try:
-    from ml_dtypes import bfloat16, float8_e5m2, float8_e3m4, float8_e4m3
+    from ml_dtypes import bfloat16, float8_e5m2
 
     HAS_ML_DTYPES = True
 except ModuleNotFoundError:
@@ -623,90 +623,6 @@ def testF8E5M2Memref():
     log("TEST: testF8E5M2Memref")
 
 
-# Test f8E3M4 memrefs
-# CHECK-LABEL: TEST: testF8E3M4Memref
-def testF8E3M4Memref():
-    with Context():
-        module = Module.parse(
-            """
-    module  {
-      func.func @main(%arg0: memref<1xf8E3M4>,
-                      %arg1: memref<1xf8E3M4>) attributes { llvm.emit_c_interface } {
-        %0 = arith.constant 0 : index
-        %1 = memref.load %arg0[%0] : memref<1xf8E3M4>
-        memref.store %1, %arg1[%0] : memref<1xf8E3M4>
-        return
-      }
-    } """
-        )
-
-        arg1 = np.array([0.5]).astype(float8_e3m4)
-        arg2 = np.array([0.0]).astype(float8_e3m4)
-
-        arg1_memref_ptr = ctypes.pointer(
-            ctypes.pointer(get_ranked_memref_descriptor(arg1))
-        )
-        arg2_memref_ptr = ctypes.pointer(
-            ctypes.pointer(get_ranked_memref_descriptor(arg2))
-        )
-
-        execution_engine = ExecutionEngine(lowerToLLVM(module))
-        execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
-
-        # test to-numpy utility
-        x = ranked_memref_to_numpy(arg2_memref_ptr[0])
-        assert len(x) == 1
-        assert x[0] == 0.5
-
-
-if HAS_ML_DTYPES:
-    run(testF8E3M4Memref)
-else:
-    log("TEST: testF8E3M4Memref")
-
-
-# Test f8E4M3 memrefs
-# CHECK-LABEL: TEST: testF8E4M3Memref
-def testF8E4M3Memref():
-    with Context():
-        module = Module.parse(
-            """
-    module  {
-      func.func @main(%arg0: memref<1xf8E4M3>,
-                      %arg1: memref<1xf8E4M3>) attributes { llvm.emit_c_interface } {
-        %0 = arith.constant 0 : index
-        %1 = memref.load %arg0[%0] : memref<1xf8E4M3>
-        memref.store %1, %arg1[%0] : memref<1xf8E4M3>
-        return
-      }
-    } """
-        )
-
-        arg1 = np.array([0.5]).astype(float8_e4m3)
-        arg2 = np.array([0.0]).astype(float8_e4m3)
-
-        arg1_memref_ptr = ctypes.pointer(
-            ctypes.pointer(get_ranked_memref_descriptor(arg1))
-        )
-        arg2_memref_ptr = ctypes.pointer(
-            ctypes.pointer(get_ranked_memref_descriptor(arg2))
-        )
-
-        execution_engine = ExecutionEngine(lowerToLLVM(module))
-        execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
-
-        # test to-numpy utility
-        x = ranked_memref_to_numpy(arg2_memref_ptr[0])
-        assert len(x) == 1
-        assert x[0] == 0.5
-
-
-if HAS_ML_DTYPES:
-    run(testF8E4M3Memref)
-else:
-    log("TEST: testF8E4M3Memref")
-
-
 #  Test addition of two 2d_memref
 # CHECK-LABEL: TEST: testDynamicMemrefAdd2D
 def testDynamicMemrefAdd2D():


        


More information about the Mlir-commits mailing list