[Mlir-commits] [mlir] 8a3481b - [mlir] Add AllOfType and ConfinedType constraints

Jeff Niu llvmlistbot at llvm.org
Fri Aug 12 13:25:43 PDT 2022


Author: Jeff Niu
Date: 2022-08-12T16:25:36-04:00
New Revision: 8a3481b958744aef8f23ade8d5d9b3e4b3230f58

URL: https://github.com/llvm/llvm-project/commit/8a3481b958744aef8f23ade8d5d9b3e4b3230f58
DIFF: https://github.com/llvm/llvm-project/commit/8a3481b958744aef8f23ade8d5d9b3e4b3230f58.diff

LOG: [mlir] Add AllOfType and ConfinedType constraints

`AllOfType` is a type constraint that satisfies all given type
constraints and `ConfinedType` is a type that satisfies additional
predicates. These shorthands simplify type constraint definition mostly
by removing the need to deal with `myType.predicate` manipulation.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D131788

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpBase.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 0202f829df325..a2d68672201a3 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -340,13 +340,29 @@ def NoneType : Type<CPred<"$_self.isa<::mlir::NoneType>()">, "none type",
 // Any type from the given list
 class AnyTypeOf<list<Type> allowedTypes, string summary = "",
                 string cppClassName = "::mlir::Type"> : Type<
-    // Satisfy any of the allowed type's condition
+    // Satisfy any of the allowed types' conditions.
     Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
     !if(!eq(summary, ""),
         !interleave(!foreach(t, allowedTypes, t.summary), " or "),
         summary),
     cppClassName>;
 
+// A type that satisfies the constraints of all given types.
+class AllOfType<list<Type> allowedTypes, string summary = "",
+                string cppClassName = "::mlir::Type"> : Type<
+    // Satisfy all of the allowedf types' conditions.
+    And<!foreach(allowedType, allowedTypes, allowedType.predicate)>,
+    !if(!eq(summary, ""),
+        !interleave(!foreach(t, allowedTypes, t.summary), " and "),
+        summary),
+    cppClassName>;
+
+// A type that satisfies additional predicates.
+class ConfinedType<Type type, list<Pred> predicates, string summary = "",
+                   string cppClassName = "::mlir::Type"> : Type<
+    And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
+    summary, cppClassName>;
+
 // Integer types.
 
 // Any integer type irrespective of its width and signedness semantics.
@@ -475,12 +491,14 @@ def F128 : F<128>;
 def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
            BuildableType<"$_builder.getBF16Type()">;
 
+def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
+                      "complex-type", "::mlir::ComplexType">;
+
 class Complex<Type type>
-    : Type<And<[
-          CPred<"$_self.isa<::mlir::ComplexType>()">,
+    : ConfinedType<AnyComplex, [
           SubstLeaves<"$_self",
                       "$_self.cast<::mlir::ComplexType>().getElementType()",
-           type.predicate>]>,
+           type.predicate>],
            "complex type with " # type.summary # " elements",
            "::mlir::ComplexType">,
       SameBuildabilityAs<type, "::mlir::ComplexType::get($_builder.get" # type #
@@ -488,9 +506,6 @@ class Complex<Type type>
   Type elementType = type;
 }
 
-def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
-                      "complex-type", "::mlir::ComplexType">;
-
 class OpaqueType<string dialect, string name, string summary>
   : Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
          summary, "::mlir::OpaqueType">,
@@ -572,9 +587,8 @@ class VectorOfRank<list<int> allowedRanks> : Type<
 // Any vector where the rank is from the given `allowedRanks` list and the type
 // is from the given `allowedTypes` list
 class VectorOfRankAndType<list<int> allowedRanks,
