[Mlir-commits] [mlir] [mlir][irdl] Add support for basic structural constraints in tblgen-to-irdl (PR #82862)
Fehr Mathieu
llvmlistbot at llvm.org
Fri Feb 23 18:58:39 PST 2024
https://github.com/math-fehr created https://github.com/llvm/llvm-project/pull/82862
Adds tblgen-to-irdl support for `TypeDef`, `AnyType`, `AnyTypeOf`, and `AllOfType` ODS constraints.
This is done by introspecting the TableGen constructs directly.
For instance, `shape.add` now looks like:
```
irdl.operation @add {
%0 = irdl.base "!shape.size"
%1 = irdl.c_pred "(::llvm::isa<::mlir::IndexType>($_self))"
%2 = irdl.any_of(%0, %1)
%3 = irdl.base "!shape.size"
%4 = irdl.c_pred "(::llvm::isa<::mlir::IndexType>($_self))"
%5 = irdl.any_of(%3, %4)
%6 = irdl.base "!shape.size"
%7 = irdl.c_pred "(::llvm::isa<::mlir::IndexType>($_self))"
%8 = irdl.any_of(%6, %7)
irdl.operands(%2, %5)
irdl.results(%8)
}
```
instead of previously
```
irdl.operation @add {
%0 = irdl.c_pred "((::llvm::isa<::mlir::shape::SizeType>($_self))) || ((::llvm::isa<::mlir::IndexType>($_self)))"
%1 = irdl.c_pred "((::llvm::isa<::mlir::shape::SizeType>($_self))) || ((::llvm::isa<::mlir::IndexType>($_self)))"
%2 = irdl.c_pred "((::llvm::isa<::mlir::shape::SizeType>($_self))) || ((::llvm::isa<::mlir::IndexType>($_self)))"
irdl.operands(%0, %1)
irdl.results(%2)
}
```
>From 15d244bffc72606efc118342fef7dad882d8f5da Mon Sep 17 00:00:00 2001
From: Mathieu Fehr <mathieu.fehr at gmail.com>
Date: Tue, 26 Dec 2023 22:28:19 +0000
Subject: [PATCH] [mlir][irdl] Add support for structural constraints in
tblgen-to-irdl
---
mlir/include/mlir/IR/CommonTypeConstraints.td | 22 +++---
mlir/test/tblgen-to-irdl/CMathDialect.td | 12 +--
mlir/test/tblgen-to-irdl/TestDialect.td | 74 +++++++++++++++++++
.../tools/tblgen-to-irdl/OpDefinitionsGen.cpp | 48 ++++++++++--
4 files changed, 134 insertions(+), 22 deletions(-)
create mode 100644 mlir/test/tblgen-to-irdl/TestDialect.td
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 03180a687523bf..af4f13dc09360d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -168,24 +168,28 @@ def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",
BuildableType<"$_builder.getType<::mlir::NoneType>()">;
// Any type from the given list
-class AnyTypeOf<list<Type> allowedTypes, string summary = "",
+class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
string cppClassName = "::mlir::Type"> : Type<
// Satisfy any of the allowed types' conditions.
- Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
+ Or<!foreach(allowedtype, allowedTypeList, allowedtype.predicate)>,
!if(!eq(summary, ""),
- !interleave(!foreach(t, allowedTypes, t.summary), " or "),
+ !interleave(!foreach(t, allowedTypeList, t.summary), " or "),
summary),
- cppClassName>;
+ cppClassName> {
+ list<Type> allowedTypes = allowedTypeList;
+}
// A type that satisfies the constraints of all given types.
-class AllOfType<list<Type> allowedTypes, string summary = "",
+class AllOfType<list<Type> allowedTypeList, string summary = "",
string cppClassName = "::mlir::Type"> : Type<
- // Satisfy all of the allowedf types' conditions.
- And<!foreach(allowedType, allowedTypes, allowedType.predicate)>,
+ // Satisfy all of the allowed types' conditions.
+ And<!foreach(allowedType, allowedTypeList, allowedType.predicate)>,
!if(!eq(summary, ""),
- !interleave(!foreach(t, allowedTypes, t.summary), " and "),
+ !interleave(!foreach(t, allowedTypeList, t.summary), " and "),
summary),
- cppClassName>;
+ cppClassName> {
+ list<Type> allowedTypes = allowedTypeList;
+}
// A type that satisfies additional predicates.
class ConfinedType<Type type, list<Pred> predicates, string summary = "",
diff --git a/mlir/test/tblgen-to-irdl/CMathDialect.td b/mlir/test/tblgen-to-irdl/CMathDialect.td
index 57ae8afbba5eeb..5b9e756727cb36 100644
--- a/mlir/test/tblgen-to-irdl/CMathDialect.td
+++ b/mlir/test/tblgen-to-irdl/CMathDialect.td
@@ -24,7 +24,7 @@ def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {
}
// CHECK: irdl.operation @identity {
-// CHECK-NEXT: %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
+// CHECK-NEXT: %0 = irdl.base "!cmath.complex"
// CHECK-NEXT: irdl.operands()
// CHECK-NEXT: irdl.results(%0)
// CHECK-NEXT: }
@@ -33,9 +33,9 @@ def CMath_IdentityOp : CMath_Op<"identity"> {
}
// CHECK: irdl.operation @mul {
-// CHECK-NEXT: %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
-// CHECK-NEXT: %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
-// CHECK-NEXT: %2 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
+// CHECK-NEXT: %0 = irdl.base "!cmath.complex"
+// CHECK-NEXT: %1 = irdl.base "!cmath.complex"
+// CHECK-NEXT: %2 = irdl.base "!cmath.complex"
// CHECK-NEXT: irdl.operands(%0, %1)
// CHECK-NEXT: irdl.results(%2)
// CHECK-NEXT: }
@@ -45,8 +45,8 @@ def CMath_MulOp : CMath_Op<"mul"> {
}
// CHECK: irdl.operation @norm {
-// CHECK-NEXT: %0 = irdl.c_pred "(true)"
-// CHECK-NEXT: %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
+// CHECK-NEXT: %0 = irdl.any
+// CHECK-NEXT: %1 = irdl.base "!cmath.complex"
// CHECK-NEXT: irdl.operands(%0)
// CHECK-NEXT: irdl.results(%1)
// CHECK-NEXT: }
diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td
new file mode 100644
index 00000000000000..fc40da527db00a
--- /dev/null
+++ b/mlir/test/tblgen-to-irdl/TestDialect.td
@@ -0,0 +1,74 @@
+// RUN: tblgen-to-irdl %s -I=%S/../../include --gen-dialect-irdl-defs --dialect=test | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/AttrTypeBase.td"
+
+// CHECK-LABEL: irdl.dialect @test {
+def Test_Dialect : Dialect {
+ let name = "test";
+}
+
+class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
+: TypeDef<Test_Dialect, name, traits> {
+ let mnemonic = typeMnemonic;
+}
+
+class Test_Op<string mnemonic, list<Trait> traits = []>
+ : Op<Test_Dialect, mnemonic, traits>;
+
+def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
+def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
+def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}
+
+
+// Check that AllOfType is converted correctly.
+def Test_AndOp : Test_Op<"and"> {
+ let arguments = (ins AllOfType<[Test_SingletonAType, AnyType]>:$in);
+}
+// CHECK-LABEL: irdl.operation @and {
+// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
+// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.any
+// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
+// CHECK-NEXT: irdl.operands(%[[v2]])
+// CHECK-NEXT: irdl.results()
+// CHECK-NEXT: }
+
+
+// Check that AnyType is converted correctly.
+def Test_AnyOp : Test_Op<"any"> {
+ let arguments = (ins AnyType:$in);
+}
+// CHECK-LABEL: irdl.operation @any {
+// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.any
+// CHECK-NEXT: irdl.operands(%[[v0]])
+// CHECK-NEXT: irdl.results()
+// CHECK-NEXT: }
+
+
+// Check that AnyTypeOf is converted correctly.
+def Test_OrOp : Test_Op<"or"> {
+ let arguments = (ins AnyTypeOf<[Test_SingletonAType, Test_SingletonBType, Test_SingletonCType]>:$in);
+}
+// CHECK-LABEL: irdl.operation @or {
+// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
+// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
+// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
+// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
+// CHECK-NEXT: irdl.operands(%[[v3]])
+// CHECK-NEXT: irdl.results()
+// CHECK-NEXT: }
+
+
+// Check that variadics and optionals are converted correctly.
+def Test_VariadicityOp : Test_Op<"variadicity"> {
+ let arguments = (ins Variadic<Test_SingletonAType>:$variadic,
+ Optional<Test_SingletonBType>:$optional,
+ Test_SingletonCType:$required);
+}
+// CHECK-LABEL: irdl.operation @variadicity {
+// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
+// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
+// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
+// CHECK-NEXT: irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]])
+// CHECK-NEXT: irdl.results()
+// CHECK-NEXT: }
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index ba5bf4d9d4abbc..a55f3539f31db0 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -39,15 +39,49 @@ llvm::cl::opt<std::string>
selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
llvm::cl::cat(dialectGenCat), llvm::cl::Required);
-irdl::CPredOp createConstraint(OpBuilder &builder,
- NamedTypeConstraint namedConstraint) {
+Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
MLIRContext *ctx = builder.getContext();
- // Build the constraint as a string.
- std::string constraint =
- namedConstraint.constraint.getPredicate().getCondition();
+ const Record &predRec = constraint.getDef();
+
+ if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
+ return createConstraint(builder, predRec.getValueAsDef("baseType"));
+
+ if (predRec.getName() == "AnyType") {
+ auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
+ return op.getOutput();
+ }
+
+ if (predRec.isSubClassOf("TypeDef")) {
+ std::string typeName = ("!" + predRec.getValueAsString("typeName")).str();
+ auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+ StringAttr::get(ctx, typeName));
+ return op.getOutput();
+ }
+
+ if (predRec.isSubClassOf("AnyTypeOf")) {
+ std::vector<Value> constraints;
+ for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
+ constraints.push_back(
+ createConstraint(builder, tblgen::Constraint(child)));
+ }
+ auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
+ return op.getOutput();
+ }
+
+ if (predRec.isSubClassOf("AllOfType")) {
+ std::vector<Value> constraints;
+ for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
+ constraints.push_back(
+ createConstraint(builder, tblgen::Constraint(child)));
+ }
+ auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+ return op.getOutput();
+ }
+
+ std::string condition = constraint.getPredicate().getCondition();
// Build a CPredOp to match the C constraint built.
irdl::CPredOp op = builder.create<irdl::CPredOp>(
- UnknownLoc::get(ctx), StringAttr::get(ctx, constraint));
+ UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
return op;
}
@@ -74,7 +108,7 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
SmallVector<Value> operands;
SmallVector<irdl::VariadicityAttr> variadicity;
for (const NamedTypeConstraint &namedCons : namedCons) {
- auto operand = createConstraint(consBuilder, namedCons);
+ auto operand = createConstraint(consBuilder, namedCons.constraint);
operands.push_back(operand);
irdl::VariadicityAttr var;
More information about the Mlir-commits
mailing list