[Mlir-commits] [mlir] [MLIR][Python] add ctype python binding support for bf16 (PR #92489)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 20 21:35:39 PDT 2024


https://github.com/xurui1995 updated https://github.com/llvm/llvm-project/pull/92489

>From 8a36de1a5b080af97bc7af1ae431436ff965956c Mon Sep 17 00:00:00 2001
From: "Xu, Rui" <rui.xu at intel.com>
Date: Thu, 16 May 2024 08:49:50 -0700
Subject: [PATCH 1/9] [MLIR][Python] add ctype python binding support for bf16

---
 mlir/python/mlir/runtime/np_to_memref.py | 10 +++++
 mlir/python/requirements.txt             |  3 +-
 mlir/test/python/execution_engine.py     | 51 ++++++++++++++++++++++++
 3 files changed, 63 insertions(+), 1 deletion(-)

diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index f6b706f9bc8ae..55e6a6cc5ab3e 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -6,6 +6,7 @@
 
 import numpy as np
 import ctypes
+import ml_dtypes
 
 
 class C128(ctypes.Structure):
@@ -25,6 +26,11 @@ class F16(ctypes.Structure):
 
     _fields_ = [("f16", ctypes.c_int16)]
 
+class BF16(ctypes.Structure):
+    """A ctype representation for MLIR's BFloat16."""
+
+    _fields_ = [("bf16", ctypes.c_int16)]
+
 
 # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
 def as_ctype(dtp):
@@ -35,6 +41,8 @@ def as_ctype(dtp):
         return C64
     if dtp == np.dtype(np.float16):
         return F16
+    if dtp == ml_dtypes.bfloat16:
+        return BF16
     return np.ctypeslib.as_ctypes_type(dtp)
 
 
@@ -46,6 +54,8 @@ def to_numpy(array):
         return array.view("complex64")
     if array.dtype == F16:
         return array.view("float16")
+    if array.dtype == BF16:
+        return array.view("bfloat16")
     return array
 
 
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index acd6dbb25edaf..90acba8d65f09 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,3 +1,4 @@
 numpy>=1.19.5, <=1.26
 pybind11>=2.9.0, <=2.10.3
-PyYAML>=5.3.1, <=6.0.1
\ No newline at end of file
+PyYAML>=5.3.1, <=6.0.1
+ml_dtypes
\ No newline at end of file
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index e8b47007a8907..61d145ef24d95 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -5,6 +5,7 @@
 from mlir.passmanager import *
 from mlir.execution_engine import *
 from mlir.runtime import *
+from ml_dtypes import bfloat16
 
 
 # Log everything to stderr and flush so that we have a unified stream to match
@@ -521,6 +522,56 @@ def testComplexUnrankedMemrefAdd():
 run(testComplexUnrankedMemrefAdd)
 
 
+# Test addition of two bf16 memrefs
+# CHECK-LABEL: TEST: testBF16MemrefAdd
+def testBF16MemrefAdd():
+    with Context():
+        module = Module.parse(
+            """
+    module  {
+      func.func @main(%arg0: memref<1xcomplex<bf16>>,
+                      %arg1: memref<1xcomplex<bf16>>,
+                      %arg2: memref<1xcomplex<bf16>>) attributes { llvm.emit_c_interface } {
+        %0 = arith.constant 0 : index
+        %1 = memref.load %arg0[%0] : memref<1xcomplex<bf16>>
+        %2 = memref.load %arg1[%0] : memref<1xcomplex<bf16>>
+        %3 = complex.add %1, %2 : complex<bf16>
+        memref.store %3, %arg2[%0] : memref<1xcomplex<bf16>>
+        return
+      }
+    } """
+        )
+
+        arg1 = np.array([11.0]).astype(bfloat16)
+        arg2 = np.array([12.0]).astype(bfloat16)
+        arg3 = np.array([0.0]).astype(bfloat16)
+
+        arg1_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg1))
+        )
+        arg2_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg2))
+        )
+        arg3_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg3))
+        )
+
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.invoke(
+            "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
+        )
+        # CHECK: [11.] + [22.] = [33.]
+        log("{0} + {1} = {2}".format(arg1, arg2, arg3))
+
+        # test to-numpy utility
+        # CHECK: [33.]
+        npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
+        log(npout)
+
+
+run(testBF16MemrefAdd)
+
+
 #  Test addition of two 2d_memref
 # CHECK-LABEL: TEST: testDynamicMemrefAdd2D
 def testDynamicMemrefAdd2D():