-                          list<Type> allowedTypes> : Type<
-  And<[VectorOf<allowedTypes>.predicate,
-       VectorOfRank<allowedRanks>.predicate]>,
+                          list<Type> allowedTypes> : AllOfType<
+  [VectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
   VectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
   "::mlir::VectorType">;
 
@@ -630,18 +644,16 @@ class ScalableVectorOfLength<list<int> allowedLengths> : Type<
 // `allowedLengths` list and the type is from the given `allowedTypes`
 // list
 class VectorOfLengthAndType<list<int> allowedLengths,
-                            list<Type> allowedTypes> : Type<
-  And<[VectorOf<allowedTypes>.predicate,
-       VectorOfLength<allowedLengths>.predicate]>,
+                            list<Type> allowedTypes> : AllOfType<
+  [VectorOf<allowedTypes>, VectorOfLength<allowedLengths>],
   VectorOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
 // Any fixed-length vector where the number of elements is from the given
 // `allowedLengths` list and the type is from the given `allowedTypes` list
 class FixedVectorOfLengthAndType<list<int> allowedLengths,
-                                    list<Type> allowedTypes> : Type<
-  And<[FixedVectorOf<allowedTypes>.predicate,
-       FixedVectorOfLength<allowedLengths>.predicate]>,
+                                    list<Type> allowedTypes> : AllOfType<
+  [FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
   FixedVectorOf<allowedTypes>.summary #
   FixedVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
@@ -649,9 +661,8 @@ class FixedVectorOfLengthAndType<list<int> allowedLengths,
 // Any scalable vector where the number of elements is from the given
 // `allowedLengths` list and the type is from the given `allowedTypes` list
 class ScalableVectorOfLengthAndType<list<int> allowedLengths,
-                                    list<Type> allowedTypes> : Type<
-  And<[ScalableVectorOf<allowedTypes>.predicate,
-       ScalableVectorOfLength<allowedLengths>.predicate]>,
+                                    list<Type> allowedTypes> : AllOfType<
+  [ScalableVectorOf<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
   ScalableVectorOf<allowedTypes>.summary #
   ScalableVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
@@ -768,34 +779,33 @@ def F64MemRef  : MemRefOf<[F64]>;
 
 // TODO: Have an easy way to add another constraint to a type.
 class MemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
-    Type<And<[MemRefOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
+    ConfinedType<MemRefOf<allowedTypes>, [HasAnyRankOfPred<ranks>],
          !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
          MemRefOf<allowedTypes>.summary,
          "::mlir::MemRefType">;
 
-class StaticShapeMemRefOf<list<Type> allowedTypes>
-    : Type<And<[MemRefOf<allowedTypes>.predicate, HasStaticShapePred]>,
-           "statically shaped " # MemRefOf<allowedTypes>.summary,
-           "::mlir::MemRefType">;
+class StaticShapeMemRefOf<list<Type> allowedTypes> :
+    ConfinedType<MemRefOf<allowedTypes>, [HasStaticShapePred],
+         "statically shaped " # MemRefOf<allowedTypes>.summary,
+         "::mlir::MemRefType">;
 
 def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>;
 
 // For a MemRefType, verify that it has strides.
 def HasStridesPred : CPred<[{ isStrided($_self.cast<::mlir::MemRefType>()) }]>;
 
-class StridedMemRefOf<list<Type> allowedTypes>
-    : Type<And<[MemRefOf<allowedTypes>.predicate, HasStridesPred]>,
-           "strided " # MemRefOf<allowedTypes>.summary>;
+class StridedMemRefOf<list<Type> allowedTypes> :
+    ConfinedType<MemRefOf<allowedTypes>, [HasStridesPred],
+         "strided " # MemRefOf<allowedTypes>.summary>;
 
 def AnyStridedMemRef : StridedMemRefOf<[AnyType]>;
 
 class AnyStridedMemRefOfRank<int rank> :
-  Type<And<[AnyStridedMemRef.predicate,
-            MemRefRankOf<[AnyType], [rank]>.predicate]>,
+  AllOfType<[AnyStridedMemRef, MemRefRankOf<[AnyType], [rank]>],
        AnyStridedMemRef.summary # " of rank " # rank>;
 
 class StridedMemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
-    Type<And<[MemRefOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
+    ConfinedType<MemRefOf<allowedTypes>, [HasAnyRankOfPred<ranks>],
          !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
          MemRefOf<allowedTypes>.summary>;
 


        


More information about the Mlir-commits mailing list