[Mlir-commits] [mlir] c8cac33 - [MLIR][Python] add f8E5M2 and tests for np_to_memref (#106028)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 26 19:40:41 PDT 2024


Author: PhrygianGates
Date: 2024-08-26T21:40:38-05:00
New Revision: c8cac33ad23acc671a0a7390a5254b9f6e848138

URL: https://github.com/llvm/llvm-project/commit/c8cac33ad23acc671a0a7390a5254b9f6e848138
DIFF: https://github.com/llvm/llvm-project/commit/c8cac33ad23acc671a0a7390a5254b9f6e848138.diff

LOG: [MLIR][Python] add f8E5M2 and tests for np_to_memref (#106028)

add f8E5M2 and tests for np_to_memref

---------

Co-authored-by: Zhicheng Xiong <zhichengx at dc2-sim-c01-215.nvidia.com>

Added: 
    

Modified: 
    mlir/python/mlir/runtime/np_to_memref.py
    mlir/test/python/execution_engine.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 882b2751921bfd..8cca1e7ad4a9eb 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -37,6 +37,11 @@ class BF16(ctypes.Structure):
 
     _fields_ = [("bf16", ctypes.c_int16)]
 
+class F8E5M2(ctypes.Structure):
+    """A ctype representation for MLIR's Float8E5M2."""
+
+    _fields_ = [("f8E5M2", ctypes.c_int8)]
+
 
 # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
 def as_ctype(dtp):
@@ -49,6 +54,8 @@ def as_ctype(dtp):
         return F16
     if ml_dtypes is not None and dtp == ml_dtypes.bfloat16:
         return BF16
+    if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2:
+        return F8E5M2
     return np.ctypeslib.as_ctypes_type(dtp)
 
 
@@ -65,6 +72,11 @@ def to_numpy(array):
     ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
     if array.dtype == BF16:
         return array.view("bfloat16")
+    assert not (
+        array.dtype == F8E5M2 and ml_dtypes is None
+    ), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
+    if array.dtype == F8E5M2:
+        return array.view("float8_e5m2")
     return array
 
 

diff  --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 8125bf3fb8fc92..1cdda63eefe300 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -5,7 +5,7 @@
 from mlir.passmanager import *
 from mlir.execution_engine import *
 from mlir.runtime import *
-from ml_dtypes import bfloat16
+from ml_dtypes import bfloat16, float8_e5m2
 
 
 # Log everything to stderr and flush so that we have a unified stream to match
@@ -561,6 +561,45 @@ def testBF16Memref():
 run(testBF16Memref)
 
 
+# Test f8E5M2 memrefs
+# CHECK-LABEL: TEST: testF8E5M2Memref
+def testF8E5M2Memref():
+    with Context():
+        module = Module.parse(
+            """
+    module  {
+      func.func @main(%arg0: memref<1xf8E5M2>,
+                      %arg1: memref<1xf8E5M2>) attributes { llvm.emit_c_interface } {
+        %0 = arith.constant 0 : index
+        %1 = memref.load %arg0[%0] : memref<1xf8E5M2>
+        memref.store %1, %arg1[%0] : memref<1xf8E5M2>
+        return
+      }
+    } """
+        )
+
+        arg1 = np.array([0.5]).astype(float8_e5m2)
+        arg2 = np.array([0.0]).astype(float8_e5m2)
+
+        arg1_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg1))
+        )
+        arg2_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg2))
+        )
+
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
+
+        # test to-numpy utility
+        # CHECK: [0.5]
+        npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
+        log(npout)
+
+
+run(testF8E5M2Memref)
+
+
 #  Test addition of two 2d_memref
 # CHECK-LABEL: TEST: testDynamicMemrefAdd2D
 def testDynamicMemrefAdd2D():


        


More information about the Mlir-commits mailing list