[Mlir-commits] [mlir] [MLIR][Python] add ctype python binding support for bf16 (PR #92489)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 16 21:48:30 PDT 2024
https://github.com/xurui1995 updated https://github.com/llvm/llvm-project/pull/92489
>From 8a36de1a5b080af97bc7af1ae431436ff965956c Mon Sep 17 00:00:00 2001
From: "Xu, Rui" <rui.xu at intel.com>
Date: Thu, 16 May 2024 08:49:50 -0700
Subject: [PATCH 1/3] [MLIR][Python] add ctype python binding support for bf16
---
mlir/python/mlir/runtime/np_to_memref.py | 10 +++++
mlir/python/requirements.txt | 3 +-
mlir/test/python/execution_engine.py | 51 ++++++++++++++++++++++++
3 files changed, 63 insertions(+), 1 deletion(-)
diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index f6b706f9bc8ae..55e6a6cc5ab3e 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -6,6 +6,7 @@
import numpy as np
import ctypes
+import ml_dtypes
class C128(ctypes.Structure):
@@ -25,6 +26,11 @@ class F16(ctypes.Structure):
_fields_ = [("f16", ctypes.c_int16)]
+class BF16(ctypes.Structure):
+ """A ctype representation for MLIR's BFloat16."""
+
+ _fields_ = [("bf16", ctypes.c_int16)]
+
# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
def as_ctype(dtp):
@@ -35,6 +41,8 @@ def as_ctype(dtp):
return C64
if dtp == np.dtype(np.float16):
return F16
+ if dtp == ml_dtypes.bfloat16:
+ return BF16
return np.ctypeslib.as_ctypes_type(dtp)
@@ -46,6 +54,8 @@ def to_numpy(array):
return array.view("complex64")
if array.dtype == F16:
return array.view("float16")
+ if array.dtype == BF16:
+ return array.view("bfloat16")
return array
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index acd6dbb25edaf..90acba8d65f09 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,3 +1,4 @@
numpy>=1.19.5, <=1.26
pybind11>=2.9.0, <=2.10.3
-PyYAML>=5.3.1, <=6.0.1
\ No newline at end of file
+PyYAML>=5.3.1, <=6.0.1
+ml_dtypes
\ No newline at end of file
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index e8b47007a8907..61d145ef24d95 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -5,6 +5,7 @@
from mlir.passmanager import *
from mlir.execution_engine import *
from mlir.runtime import *
+from ml_dtypes import bfloat16
# Log everything to stderr and flush so that we have a unified stream to match
@@ -521,6 +522,56 @@ def testComplexUnrankedMemrefAdd():
run(testComplexUnrankedMemrefAdd)
+# Test addition of two bf16 memrefs
+# CHECK-LABEL: TEST: testBF16MemrefAdd
+def testBF16MemrefAdd():
+ with Context():
+ module = Module.parse(
+ """
+ module {
+ func.func @main(%arg0: memref<1xcomplex<bf16>>,
+ %arg1: memref<1xcomplex<bf16>>,
+ %arg2: memref<1xcomplex<bf16>>) attributes { llvm.emit_c_interface } {
+ %0 = arith.constant 0 : index
+ %1 = memref.load %arg0[%0] : memref<1xcomplex<bf16>>
+ %2 = memref.load %arg1[%0] : memref<1xcomplex<bf16>>
+ %3 = complex.add %1, %2 : complex<bf16>
+ memref.store %3, %arg2[%0] : memref<1xcomplex<bf16>>
+ return
+ }
+ } """
+ )
+
+ arg1 = np.array([11.0]).astype(bfloat16)
+ arg2 = np.array([12.0]).astype(bfloat16)
+ arg3 = np.array([0.0]).astype(bfloat16)
+
+ arg1_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg1))
+ )
+ arg2_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg2))
+ )
+ arg3_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg3))
+ )
+
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.invoke(
+ "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
+ )
+ # CHECK: [11.] + [22.] = [33.]
+ log("{0} + {1} = {2}".format(arg1, arg2, arg3))
+
+ # test to-numpy utility
+ # CHECK: [33.]
+ npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
+ log(npout)
+
+
+run(testBF16MemrefAdd)
+
+
# Test addition of two 2d_memref
# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
def testDynamicMemrefAdd2D():
>From aac79a38cd9ff31c29825adec99ed76376116b51 Mon Sep 17 00:00:00 2001
From: "Xu, Rui" <rui.xu at intel.com>
Date: Thu, 16 May 2024 21:21:16 -0700
Subject: [PATCH 2/3] [MLIR][Python] make ml_dtypes optional
---
mlir/python/mlir/runtime/np_to_memref.py | 13 +++++++++++--
mlir/python/requirements.txt | 3 +--
mlir/test/python/execution_engine.py | 15 ++++++++++-----
3 files changed, 22 insertions(+), 9 deletions(-)
diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 55e6a6cc5ab3e..882b2751921bf 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -6,7 +6,12 @@
import numpy as np
import ctypes
-import ml_dtypes
+
+try:
+ import ml_dtypes
+except ModuleNotFoundError:
+ # The third-party ml_dtypes provides some optional low precision data-types for NumPy.
+ ml_dtypes = None
class C128(ctypes.Structure):
@@ -26,6 +31,7 @@ class F16(ctypes.Structure):
_fields_ = [("f16", ctypes.c_int16)]
+
class BF16(ctypes.Structure):
"""A ctype representation for MLIR's BFloat16."""
@@ -41,7 +47,7 @@ def as_ctype(dtp):
return C64
if dtp == np.dtype(np.float16):
return F16
- if dtp == ml_dtypes.bfloat16:
+ if ml_dtypes is not None and dtp == ml_dtypes.bfloat16:
return BF16
return np.ctypeslib.as_ctypes_type(dtp)
@@ -54,6 +60,9 @@ def to_numpy(array):
return array.view("complex64")
if array.dtype == F16:
return array.view("float16")
+ assert not (
+ array.dtype == BF16 and ml_dtypes is None
+ ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
if array.dtype == BF16:
return array.view("bfloat16")
return array
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index 90acba8d65f09..acd6dbb25edaf 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,4 +1,3 @@
numpy>=1.19.5, <=1.26
pybind11>=2.9.0, <=2.10.3
-PyYAML>=5.3.1, <=6.0.1
-ml_dtypes
\ No newline at end of file
+PyYAML>=5.3.1, <=6.0.1
\ No newline at end of file
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 61d145ef24d95..d406e8f05b316 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -5,7 +5,11 @@
from mlir.passmanager import *
from mlir.execution_engine import *
from mlir.runtime import *
-from ml_dtypes import bfloat16
+
+try:
+ import ml_dtypes
+except ModuleNotFoundError:
+ ml_dtypes = None
# Log everything to stderr and flush so that we have a unified stream to match
@@ -542,9 +546,9 @@ def testBF16MemrefAdd():
} """
)
- arg1 = np.array([11.0]).astype(bfloat16)
- arg2 = np.array([12.0]).astype(bfloat16)
- arg3 = np.array([0.0]).astype(bfloat16)
+ arg1 = np.array([11.0]).astype(ml_dtypes.bfloat16)
+ arg2 = np.array([12.0]).astype(ml_dtypes.bfloat16)
+ arg3 = np.array([0.0]).astype(ml_dtypes.bfloat16)
arg1_memref_ptr = ctypes.pointer(
ctypes.pointer(get_ranked_memref_descriptor(arg1))
@@ -569,7 +573,8 @@ def testBF16MemrefAdd():
log(npout)
-run(testBF16MemrefAdd)
+if ml_dtypes is not None:
+ run(testBF16MemrefAdd)
# Test addition of two 2d_memref
>From 6a1900705692cbfb6e309d91929520fa48beb74c Mon Sep 17 00:00:00 2001
From: "Xu, Rui" <rui.xu at intel.com>
Date: Thu, 16 May 2024 21:48:10 -0700
Subject: [PATCH 3/3] [MLIR][Python] fix filecheck
---
mlir/test/python/execution_engine.py | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index d406e8f05b316..adc3e02dd3fce 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -529,6 +529,11 @@ def testComplexUnrankedMemrefAdd():
# Test addition of two bf16 memrefs
# CHECK-LABEL: TEST: testBF16MemrefAdd
def testBF16MemrefAdd():
+ if ml_dtypes is None:
+ log(
+ "Skipping testBF16MemrefAdd because bfloat16 requires the ml_dtypes package."
+ )
+ return
with Context():
module = Module.parse(
"""
@@ -573,8 +578,7 @@ def testBF16MemrefAdd():
log(npout)
-if ml_dtypes is not None:
- run(testBF16MemrefAdd)
+run(testBF16MemrefAdd)
# Test addition of two 2d_memref
More information about the Mlir-commits
mailing list