[llvm-branch-commits] [mlir] 52ef986 - Revert "[MLIR][Python] add ctype python binding support for bf16 (#92489)"

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed May 29 22:20:38 PDT 2024


Author: Mehdi Amini
Date: 2024-05-29T23:20:35-06:00
New Revision: 52ef9864abecea0cf8d20e7eaf49c256248af5f7

URL: https://github.com/llvm/llvm-project/commit/52ef9864abecea0cf8d20e7eaf49c256248af5f7
DIFF: https://github.com/llvm/llvm-project/commit/52ef9864abecea0cf8d20e7eaf49c256248af5f7.diff

LOG: Revert "[MLIR][Python] add ctype python binding support for bf16 (#92489)"

This reverts commit 89801c74c3e25f5a1eaa3999863be398f6a82abb.

Added: 
    

Modified: 
    mlir/python/mlir/runtime/np_to_memref.py
    mlir/python/requirements.txt
    mlir/test/python/execution_engine.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 882b2751921bf..f6b706f9bc8ae 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -7,12 +7,6 @@
 import numpy as np
 import ctypes
 
-try:
-    import ml_dtypes
-except ModuleNotFoundError:
-    # The third-party ml_dtypes provides some optional low precision data-types for NumPy.
-    ml_dtypes = None
-
 
 class C128(ctypes.Structure):
     """A ctype representation for MLIR's Double Complex."""
@@ -32,12 +26,6 @@ class F16(ctypes.Structure):
     _fields_ = [("f16", ctypes.c_int16)]
 
 
-class BF16(ctypes.Structure):
-    """A ctype representation for MLIR's BFloat16."""
-
-    _fields_ = [("bf16", ctypes.c_int16)]
-
-
 # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
 def as_ctype(dtp):
     """Converts dtype to ctype."""
@@ -47,8 +35,6 @@ def as_ctype(dtp):
         return C64
     if dtp == np.dtype(np.float16):
         return F16
-    if ml_dtypes is not None and dtp == ml_dtypes.bfloat16:
-        return BF16
     return np.ctypeslib.as_ctypes_type(dtp)
 
 
@@ -60,11 +46,6 @@ def to_numpy(array):
         return array.view("complex64")
     if array.dtype == F16:
         return array.view("float16")
-    assert not (
-        array.dtype == BF16 and ml_dtypes is None
-    ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
-    if array.dtype == BF16:
-        return array.view("bfloat16")
     return array
 
 

diff  --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index 6ec63e43adf89..acd6dbb25edaf 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,4 +1,3 @@
 numpy>=1.19.5, <=1.26
 pybind11>=2.9.0, <=2.10.3
-PyYAML>=5.3.1, <=6.0.1
-ml_dtypes   # provides several NumPy dtype extensions, including the bf16
\ No newline at end of file
+PyYAML>=5.3.1, <=6.0.1
\ No newline at end of file

diff  --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 8125bf3fb8fc9..e8b47007a8907 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -5,7 +5,6 @@
 from mlir.passmanager import *
 from mlir.execution_engine import *
 from mlir.runtime import *
-from ml_dtypes import bfloat16
 
 
 # Log everything to stderr and flush so that we have a unified stream to match
@@ -522,45 +521,6 @@ def testComplexUnrankedMemrefAdd():
 run(testComplexUnrankedMemrefAdd)
 
 
-# Test bf16 memrefs
-# CHECK-LABEL: TEST: testBF16Memref
-def testBF16Memref():
-    with Context():
-        module = Module.parse(
-            """
-    module  {
-      func.func @main(%arg0: memref<1xbf16>,
-                      %arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
-        %0 = arith.constant 0 : index
-        %1 = memref.load %arg0[%0] : memref<1xbf16>
-        memref.store %1, %arg1[%0] : memref<1xbf16>
-        return
-      }
-    } """
-        )
-
-        arg1 = np.array([0.5]).astype(bfloat16)
-        arg2 = np.array([0.0]).astype(bfloat16)
-
-        arg1_memref_ptr = ctypes.pointer(
-            ctypes.pointer(get_ranked_memref_descriptor(arg1))
-        )
-        arg2_memref_ptr = ctypes.pointer(
-            ctypes.pointer(get_ranked_memref_descriptor(arg2))
-        )
-
-        execution_engine = ExecutionEngine(lowerToLLVM(module))
-        execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
-
-        # test to-numpy utility
-        # CHECK: [0.5]
-        npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
-        log(npout)
-
-
-run(testBF16Memref)
-
-
 #  Test addition of two 2d_memref
 # CHECK-LABEL: TEST: testDynamicMemrefAdd2D
 def testDynamicMemrefAdd2D():


        


More information about the llvm-branch-commits mailing list