[Mlir-commits] [mlir] 0371ddf - [mlir] Refactoring the tablegen Tensor types

wren romano llvmlistbot at llvm.org
Wed Jun 8 11:33:58 PDT 2022


Author: wren romano
Date: 2022-06-08T11:33:48-07:00
New Revision: 0371ddf9adbec24d87a14ff21362132e5b598042

URL: https://github.com/llvm/llvm-project/commit/0371ddf9adbec24d87a14ff21362132e5b598042
DIFF: https://github.com/llvm/llvm-project/commit/0371ddf9adbec24d87a14ff21362132e5b598042.diff

LOG: [mlir] Refactoring the tablegen Tensor types

Reduces repetition in tablegen files for defining various tensor types.  In particular the goal is to reduce the repetition when defining new tensor types (e.g., D126994).

Reviewed By: aartbik, rriddle

Differential Revision: https://reviews.llvm.org/D127039

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/include/mlir/IR/OpBase.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index c0c3d4920a1fe..76f408d76c955 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -100,20 +100,12 @@ def IsSparseTensorPred
 // `RankedTensorOf`, `AnyRankedTensor`.
 
 class SparseTensorOf<list<Type> allowedTypes>
-  : ShapedContainerType<
-      allowedTypes,
-      And<[IsTensorTypePred, IsSparseTensorPred]>,
-      "sparse tensor",
-      "::mlir::TensorType">;
+  : TensorOf<allowedTypes, [IsSparseTensorPred], "sparse tensor">;
 
 def AnySparseTensor : SparseTensorOf<[AnyType]>;
 
 class RankedSparseTensorOf<list<Type> allowedTypes>
-  : ShapedContainerType<
-      allowedTypes,
-      And<[IsTensorTypePred, HasRankPred, IsSparseTensorPred]>,
-      "ranked sparse tensor",
-      "::mlir::TensorType">;
+  : RankedTensorOf<allowedTypes, [IsSparseTensorPred], "ranked sparse tensor">;
 
 def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>;
 

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 207ae2a1ef2b1..9f2ae6fd8b804 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -669,34 +669,29 @@ def AnyScalableVector : ScalableVectorOf<[AnyType]>;
 def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
                                    "::mlir::ShapedType">;
 
+//===----------------------------------------------------------------------===//
 // Tensor types.
 
-// Any tensor type whose element type is from the given `allowedTypes` list
-class TensorOf<list<Type> allowedTypes> :
-  ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor",
-                      "::mlir::TensorType">;
-
-class RankedTensorOf<list<Type> allowedTypes> :
-  ShapedContainerType<allowedTypes, And<[IsTensorTypePred, HasRankPred]>,
-  "ranked tensor", "::mlir::TensorType">;
+// Unranked tensor type whose element type is from the given
+// `allowedTypes` list.
+class UnrankedTensorOf<list<Type> allowedTypes>
+  : ShapedContainerType<allowedTypes, IsUnrankedTensorTypePred,
+      "unranked.tensor", "::mlir::UnrankedTensorType">;
 
-def AnyTensor : TensorOf<[AnyType]>;
-
-// Unranked Memref type
-class UnrankedTensorOf<list<Type> allowedTypes> :
-    ShapedContainerType<allowedTypes,
-                        IsUnrankedTensorTypePred,
-                        "unranked.tensor", "::mlir::UnrankedTensorType">;
-
-def AnyRankedTensor : RankedTensorOf<[AnyType]>;
-
-// TODO: Have an easy way to add another constraint to a type.
-class StaticShapeTensorOf<list<Type> allowedTypes>
-    : Type<And<[TensorOf<allowedTypes>.predicate, HasStaticShapePred]>,
-           "statically shaped " # TensorOf<allowedTypes>.summary,
-           "::mlir::TensorType">;
-
-def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
+// Any tensor type whose element type is from the given `allowedTypes`
+// list, and which additionally satisfies an optional list of predicates.
+//
+// TODO: use `Constraint` instead of `Pred`, so we can generate a better
+// default summary (a la `Confined`).
+class TensorOf<
+    list<Type> allowedTypes,
+    list<Pred> preds = [],
+    string summary = "tensor">
+  : ShapedContainerType<allowedTypes,
+      And<!listconcat([IsTensorTypePred], preds)>,
+      summary, "::mlir::TensorType">;
+
+def AnyTensor  : TensorOf<[AnyType]>;
 
 def I1Tensor   : TensorOf<[I1]>;
 def I8Tensor   : TensorOf<[I8]>;
@@ -710,11 +705,19 @@ def F16Tensor  : TensorOf<[F16]>;
 def F32Tensor  : TensorOf<[F32]>;
 def F64Tensor  : TensorOf<[F64]>;
 
+class RankedTensorOf<
+    list<Type> allowedTypes,
+    list<Pred> preds = [],
+    string summary = "ranked tensor">
+  : TensorOf<allowedTypes, !listconcat([HasRankPred], preds), summary>;
+
+def AnyRankedTensor : RankedTensorOf<[AnyType]>;
+
 // Ranked tensor type with one of the specified types and ranks.
-class TensorRankOf<list<Type> allowedTypes, list<int> ranks> :
-    Type<And<[TensorOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
-         !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
-         TensorOf<allowedTypes>.summary, "::mlir::TensorType">;
+class TensorRankOf<list<Type> allowedTypes, list<int> ranks>
+  : TensorOf<allowedTypes,
+      [HasAnyRankOfPred<ranks>],
+      !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
 
 class 0DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [0]>;
 class 1DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [1]>;
@@ -722,6 +725,14 @@ class 2DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [2]>;
 class 3DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [3]>;
 class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;
 
+class StaticShapeTensorOf<list<Type> allowedTypes>
+  : TensorOf<allowedTypes, [HasStaticShapePred], "statically shaped tensor">;
+
+def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
+
+//===----------------------------------------------------------------------===//
+// Memref type.
+
 // Unranked Memref type
 class UnrankedMemRefOf<list<Type> allowedTypes> :
     ShapedContainerType<allowedTypes,
@@ -730,8 +741,6 @@ class UnrankedMemRefOf<list<Type> allowedTypes> :
 
 def AnyUnrankedMemRef : UnrankedMemRefOf<[AnyType]>;
 
-// Memref type.
-
 // Memrefs are blocks of data with fixed type and rank.
 class MemRefOf<list<Type> allowedTypes> :
     ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref",


        


More information about the Mlir-commits mailing list