[Mlir-commits] [mlir] [MLIR][Python] add f8E5M2 and tests for np_to_memref (PR #106028)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 26 02:22:19 PDT 2024
https://github.com/PhrygianGates updated https://github.com/llvm/llvm-project/pull/106028
>From bdbcbaa8372105e38370b1dae842d45a658b9213 Mon Sep 17 00:00:00 2001
From: Zhicheng Xiong <zhichengx at dc2-sim-c01-215.nvidia.com>
Date: Sun, 25 Aug 2024 19:31:36 -0700
Subject: [PATCH 1/2] add f8E5M2 and tests
---
mlir/python/mlir/runtime/np_to_memref.py | 12 +++++++
mlir/test/python/execution_engine.py | 40 +++++++++++++++++++++++-
2 files changed, 51 insertions(+), 1 deletion(-)
diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 882b2751921bfd..8cca1e7ad4a9eb 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -37,6 +37,11 @@ class BF16(ctypes.Structure):
_fields_ = [("bf16", ctypes.c_int16)]
+class F8E5M2(ctypes.Structure):
+ """A ctype representation for MLIR's Float8E5M2."""
+
+ _fields_ = [("f8E5M2", ctypes.c_int8)]
+
# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
def as_ctype(dtp):
@@ -49,6 +54,8 @@ def as_ctype(dtp):
return F16
if ml_dtypes is not None and dtp == ml_dtypes.bfloat16:
return BF16
+ if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2:
+ return F8E5M2
return np.ctypeslib.as_ctypes_type(dtp)
@@ -65,6 +72,11 @@ def to_numpy(array):
), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
if array.dtype == BF16:
return array.view("bfloat16")
+ assert not (
+ array.dtype == F8E5M2 and ml_dtypes is None
+ ), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
+ if array.dtype == F8E5M2:
+ return array.view("float8_e5m2")
return array
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 8125bf3fb8fc92..ae0c691e1cbcd9 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -5,7 +5,7 @@
from mlir.passmanager import *
from mlir.execution_engine import *
from mlir.runtime import *
-from ml_dtypes import bfloat16
+from ml_dtypes import bfloat16, float8_e5m2
# Log everything to stderr and flush so that we have a unified stream to match
@@ -560,6 +560,44 @@ def testBF16Memref():
run(testBF16Memref)
+# Test f8E5M2 memrefs
+# CHECK-LABEL: TEST: testF8E5M2Memref
+def testF8E5M2Memref():
+ with Context():
+ module = Module.parse(
+ """
+ module {
+ func.func @main(%arg0: memref<1xf8E5M2>,
+ %arg1: memref<1xf8E5M2>) attributes { llvm.emit_c_interface } {
+ %0 = arith.constant 0 : index
+ %1 = memref.load %arg0[%0] : memref<1xf8E5M2>
+ memref.store %1, %arg1[%0] : memref<1xf8E5M2>
+ return
+ }
+ } """
+ )
+
+ arg1 = np.array([0.5]).astype(float8_e5m2)
+ arg2 = np.array([0.0]).astype(float8_e5m2)
+
+ arg1_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg1))
+ )
+ arg2_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg2))
+ )
+
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
+
+ # test to-numpy utility
+ # CHECK: [0.5]
+ npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
+ log(npout)
+
+
+run(testF8E5M2Memref)
+
# Test addition of two 2d_memref
# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
>From b9cd8b63308ed41fb05c1007082efb5270e0e4ab Mon Sep 17 00:00:00 2001
From: Zhicheng Xiong <zhichengx at dc2-sim-c01-215.nvidia.com>
Date: Mon, 26 Aug 2024 02:21:56 -0700
Subject: [PATCH 2/2] format
---
mlir/test/python/execution_engine.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index ae0c691e1cbcd9..1cdda63eefe300 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -560,6 +560,7 @@ def testBF16Memref():
run(testBF16Memref)
+
# Test f8E5M2 memrefs
# CHECK-LABEL: TEST: testF8E5M2Memref
def testF8E5M2Memref():
More information about the Mlir-commits
mailing list