[Mlir-commits] [mlir] 58c296a - [mlir] Use BaseMemRefType for ranked/unranked memrefs

Matthias Springer llvmlistbot at llvm.org
Wed Apr 5 21:34:42 PDT 2023


Author: Matthias Springer
Date: 2023-04-06T13:29:44+09:00
New Revision: 58c296a418cb2a3dbd542a39b5077eb32e1f1895

URL: https://github.com/llvm/llvm-project/commit/58c296a418cb2a3dbd542a39b5077eb32e1f1895
DIFF: https://github.com/llvm/llvm-project/commit/58c296a418cb2a3dbd542a39b5077eb32e1f1895.diff

LOG: [mlir] Use BaseMemRefType for ranked/unranked memrefs

This makes `RankedOrUnrankedMemRefOf` consistent with `TensorOf`.

Differential Revision: https://reviews.llvm.org/D147160

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpBase.td
    mlir/test/Dialect/MemRef/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index f7f009cce3177..98866c83b4b4e 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -801,7 +801,7 @@ def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
 //===----------------------------------------------------------------------===//
 // Memref type.
 
-// Unranked Memref type
+// Any unranked memref whose element type is from the given `allowedTypes` list.
 class UnrankedMemRefOf<list<Type> allowedTypes> :
     ShapedContainerType<allowedTypes,
                         IsUnrankedMemRefTypePred, "unranked.memref",
@@ -809,10 +809,11 @@ class UnrankedMemRefOf<list<Type> allowedTypes> :
 
 def AnyUnrankedMemRef : UnrankedMemRefOf<[AnyType]>;
 
-// Memrefs are blocks of data with fixed type and rank.
+// Any ranked memref whose element type is from the given `allowedTypes` list.
 class MemRefOf<list<Type> allowedTypes> :
     ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref",
                         "::mlir::MemRefType">;
+
 class Non0RankedMemRefOf<list<Type> allowedTypes> :
     ConfinedType<MemRefOf<allowedTypes>, [HasRankGreaterOrEqualPred<1>],
          "non-0-ranked." # MemRefOf<allowedTypes>.summary,
@@ -821,10 +822,18 @@ class Non0RankedMemRefOf<list<Type> allowedTypes> :
 def AnyMemRef : MemRefOf<[AnyType]>;
 def AnyNon0RankedMemRef : Non0RankedMemRefOf<[AnyType]>;
 
-class RankedOrUnrankedMemRefOf<list<Type> allowedTypes>:
-    AnyTypeOf<[UnrankedMemRefOf<allowedTypes>, MemRefOf<allowedTypes>]>;
+// Any memref (ranked or unranked) whose element type is from the given
+// `allowedTypes` list, and which additionally satisfies an optional list of
+// predicates.
+class RankedOrUnrankedMemRefOf<
+    list<Type> allowedTypes,
+    list<Pred> preds = [],
+    string summary = "ranked or unranked memref">
+  : ShapedContainerType<allowedTypes,
+      And<!listconcat([IsBaseMemRefTypePred], preds)>,
+      summary, "::mlir::BaseMemRefType">;
 
-def AnyRankedOrUnrankedMemRef: AnyTypeOf<[AnyUnrankedMemRef, AnyMemRef]>;
+def AnyRankedOrUnrankedMemRef  : RankedOrUnrankedMemRefOf<[AnyType]>;
 def AnyNon0RankedOrUnrankedMemRef:
     AnyTypeOf<[AnyUnrankedMemRef, AnyNon0RankedMemRef]>;
 

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 19874f08cd2fe..37ac1ca77328b 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -910,7 +910,7 @@ func.func @test_alloc_memref_map_rank_mismatch() {
 // -----
 
 func.func @rank(%0: f32) {
-  // expected-error at +1 {{'memref.rank' op operand #0 must be unranked.memref of any type values or memref of any type values}}
+  // expected-error at +1 {{'memref.rank' op operand #0 must be ranked or unranked memref of any type values}}
   "memref.rank"(%0): (f32)->index
   return
 }


        


More information about the Mlir-commits mailing list