[Mlir-commits] [mlir] 89801c7 - [MLIR][Python] add ctype python binding support for bf16 (#92489)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 29 22:01:43 PDT 2024


Author: Bimo
Date: 2024-05-29T22:01:40-07:00
New Revision: 89801c74c3e25f5a1eaa3999863be398f6a82abb

URL: https://github.com/llvm/llvm-project/commit/89801c74c3e25f5a1eaa3999863be398f6a82abb
DIFF: https://github.com/llvm/llvm-project/commit/89801c74c3e25f5a1eaa3999863be398f6a82abb.diff

LOG: [MLIR][Python] add ctype python binding support for bf16 (#92489)

Since bf16 is supported by mlir, similar to
complex128/complex64/float16, we need an implementation of bf16 ctype in
Python binding. Furthermore, to resolve the absence of bf16 support in
NumPy, a third-party package [ml_dtypes
](https://github.com/jax-ml/ml_dtypes) is introduced to add bf16
extension, and the same approach was used in `torch-mlir` project.

See motivation and discussion in:
https://discourse.llvm.org/t/how-to-run-executionengine-with-bf16-dtype-in-mlir-python-bindings/79025

Added: 
    

Modified: 
    mlir/python/mlir/runtime/np_to_memref.py
    mlir/python/requirements.txt
    mlir/test/python/execution_engine.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index f6b706f9bc8ae..882b2751921bf 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -7,6 +7,12 @@
 import numpy as np
 import ctypes
 
+try:
+    import ml_dtypes
+except ModuleNotFoundError:
+    # The third-party ml_dtypes provides some optional low precision data-types for NumPy.
+    ml_dtypes = None
+
 
 class C128(ctypes.Structure):
     """A ctype representation for MLIR's Double Complex."""
@@ -26,6 +32,12 @@ class F16(ctypes.Structure):
     _fields_ = [("f16", ctypes.c_int16)]
 
 
+class BF16(ctypes.Structure):
+    """A ctype representation for MLIR's BFloat16."""
+
+    _fields_ = [("bf16", ctypes.c_int16)]
+
+
 # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
 def as_ctype(dtp):
     """Converts dtype to ctype."""
@@ -35,6 +47,8 @@ def as_ctype(dtp):
         return C64
     if dtp == np.dtype(np.float16):
         return F16
+    if ml_dtypes is not None and dtp == ml_dtypes.bfloat16:
+        return BF16
     return np.ctypeslib.as_ctypes_type(dtp)
 
 
@@ -46,6 +60,11 @@ def to_numpy(array):
         return array.view("complex64")
     if array.dtype == F16:
         return array.view("float16")
+    assert not (
+        array.dtype == BF16 and ml_dtypes is None
+    ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
+    if array.dtype == BF16:
+        return array.view("bfloat16")
     return array
 
 

diff  --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index acd6dbb25edaf..6ec63e43adf89 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,3 +1,4 @@
 numpy>=1.19.5, <=1.26
 pybind11>=2.9.0, <=2.10.3
-PyYAML>=5.3.1, <=6.0.1
\ No newline at end of file
+PyYAML>=5.3.1, <=6.0.1
+ml_dtypes   # provides several NumPy dtype extensions, including the bf16
\ No newline at end of file

diff  --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index e8b47007a8907..8125bf3fb8fc9 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -5,6 +5,7 @@
 from mlir.passmanager import *
 from mlir.execution_engine import *
 from mlir.runtime import *
+from ml_dtypes import bfloat16
 
 
 # Log everything to stderr and flush so that we have a unified stream to match
@@ -521,6 +522,45 @@ def testComplexUnrankedMemrefAdd():
 run(testComplexUnrankedMemrefAdd)
 
 
+# Test bf16 memrefs
+# CHECK-LABEL: TEST: testBF16Memref
+def testBF16Memref():
+    with Context():
+        module = Module.parse(
+            """
+    module  {
+      func.func @main(%arg0: memref<1xbf16>,
+                      %arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
+        %0 = arith.constant 0 : index
+        %1 = memref.load %arg0[%0] : memref<1xbf16>
+        memref.store %1, %arg1[%0] : memref<1xbf16>
+        return
+      }
+    } """
+        )
+
+        arg1 = np.array([0.5]).astype(bfloat16)
+        arg2 = np.array([0.0]).astype(bfloat16)
+
+        arg1_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg1))
+        )
+        arg2_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg2))
+        )
+
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
+
+        # test to-numpy utility
+        # CHECK: [0.5]
+        npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
+        log(npout)
+
+
+run(testBF16Memref)
+
+
 #  Test addition of two 2d_memref
 # CHECK-LABEL: TEST: testDynamicMemrefAdd2D
 def testDynamicMemrefAdd2D():


        


More information about the Mlir-commits mailing list