>From aac79a38cd9ff31c29825adec99ed76376116b51 Mon Sep 17 00:00:00 2001
From: "Xu, Rui" <rui.xu at intel.com>
Date: Thu, 16 May 2024 21:21:16 -0700
Subject: [PATCH 2/9] [MLIR][Python] make ml_dtypes optional

---
 mlir/python/mlir/runtime/np_to_memref.py | 13 +++++++++++--
 mlir/python/requirements.txt             |  3 +--
 mlir/test/python/execution_engine.py     | 15 ++++++++++-----
 3 files changed, 22 insertions(+), 9 deletions(-)

diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 55e6a6cc5ab3e..882b2751921bf 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -6,7 +6,12 @@
 
 import numpy as np
 import ctypes
-import ml_dtypes
+
+try:
+    import ml_dtypes
+except ModuleNotFoundError:
+    # The third-party ml_dtypes provides some optional low precision data-types for NumPy.
+    ml_dtypes = None
 
 
 class C128(ctypes.Structure):
@@ -26,6 +31,7 @@ class F16(ctypes.Structure):
 
     _fields_ = [("f16", ctypes.c_int16)]
 
+
 class BF16(ctypes.Structure):
     """A ctype representation for MLIR's BFloat16."""
 
@@ -41,7 +47,7 @@ def as_ctype(dtp):
         return C64
     if dtp == np.dtype(np.float16):
         return F16
-    if dtp == ml_dtypes.bfloat16:
+    if ml_dtypes is not None and dtp == ml_dtypes.bfloat16:
         return BF16
     return np.ctypeslib.as_ctypes_type(dtp)
 
@@ -54,6 +60,9 @@ def to_numpy(array):
         return array.view("complex64")
     if array.dtype == F16:
         return array.view("float16")
+    assert not (
+        array.dtype == BF16 and ml_dtypes is None
+    ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
     if array.dtype == BF16:
         return array.view("bfloat16")
     return array
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index 90acba8d65f09..acd6dbb25edaf 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,4 +1,3 @@
 numpy>=1.19.5, <=1.26
 pybind11>=2.9.0, <=2.10.3
-PyYAML>=5.3.1, <=6.0.1
-ml_dtypes
\ No newline at end of file
+PyYAML>=5.3.1, <=6.0.1
\ No newline at end of file
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 61d145ef24d95..d406e8f05b316 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -5,7 +5,11 @@
 from mlir.passmanager import *
 from mlir.execution_engine import *
 from mlir.runtime import *
-from ml_dtypes import bfloat16
+
+try:
+    import ml_dtypes
+except ModuleNotFoundError:
+    ml_dtypes = None
 
 
 # Log everything to stderr and flush so that we have a unified stream to match
@@ -542,9 +546,9 @@ def testBF16MemrefAdd():
     } """
         )
 
-        arg1 = np.array([11.0]).astype(bfloat16)
-        arg2 = np.array([12.0]).astype(bfloat16)
-        arg3 = np.array([0.0]).astype(bfloat16)
+        arg1 = np.array([11.0]).astype(ml_dtypes.bfloat16)
+        arg2 = np.array([12.0]).astype(ml_dtypes.bfloat16)
+        arg3 = np.array([0.0]).astype(ml_dtypes.bfloat16)
 
         arg1_memref_ptr = ctypes.pointer(
             ctypes.pointer(get_ranked_memref_descriptor(arg1))
@@ -569,7 +573,8 @@ def testBF16MemrefAdd():
         log(npout)
 
 
-run(testBF16MemrefAdd)
+if ml_dtypes is not None:
+    run(testBF16MemrefAdd)
 
 
 #  Test addition of two 2d_memref

>From 6a1900705692cbfb6e309d91929520fa48beb74c Mon Sep 17 00:00:00 2001
From: "Xu, Rui" <rui.xu at intel.com>
Date: Thu, 16 May 2024 21:48:10 -0700
Subject: [PATCH 3/9] [MLIR][Python] fix filecheck

---
 mlir/test/python/execution_engine.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index d406e8f05b316..adc3e02dd3fce 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -529,6 +529,11 @@ def testComplexUnrankedMemrefAdd():
 # Test addition of two bf16 memrefs
 # CHECK-LABEL: TEST: testBF16MemrefAdd
 def testBF16MemrefAdd():
+    if ml_dtypes is None:
+        log(
+            "Skipping testBF16MemrefAdd because bfloat16 requires the ml_dtypes package."
+        )
+        return
     with Context():
         module = Module.parse(
             """
