[llvm-branch-commits] [mlir] 52ef986 - Revert "[MLIR][Python] add ctype python binding support for bf16 (#92489)"
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed May 29 22:20:38 PDT 2024
Author: Mehdi Amini
Date: 2024-05-29T23:20:35-06:00
New Revision: 52ef9864abecea0cf8d20e7eaf49c256248af5f7
URL: https://github.com/llvm/llvm-project/commit/52ef9864abecea0cf8d20e7eaf49c256248af5f7
DIFF: https://github.com/llvm/llvm-project/commit/52ef9864abecea0cf8d20e7eaf49c256248af5f7.diff
LOG: Revert "[MLIR][Python] add ctype python binding support for bf16 (#92489)"
This reverts commit 89801c74c3e25f5a1eaa3999863be398f6a82abb.
Added:
Modified:
mlir/python/mlir/runtime/np_to_memref.py
mlir/python/requirements.txt
mlir/test/python/execution_engine.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 882b2751921bf..f6b706f9bc8ae 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -7,12 +7,6 @@
import numpy as np
import ctypes
-try:
- import ml_dtypes
-except ModuleNotFoundError:
- # The third-party ml_dtypes provides some optional low precision data-types for NumPy.
- ml_dtypes = None
-
class C128(ctypes.Structure):
"""A ctype representation for MLIR's Double Complex."""
@@ -32,12 +26,6 @@ class F16(ctypes.Structure):
_fields_ = [("f16", ctypes.c_int16)]
-class BF16(ctypes.Structure):
- """A ctype representation for MLIR's BFloat16."""
-
- _fields_ = [("bf16", ctypes.c_int16)]
-
-
# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
def as_ctype(dtp):
"""Converts dtype to ctype."""
@@ -47,8 +35,6 @@ def as_ctype(dtp):
return C64
if dtp == np.dtype(np.float16):
return F16
- if ml_dtypes is not None and dtp == ml_dtypes.bfloat16:
- return BF16
return np.ctypeslib.as_ctypes_type(dtp)
@@ -60,11 +46,6 @@ def to_numpy(array):
return array.view("complex64")
if array.dtype == F16:
return array.view("float16")
- assert not (
- array.dtype == BF16 and ml_dtypes is None
- ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
- if array.dtype == BF16:
- return array.view("bfloat16")
return array
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index 6ec63e43adf89..acd6dbb25edaf 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,4 +1,3 @@
numpy>=1.19.5, <=1.26
pybind11>=2.9.0, <=2.10.3
-PyYAML>=5.3.1, <=6.0.1
-ml_dtypes # provides several NumPy dtype extensions, including the bf16
\ No newline at end of file
+PyYAML>=5.3.1, <=6.0.1
\ No newline at end of file
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 8125bf3fb8fc9..e8b47007a8907 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -5,7 +5,6 @@
from mlir.passmanager import *
from mlir.execution_engine import *
from mlir.runtime import *
-from ml_dtypes import bfloat16
# Log everything to stderr and flush so that we have a unified stream to match
@@ -522,45 +521,6 @@ def testComplexUnrankedMemrefAdd():
run(testComplexUnrankedMemrefAdd)
-# Test bf16 memrefs
-# CHECK-LABEL: TEST: testBF16Memref
-def testBF16Memref():
- with Context():
- module = Module.parse(
- """
- module {
- func.func @main(%arg0: memref<1xbf16>,
- %arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
- %0 = arith.constant 0 : index
- %1 = memref.load %arg0[%0] : memref<1xbf16>
- memref.store %1, %arg1[%0] : memref<1xbf16>
- return
- }
- } """
- )
-
- arg1 = np.array([0.5]).astype(bfloat16)
- arg2 = np.array([0.0]).astype(bfloat16)
-
- arg1_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_ranked_memref_descriptor(arg1))
- )
- arg2_memref_ptr = ctypes.pointer(
- ctypes.pointer(get_ranked_memref_descriptor(arg2))
- )
-
- execution_engine = ExecutionEngine(lowerToLLVM(module))
- execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
-
- # test to-numpy utility
- # CHECK: [0.5]
- npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
- log(npout)
-
-
-run(testBF16Memref)
-
-
# Test addition of two 2d_memref
# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
def testDynamicMemrefAdd2D():
More information about the llvm-branch-commits
mailing list