[Mlir-commits] [mlir] a04c0b7 - [mlir][python] Fix MemRefType IsAFunction in Python bindings

Alex Zinenko llvmlistbot at llvm.org
Thu Oct 14 04:12:44 PDT 2021


Author: Alex Zinenko
Date: 2021-10-14T13:12:37+02:00
New Revision: a04c0b7ed2f92456558af2833f64cd494d161905

URL: https://github.com/llvm/llvm-project/commit/a04c0b7ed2f92456558af2833f64cd494d161905
DIFF: https://github.com/llvm/llvm-project/commit/a04c0b7ed2f92456558af2833f64cd494d161905.diff

LOG: [mlir][python] Fix MemRefType IsAFunction in Python bindings

MemRefType was using a wrong `isa` function in the bindings code, which
could lead to invalid IR being constructed. Also run the verifier in
memref dialect tests.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D111784

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/python/mlir/dialects/_memref_ops_ext.py
    mlir/test/python/dialects/memref.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 568cca160a595..fd9f3efe7405f 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -406,7 +406,7 @@ class PyMemRefLayoutMapList;
 /// Ranked MemRef Type subclass - MemRefType.
 class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
 public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
   static constexpr const char *pyClassName = "MemRefType";
   using PyConcreteType::PyConcreteType;
 

diff  --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py
index cb25ef105d73f..9cc22a21c6283 100644
--- a/mlir/python/mlir/dialects/_memref_ops_ext.py
+++ b/mlir/python/mlir/dialects/_memref_ops_ext.py
@@ -33,5 +33,5 @@ def __init__(self,
     memref_resolved = _get_op_result_or_value(memref)
     indices_resolved = [] if indices is None else _get_op_results_or_values(
         indices)
-    return_type = memref_resolved.type
+    return_type = MemRefType(memref_resolved.type).element_type
     super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip)

diff  --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py
index e421f9b2fde95..f2eda0a620610 100644
--- a/mlir/test/python/dialects/memref.py
+++ b/mlir/test/python/dialects/memref.py
@@ -71,3 +71,4 @@ def testCustomBuidlers():
     # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
     # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
     print(module)
+    assert module.operation.verify()


        


More information about the Mlir-commits mailing list