[Mlir-commits] [mlir] 58c296a - [mlir] Use BaseMemRefType for ranked/unranked memrefs
Matthias Springer
llvmlistbot at llvm.org
Wed Apr 5 21:34:42 PDT 2023
Author: Matthias Springer
Date: 2023-04-06T13:29:44+09:00
New Revision: 58c296a418cb2a3dbd542a39b5077eb32e1f1895
URL: https://github.com/llvm/llvm-project/commit/58c296a418cb2a3dbd542a39b5077eb32e1f1895
DIFF: https://github.com/llvm/llvm-project/commit/58c296a418cb2a3dbd542a39b5077eb32e1f1895.diff
LOG: [mlir] Use BaseMemRefType for ranked/unranked memrefs
This makes `RankedOrUnrankedMemRefOf` consistent with `TensorOf`.
Differential Revision: https://reviews.llvm.org/D147160
Added:
Modified:
mlir/include/mlir/IR/OpBase.td
mlir/test/Dialect/MemRef/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index f7f009cce3177..98866c83b4b4e 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -801,7 +801,7 @@ def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
//===----------------------------------------------------------------------===//
// Memref type.
-// Unranked Memref type
+// Any unranked memref whose element type is from the given `allowedTypes` list.
class UnrankedMemRefOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes,
IsUnrankedMemRefTypePred, "unranked.memref",
@@ -809,10 +809,11 @@ class UnrankedMemRefOf<list<Type> allowedTypes> :
def AnyUnrankedMemRef : UnrankedMemRefOf<[AnyType]>;
-// Memrefs are blocks of data with fixed type and rank.
+// Any ranked memref whose element type is from the given `allowedTypes` list.
class MemRefOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref",
"::mlir::MemRefType">;
+
class Non0RankedMemRefOf<list<Type> allowedTypes> :
ConfinedType<MemRefOf<allowedTypes>, [HasRankGreaterOrEqualPred<1>],
"non-0-ranked." # MemRefOf<allowedTypes>.summary,
@@ -821,10 +822,18 @@ class Non0RankedMemRefOf<list<Type> allowedTypes> :
def AnyMemRef : MemRefOf<[AnyType]>;
def AnyNon0RankedMemRef : Non0RankedMemRefOf<[AnyType]>;
-class RankedOrUnrankedMemRefOf<list<Type> allowedTypes>:
- AnyTypeOf<[UnrankedMemRefOf<allowedTypes>, MemRefOf<allowedTypes>]>;
+// Any memref (ranked or unranked) whose element type is from the given
+// `allowedTypes` list, and which additionally satisfies an optional list of
+// predicates.
+class RankedOrUnrankedMemRefOf<
+ list<Type> allowedTypes,
+ list<Pred> preds = [],
+ string summary = "ranked or unranked memref">
+ : ShapedContainerType<allowedTypes,
+ And<!listconcat([IsBaseMemRefTypePred], preds)>,
+ summary, "::mlir::BaseMemRefType">;
-def AnyRankedOrUnrankedMemRef: AnyTypeOf<[AnyUnrankedMemRef, AnyMemRef]>;
+def AnyRankedOrUnrankedMemRef : RankedOrUnrankedMemRefOf<[AnyType]>;
def AnyNon0RankedOrUnrankedMemRef:
AnyTypeOf<[AnyUnrankedMemRef, AnyNon0RankedMemRef]>;
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 19874f08cd2fe..37ac1ca77328b 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -910,7 +910,7 @@ func.func @test_alloc_memref_map_rank_mismatch() {
// -----
func.func @rank(%0: f32) {
- // expected-error at +1 {{'memref.rank' op operand #0 must be unranked.memref of any type values or memref of any type values}}
+ // expected-error at +1 {{'memref.rank' op operand #0 must be ranked or unranked memref of any type values}}
"memref.rank"(%0): (f32)->index
return
}
More information about the Mlir-commits
mailing list