[Mlir-commits] [mlir] 8a3481b - [mlir] Add AllOfType and ConfinedType constraints
Jeff Niu
llvmlistbot at llvm.org
Fri Aug 12 13:25:43 PDT 2022
Author: Jeff Niu
Date: 2022-08-12T16:25:36-04:00
New Revision: 8a3481b958744aef8f23ade8d5d9b3e4b3230f58
URL: https://github.com/llvm/llvm-project/commit/8a3481b958744aef8f23ade8d5d9b3e4b3230f58
DIFF: https://github.com/llvm/llvm-project/commit/8a3481b958744aef8f23ade8d5d9b3e4b3230f58.diff
LOG: [mlir] Add AllOfType and ConfinedType constraints
`AllOfType` is a type constraint that satisfies all given type
constraints and `ConfinedType` is a type that satisfies additional
predicates. These shorthands simplify type constraint definition mostly
by removing the need to deal with `myType.predicate` manipulation.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D131788
Added:
Modified:
mlir/include/mlir/IR/OpBase.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 0202f829df325..a2d68672201a3 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -340,13 +340,29 @@ def NoneType : Type<CPred<"$_self.isa<::mlir::NoneType>()">, "none type",
// Any type from the given list
class AnyTypeOf<list<Type> allowedTypes, string summary = "",
string cppClassName = "::mlir::Type"> : Type<
- // Satisfy any of the allowed type's condition
+ // Satisfy any of the allowed types' conditions.
Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedTypes, t.summary), " or "),
summary),
cppClassName>;
+// A type that satisfies the constraints of all given types.
+class AllOfType<list<Type> allowedTypes, string summary = "",
+ string cppClassName = "::mlir::Type"> : Type<
+ // Satisfy all of the allowedf types' conditions.
+ And<!foreach(allowedType, allowedTypes, allowedType.predicate)>,
+ !if(!eq(summary, ""),
+ !interleave(!foreach(t, allowedTypes, t.summary), " and "),
+ summary),
+ cppClassName>;
+
+// A type that satisfies additional predicates.
+class ConfinedType<Type type, list<Pred> predicates, string summary = "",
+ string cppClassName = "::mlir::Type"> : Type<
+ And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
+ summary, cppClassName>;
+
// Integer types.
// Any integer type irrespective of its width and signedness semantics.
@@ -475,12 +491,14 @@ def F128 : F<128>;
def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
BuildableType<"$_builder.getBF16Type()">;
+def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
+ "complex-type", "::mlir::ComplexType">;
+
class Complex<Type type>
- : Type<And<[
- CPred<"$_self.isa<::mlir::ComplexType>()">,
+ : ConfinedType<AnyComplex, [
SubstLeaves<"$_self",
"$_self.cast<::mlir::ComplexType>().getElementType()",
- type.predicate>]>,
+ type.predicate>],
"complex type with " # type.summary # " elements",
"::mlir::ComplexType">,
SameBuildabilityAs<type, "::mlir::ComplexType::get($_builder.get" # type #
@@ -488,9 +506,6 @@ class Complex<Type type>
Type elementType = type;
}
-def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
- "complex-type", "::mlir::ComplexType">;
-
class OpaqueType<string dialect, string name, string summary>
: Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
summary, "::mlir::OpaqueType">,
@@ -572,9 +587,8 @@ class VectorOfRank<list<int> allowedRanks> : Type<
// Any vector where the rank is from the given `allowedRanks` list and the type
// is from the given `allowedTypes` list
class VectorOfRankAndType<list<int> allowedRanks,
- list<Type> allowedTypes> : Type<
- And<[VectorOf<allowedTypes>.predicate,
- VectorOfRank<allowedRanks>.predicate]>,
+ list<Type> allowedTypes> : AllOfType<
+ [VectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
VectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
"::mlir::VectorType">;
@@ -630,18 +644,16 @@ class ScalableVectorOfLength<list<int> allowedLengths> : Type<
// `allowedLengths` list and the type is from the given `allowedTypes`
// list
class VectorOfLengthAndType<list<int> allowedLengths,
- list<Type> allowedTypes> : Type<
- And<[VectorOf<allowedTypes>.predicate,
- VectorOfLength<allowedLengths>.predicate]>,
+ list<Type> allowedTypes> : AllOfType<
+ [VectorOf<allowedTypes>, VectorOfLength<allowedLengths>],
VectorOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
// Any fixed-length vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes` list
class FixedVectorOfLengthAndType<list<int> allowedLengths,
- list<Type> allowedTypes> : Type<
- And<[FixedVectorOf<allowedTypes>.predicate,
- FixedVectorOfLength<allowedLengths>.predicate]>,
+ list<Type> allowedTypes> : AllOfType<
+ [FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
FixedVectorOf<allowedTypes>.summary #
FixedVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
@@ -649,9 +661,8 @@ class FixedVectorOfLengthAndType<list<int> allowedLengths,
// Any scalable vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes` list
class ScalableVectorOfLengthAndType<list<int> allowedLengths,
- list<Type> allowedTypes> : Type<
- And<[ScalableVectorOf<allowedTypes>.predicate,
- ScalableVectorOfLength<allowedLengths>.predicate]>,
+ list<Type> allowedTypes> : AllOfType<
+ [ScalableVectorOf<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
ScalableVectorOf<allowedTypes>.summary #
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
@@ -768,34 +779,33 @@ def F64MemRef : MemRefOf<[F64]>;
// TODO: Have an easy way to add another constraint to a type.
class MemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
- Type<And<[MemRefOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
+ ConfinedType<MemRefOf<allowedTypes>, [HasAnyRankOfPred<ranks>],
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
MemRefOf<allowedTypes>.summary,
"::mlir::MemRefType">;
-class StaticShapeMemRefOf<list<Type> allowedTypes>
- : Type<And<[MemRefOf<allowedTypes>.predicate, HasStaticShapePred]>,
- "statically shaped " # MemRefOf<allowedTypes>.summary,
- "::mlir::MemRefType">;
+class StaticShapeMemRefOf<list<Type> allowedTypes> :
+ ConfinedType<MemRefOf<allowedTypes>, [HasStaticShapePred],
+ "statically shaped " # MemRefOf<allowedTypes>.summary,
+ "::mlir::MemRefType">;
def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>;
// For a MemRefType, verify that it has strides.
def HasStridesPred : CPred<[{ isStrided($_self.cast<::mlir::MemRefType>()) }]>;
-class StridedMemRefOf<list<Type> allowedTypes>
- : Type<And<[MemRefOf<allowedTypes>.predicate, HasStridesPred]>,
- "strided " # MemRefOf<allowedTypes>.summary>;
+class StridedMemRefOf<list<Type> allowedTypes> :
+ ConfinedType<MemRefOf<allowedTypes>, [HasStridesPred],
+ "strided " # MemRefOf<allowedTypes>.summary>;
def AnyStridedMemRef : StridedMemRefOf<[AnyType]>;
class AnyStridedMemRefOfRank<int rank> :
- Type<And<[AnyStridedMemRef.predicate,
- MemRefRankOf<[AnyType], [rank]>.predicate]>,
+ AllOfType<[AnyStridedMemRef, MemRefRankOf<[AnyType], [rank]>],
AnyStridedMemRef.summary # " of rank " # rank>;
class StridedMemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
- Type<And<[MemRefOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
+ ConfinedType<MemRefOf<allowedTypes>, [HasAnyRankOfPred<ranks>],
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
MemRefOf<allowedTypes>.summary>;
More information about the Mlir-commits
mailing list