@@ -573,8 +578,7 @@ def testBF16MemrefAdd():
         log(npout)
 
 
-if ml_dtypes is not None:
-    run(testBF16MemrefAdd)
+run(testBF16MemrefAdd)
 
 
 #  Test addition of two 2d_memref

>From f66cf78601c4c18588d17e70f7f9effcc66f8f0d Mon Sep 17 00:00:00 2001
From: "Xu, Rui" <rui.xu at intel.com>
Date: Thu, 16 May 2024 23:09:40 -0700
Subject: [PATCH 4/9] [MLIR][Python] fix filecheck when ml_dtypes is not
 installed

---
 mlir/test/python/execution_engine.py | 19 ++++++++++---------
 1 file changed, 10 insertions(+), 9 deletions(-)

diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index adc3e02dd3fce..9afb81ab2266b 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -529,11 +529,6 @@ def testComplexUnrankedMemrefAdd():
 # Test addition of two bf16 memrefs
 # CHECK-LABEL: TEST: testBF16MemrefAdd
 def testBF16MemrefAdd():
-    if ml_dtypes is None:
-        log(
-            "Skipping testBF16MemrefAdd because bfloat16 requires the ml_dtypes package."
-        )
-        return
     with Context():
         module = Module.parse(
             """
@@ -552,7 +547,7 @@ def testBF16MemrefAdd():
         )
 
         arg1 = np.array([11.0]).astype(ml_dtypes.bfloat16)
-        arg2 = np.array([12.0]).astype(ml_dtypes.bfloat16)
+        arg2 = np.array([22.0]).astype(ml_dtypes.bfloat16)
         arg3 = np.array([0.0]).astype(ml_dtypes.bfloat16)
 
         arg1_memref_ptr = ctypes.pointer(
@@ -569,16 +564,22 @@ def testBF16MemrefAdd():
         execution_engine.invoke(
             "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
         )
-        # CHECK: [11.] + [22.] = [33.]
+        # CHECK: [11] + [22] = [33]
         log("{0} + {1} = {2}".format(arg1, arg2, arg3))
 
         # test to-numpy utility
-        # CHECK: [33.]
+        # CHECK: [33]
         npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
         log(npout)
 
 
-run(testBF16MemrefAdd)
+if ml_dtypes is None:
+    log("BF16 execution requires the ml_dtypes package, skipping the test..")
+    log(
+        "Skip BF16 execution check: \nTEST: testBF16MemrefAdd; \n[11] + [22] = [33]; \n[33]"
+    )
+else:
+    run(testBF16MemrefAdd)
 
 
 #  Test addition of two 2d_memref

>From f0fddd04a3cbee6cccf3d3bb3adc18fdbc33723c Mon Sep 17 00:00:00 2001
From: "Xu, Rui" <rui.xu at intel.com>
Date: Sat, 18 May 2024 19:31:49 -0700
Subject: [PATCH 5/9] [MLIR][Python] add dep in requirements.txt

---
 mlir/python/requirements.txt         |  3 ++-
 mlir/test/python/execution_engine.py | 20 +++++---------------
 2 files changed, 7 insertions(+), 16 deletions(-)

diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index acd6dbb25edaf..6ec63e43adf89 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,3 +1,4 @@
 numpy>=1.19.5, <=1.26
 pybind11>=2.9.0, <=2.10.3
-PyYAML>=5.3.1, <=6.0.1
\ No newline at end of file
+PyYAML>=5.3.1, <=6.0.1
+ml_dtypes   # provides several NumPy dtype extensions, including the bf16
\ No newline at end of file
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 9afb81ab2266b..63b70b600388f 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -5,11 +5,7 @@
 from mlir.passmanager import *
 from mlir.execution_engine import *
 from mlir.runtime import *
-
-try:
-    import ml_dtypes
-except ModuleNotFoundError:
-    ml_dtypes = None
+from ml_dtypes import bfloat16
 
 
 # Log everything to stderr and flush so that we have a unified stream to match
@@ -546,9 +542,9 @@ def testBF16MemrefAdd():
     } """
         )
 
