[Mlir-commits] [mlir] 58c8b25 - Replacing `is` with `==` for the dtype check.
Prashant Kumar
llvmlistbot at llvm.org
Thu Dec 1 18:34:38 PST 2022
Author: Prashant Kumar
Date: 2022-12-02T02:34:31Z
New Revision: 58c8b253cdd5ccd4e61d5854a8a614b35498276f
URL: https://github.com/llvm/llvm-project/commit/58c8b253cdd5ccd4e61d5854a8a614b35498276f
DIFF: https://github.com/llvm/llvm-project/commit/58c8b253cdd5ccd4e61d5854a8a614b35498276f.diff
LOG: Replacing `is` with `==` for the dtype check.
>>> a = np.ndarray([1,1]).astype(np.half)
>>> a
array([[0.007812]], dtype=float16)
>>> a.dtype
dtype('float16')
>>> a.dtype == np.half
True
>>> a.dtype == np.float16
True
>>> a.dtype is np.float16
False
Checking with `is` leads to inconsistency in checking.
Reviewed By: silvas
Differential Revision: https://reviews.llvm.org/D139121
Added:
Modified:
mlir/python/mlir/runtime/np_to_memref.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 5b3c3c4ae322d..d70967983c45d 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -23,13 +23,14 @@ class F16(ctypes.Structure):
_fields_ = [("f16", ctypes.c_int16)]
+# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
def as_ctype(dtp):
"""Converts dtype to ctype."""
- if dtp is np.dtype(np.complex128):
+ if dtp == np.dtype(np.complex128):
return C128
- if dtp is np.dtype(np.complex64):
+ if dtp == np.dtype(np.complex64):
return C64
- if dtp is np.dtype(np.float16):
+ if dtp == np.dtype(np.float16):
return F16
return np.ctypeslib.as_ctypes_type(dtp)
More information about the Mlir-commits
mailing list