[Mlir-commits] [mlir] 58c8b25 - Replacing `is` with `==` for the dtype check.

Prashant Kumar llvmlistbot at llvm.org
Thu Dec 1 18:34:38 PST 2022


Author: Prashant Kumar
Date: 2022-12-02T02:34:31Z
New Revision: 58c8b253cdd5ccd4e61d5854a8a614b35498276f

URL: https://github.com/llvm/llvm-project/commit/58c8b253cdd5ccd4e61d5854a8a614b35498276f
DIFF: https://github.com/llvm/llvm-project/commit/58c8b253cdd5ccd4e61d5854a8a614b35498276f.diff

LOG: Replacing `is` with `==` for the dtype check.

>>> a = np.ndarray([1,1]).astype(np.half)
>>> a
array([[0.007812]], dtype=float16)
>>> a.dtype
dtype('float16')
>>> a.dtype == np.half
True
>>> a.dtype == np.float16
True
>>> a.dtype is np.float16
False

Checking with `is` leads to inconsistency in checking.

Reviewed By: silvas

Differential Revision: https://reviews.llvm.org/D139121

Added: 
    

Modified: 
    mlir/python/mlir/runtime/np_to_memref.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 5b3c3c4ae322d..d70967983c45d 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -23,13 +23,14 @@ class F16(ctypes.Structure):
   _fields_ = [("f16", ctypes.c_int16)]
 
 
+# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
 def as_ctype(dtp):
   """Converts dtype to ctype."""
-  if dtp is np.dtype(np.complex128):
+  if dtp == np.dtype(np.complex128):
     return C128
-  if dtp is np.dtype(np.complex64):
+  if dtp == np.dtype(np.complex64):
     return C64
-  if dtp is np.dtype(np.float16):
+  if dtp == np.dtype(np.float16):
     return F16
   return np.ctypeslib.as_ctypes_type(dtp)
 


        


More information about the Mlir-commits mailing list