-        arg1 = np.array([11.0]).astype(ml_dtypes.bfloat16)
-        arg2 = np.array([22.0]).astype(ml_dtypes.bfloat16)
-        arg3 = np.array([0.0]).astype(ml_dtypes.bfloat16)
+        arg1 = np.array([11.0]).astype(bfloat16)
+        arg2 = np.array([22.0]).astype(bfloat16)
+        arg3 = np.array([0.0]).astype(bfloat16)
 
         arg1_memref_ptr = ctypes.pointer(
             ctypes.pointer(get_ranked_memref_descriptor(arg1))
@@ -573,13 +569,7 @@ def testBF16MemrefAdd():
         log(npout)
 
 
-if ml_dtypes is None:
-    log("BF16 execution requires the ml_dtypes package, skipping the test..")
-    log(
-        "Skip BF16 execution check: \nTEST: testBF16MemrefAdd; \n[11] + [22] = [33]; \n[33]"
-    )
-else:
-    run(testBF16MemrefAdd)
+run(testBF16MemrefAdd)
 
 
 #  Test addition of two 2d_memref

>From 1266bfb11dceb917bf3bed48ce9ca91753009586 Mon Sep 17 00:00:00 2001
From: "Xu, Rui" <rui.xu at intel.com>
Date: Mon, 20 May 2024 18:56:02 -0700
Subject: [PATCH 6/9] [MLIR][Python] fix test case

---
 mlir/test/python/execution_engine.py | 30 +++++++++-------------------
 1 file changed, 9 insertions(+), 21 deletions(-)

diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 63b70b600388f..947875ca6f34b 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -521,7 +521,6 @@ def testComplexUnrankedMemrefAdd():
 
 run(testComplexUnrankedMemrefAdd)
 
-
 # Test addition of two bf16 memrefs
 # CHECK-LABEL: TEST: testBF16MemrefAdd
 def testBF16MemrefAdd():
@@ -529,22 +528,18 @@ def testBF16MemrefAdd():
         module = Module.parse(
             """
     module  {
-      func.func @main(%arg0: memref<1xcomplex<bf16>>,
-                      %arg1: memref<1xcomplex<bf16>>,
-                      %arg2: memref<1xcomplex<bf16>>) attributes { llvm.emit_c_interface } {
+      func.func @main(%arg0: memref<1xbf16>,
+                      %arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
         %0 = arith.constant 0 : index
-        %1 = memref.load %arg0[%0] : memref<1xcomplex<bf16>>
-        %2 = memref.load %arg1[%0] : memref<1xcomplex<bf16>>
-        %3 = complex.add %1, %2 : complex<bf16>
-        memref.store %3, %arg2[%0] : memref<1xcomplex<bf16>>
+        %1 = memref.load %arg0[%0] : memref<1xbf16>
+        memref.store %1, %arg1[%0] : memref<1xbf16>
         return
       }
     } """
         )
 
-        arg1 = np.array([11.0]).astype(bfloat16)
-        arg2 = np.array([22.0]).astype(bfloat16)
-        arg3 = np.array([0.0]).astype(bfloat16)
+        arg1 = np.array([0.5]).astype(bfloat16)
+        arg2 = np.array([0.0]).astype(bfloat16)
 
         arg1_memref_ptr = ctypes.pointer(
             ctypes.pointer(get_ranked_memref_descriptor(arg1))
@@ -552,20 +547,13 @@ def testBF16MemrefAdd():
         arg2_memref_ptr = ctypes.pointer(
             ctypes.pointer(get_ranked_memref_descriptor(arg2))
         )
-        arg3_memref_ptr = ctypes.pointer(
-            ctypes.pointer(get_ranked_memref_descriptor(arg3))
-        )
 
         execution_engine = ExecutionEngine(lowerToLLVM(module))
-        execution_engine.invoke(
-            "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
-        )
-        # CHECK: [11] + [22] = [33]
-        log("{0} + {1} = {2}".format(arg1, arg2, arg3))
+        execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
 
         # test to-numpy utility
-        # CHECK: [33]
-        npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
+        # CHECK: [0.5]
+        npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
         log(npout)
 
 

>From cf75f2968b6770d938df319f39683d4c0fd01ac1 Mon Sep 17 00:00:00 2001
From: "Xu, Rui" <rui.xu at intel.com>
Date: Mon, 20 May 2024 21:08:55 -0700
Subject: [PATCH 7/9] [MLIR][Python] rename test case

---
 mlir/test/python/execution_engine.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 947875ca6f34b..d6e9008dc924a 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -521,9 +521,10 @@ def testComplexUnrankedMemrefAdd():
 
 run(testComplexUnrankedMemrefAdd)
 
+
 # Test addition of two bf16 memrefs
 # CHECK-LABEL: TEST: testBF16MemrefAdd
-def testBF16MemrefAdd():
+def testBF16Memref():
     with Context():
         module = Module.parse(
             """
@@ -557,7 +558,7 @@ def testBF16MemrefAdd():
         log(npout)
 
 
-run(testBF16MemrefAdd)
+run(testBF16Memref)
 
 
 #  Test addition of two 2d_memref

>From 909b8963b3115fed0aa7d12ab98697f9d9f202f8 Mon Sep 17 00:00:00 2001
From: "Xu, Rui" <rui.xu at intel.com>
Date: Mon, 20 May 2024 21:12:01 -0700
Subject: [PATCH 8/9] [MLIR][Python] fix comment

---
 mlir/test/python/execution_engine.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index d6e9008dc924a..d1f0f836d712e 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -522,7 +522,7 @@ def testComplexUnrankedMemrefAdd():
 run(testComplexUnrankedMemrefAdd)
 
 
-# Test addition of two bf16 memrefs
+# Test bf16 memrefs
 # CHECK-LABEL: TEST: testBF16MemrefAdd
 def testBF16Memref():
     with Context():

>From 08bce72071f94e5f04f78d38f62a64bf042eab79 Mon Sep 17 00:00:00 2001
From: "Xu, Rui" <rui.xu at intel.com>
Date: Mon, 20 May 2024 21:35:14 -0700
Subject: [PATCH 9/9] [MLIR][Python] fix cheeck label

---
 mlir/test/python/execution_engine.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index d1f0f836d712e..8125bf3fb8fc9 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -523,7 +523,7 @@ def testComplexUnrankedMemrefAdd():
 
 
 # Test bf16 memrefs
-# CHECK-LABEL: TEST: testBF16MemrefAdd
+# CHECK-LABEL: TEST: testBF16Memref
 def testBF16Memref():
     with Context():
         module = Module.parse(



More information about the Mlir-commits mailing list