[Mlir-commits] [mlir] a04c0b7 - [mlir][python] Fix MemRefType IsAFunction in Python bindings
Alex Zinenko
llvmlistbot at llvm.org
Thu Oct 14 04:12:44 PDT 2021
Author: Alex Zinenko
Date: 2021-10-14T13:12:37+02:00
New Revision: a04c0b7ed2f92456558af2833f64cd494d161905
URL: https://github.com/llvm/llvm-project/commit/a04c0b7ed2f92456558af2833f64cd494d161905
DIFF: https://github.com/llvm/llvm-project/commit/a04c0b7ed2f92456558af2833f64cd494d161905.diff
LOG: [mlir][python] Fix MemRefType IsAFunction in Python bindings
MemRefType was using a wrong `isa` function in the bindings code, which
could lead to invalid IR being constructed. Also run the verifier in
memref dialect tests.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D111784
Added:
Modified:
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/python/mlir/dialects/_memref_ops_ext.py
mlir/test/python/dialects/memref.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 568cca160a595..fd9f3efe7405f 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -406,7 +406,7 @@ class PyMemRefLayoutMapList;
/// Ranked MemRef Type subclass - MemRefType.
class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
static constexpr const char *pyClassName = "MemRefType";
using PyConcreteType::PyConcreteType;
diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py
index cb25ef105d73f..9cc22a21c6283 100644
--- a/mlir/python/mlir/dialects/_memref_ops_ext.py
+++ b/mlir/python/mlir/dialects/_memref_ops_ext.py
@@ -33,5 +33,5 @@ def __init__(self,
memref_resolved = _get_op_result_or_value(memref)
indices_resolved = [] if indices is None else _get_op_results_or_values(
indices)
- return_type = memref_resolved.type
+ return_type = MemRefType(memref_resolved.type).element_type
super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip)
diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py
index e421f9b2fde95..f2eda0a620610 100644
--- a/mlir/test/python/dialects/memref.py
+++ b/mlir/test/python/dialects/memref.py
@@ -71,3 +71,4 @@ def testCustomBuidlers():
# CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
# CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
print(module)
+ assert module.operation.verify()
More information about the Mlir-commits
mailing list