[Mlir-commits] [mlir] Refactor tblgen-to-irdl script and support more types (PR #105505)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Aug 21 05:05:00 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-ods
Author: Alex Rice (alexarice)
<details>
<summary>Changes</summary>
Refactors the tblgen-to-irdl script slightly and adds support for
- Various integer types
- Various Float types
- Confined types
- Complex types (with fixed element type)
Also doesn't add the operand and result ops if they are empty.
I could potentially split this into smaller PRs if that'd be helpful (refactor + integer/float/complex, confined type, optional operand/result).
@<!-- -->math-fehr
---
Full diff: https://github.com/llvm/llvm-project/pull/105505.diff
4 Files Affected:
- (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+4-1)
- (modified) mlir/test/tblgen-to-irdl/CMathDialect.td (-1)
- (modified) mlir/test/tblgen-to-irdl/TestDialect.td (+51-6)
- (modified) mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp (+169-9)
``````````diff
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4536d781ef674f..0e076413d0d9f3 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -198,7 +198,10 @@ class AllOfType<list<Type> allowedTypeList, string summary = "",
class ConfinedType<Type type, list<Pred> predicates, string summary = "",
string cppType = type.cppType> : Type<
And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
- summary, cppType>;
+ summary, cppType> {
+ Type baseType = type;
+ list<Pred> predicateList = predicates;
+}
// Integer types.
diff --git a/mlir/test/tblgen-to-irdl/CMathDialect.td b/mlir/test/tblgen-to-irdl/CMathDialect.td
index 5b9e756727cb36..454543e074c489 100644
--- a/mlir/test/tblgen-to-irdl/CMathDialect.td
+++ b/mlir/test/tblgen-to-irdl/CMathDialect.td
@@ -25,7 +25,6 @@ def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {
// CHECK: irdl.operation @identity {
// CHECK-NEXT: %0 = irdl.base "!cmath.complex"
-// CHECK-NEXT: irdl.operands()
// CHECK-NEXT: irdl.results(%0)
// CHECK-NEXT: }
def CMath_IdentityOp : CMath_Op<"identity"> {
diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td
index fc40da527db00a..a86dcb5b3b66e2 100644
--- a/mlir/test/tblgen-to-irdl/TestDialect.td
+++ b/mlir/test/tblgen-to-irdl/TestDialect.td
@@ -28,9 +28,8 @@ def Test_AndOp : Test_Op<"and"> {
// CHECK-LABEL: irdl.operation @and {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.any
-// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
+// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
// CHECK-NEXT: irdl.operands(%[[v2]])
-// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }
@@ -41,9 +40,37 @@ def Test_AnyOp : Test_Op<"any"> {
// CHECK-LABEL: irdl.operation @any {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.any
// CHECK-NEXT: irdl.operands(%[[v0]])
-// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }
+// Check confined types are converted correctly.
+def Test_ConfinedOp : Test_Op<"confined"> {
+ let arguments = (ins ConfinedType<I32, [IntNonNegative.predicate]>:$confined,
+ ConfinedType<I8, [And<[IntMinValue<1>.predicate, IntMaxValue<2>.predicate]>]>:$bounded);
+}
+// CHECK-LABEL: irdl.operation @confined {
+// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.is i32
+// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.c_pred "{{.*}}"
+// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
+// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.is i8
+// CHECK-NEXT: %[[v4:[^ ]*]] = irdl.c_pred "{{.*}}"
+// CHECK-NEXT: %[[v5:[^ ]*]] = irdl.c_pred "{{.*}}"
+// CHECK-NEXT: %[[v6:[^ ]*]] = irdl.all_of(%[[v4]], %[[v5]])
+// CHECK-NEXT: %[[v7:[^ ]*]] = irdl.all_of(%[[v3]], %[[v6]])
+// CHECK-NEXT: irdl.operands(%[[v2]], %[[v7]])
+// CHECK-NEXT: }
+
+def Test_Integers : Test_Op<"integers"> {
+ let arguments = (ins AnyI8:$any_int,
+ AnyInteger:$any_integer);
+}
+// CHECK-LABEL: irdl.operation @integers {
+// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.is i8
+// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.is si8
+// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.is ui8
+// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
+// CHECK-NEXT: %[[v4:[^ ]*]] = irdl.base "!builtin.integer"
+// CHECK-NEXT: irdl.operands(%[[v3]], %[[v4]])
+// CHECK-NEXT: }
// Check that AnyTypeOf is converted correctly.
def Test_OrOp : Test_Op<"or"> {
@@ -53,11 +80,30 @@ def Test_OrOp : Test_Op<"or"> {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
-// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
+// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
// CHECK-NEXT: irdl.operands(%[[v3]])
-// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }
+// Check that various types are converted correctly.
+def Test_TypesOp : Test_Op<"types"> {
+ let arguments = (ins I32:$a,
+ SI64:$b,
+ UI8:$c,
+ Index:$d,
+ F32:$e,
+ NoneType:$f,
+ Complex<F8E4M3FN>);
+}
+// CHECK-LABEL: irdl.operation @types {
+// CHECK-NEXT: %{{.*}} = irdl.is i32
+// CHECK-NEXT: %{{.*}} = irdl.is si64
+// CHECK-NEXT: %{{.*}} = irdl.is ui8
+// CHECK-NEXT: %{{.*}} = irdl.is index
+// CHECK-NEXT: %{{.*}} = irdl.is f32
+// CHECK-NEXT: %{{.*}} = irdl.is none
+// CHECK-NEXT: %{{.*}} = irdl.is complex<f8E4M3FN>
+// CHECK-NEXT: irdl.operands({{.*}})
+// CHECK-NEXT: }
// Check that variadics and optionals are converted correctly.
def Test_VariadicityOp : Test_Op<"variadicity"> {
@@ -70,5 +116,4 @@ def Test_VariadicityOp : Test_Op<"variadicity"> {
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
// CHECK-NEXT: irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]])
-// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index a55f3539f31db0..181d02c6608bdb 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -39,6 +39,130 @@ llvm::cl::opt<std::string>
selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
llvm::cl::cat(dialectGenCat), llvm::cl::Required);
+Value createPredicate(OpBuilder &builder, tblgen::Pred pred) {
+ MLIRContext *ctx = builder.getContext();
+
+ if (pred.isCombined()) {
+ auto combiner = pred.getDef().getValueAsDef("kind")->getName();
+ if (combiner == "PredCombinerAnd" || combiner == "PredCombinerOr") {
+ std::vector<Value> constraints;
+ for (auto *child : pred.getDef().getValueAsListOfDefs("children")) {
+ constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
+ }
+ if (combiner == "PredCombinerAnd") {
+ auto op =
+ builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+ return op.getOutput();
+ }
+ auto op =
+ builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
+ return op.getOutput();
+ }
+ }
+
+ std::string condition = pred.getCondition();
+ // Build a CPredOp to match the C constraint built.
+ irdl::CPredOp op = builder.create<irdl::CPredOp>(
+ UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
+ return op;
+}
+
+Value typeToConstraint(OpBuilder &builder, MLIRContext *ctx, Type type) {
+ auto op =
+ builder.create<irdl::IsOp>(UnknownLoc::get(ctx), TypeAttr::get(type));
+ return op.getOutput();
+}
+
+std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
+
+ if (predRec.isSubClassOf("I")) {
+ auto width = predRec.getValueAsInt("bitwidth");
+ return IntegerType::get(ctx, width, IntegerType::Signless);
+ }
+
+ if (predRec.isSubClassOf("SI")) {
+ auto width = predRec.getValueAsInt("bitwidth");
+ return IntegerType::get(ctx, width, IntegerType::Signed);
+ }
+
+ if (predRec.isSubClassOf("UI")) {
+ auto width = predRec.getValueAsInt("bitwidth");
+ return IntegerType::get(ctx, width, IntegerType::Unsigned);
+ }
+
+ // Index type
+ if (predRec.getName() == "Index") {
+ return IndexType::get(ctx);
+ }
+
+ // Float types
+ if (predRec.isSubClassOf("F")) {
+ auto width = predRec.getValueAsInt("bitwidth");
+ switch (width) {
+ case 16:
+ return FloatType::getF16(ctx);
+ case 32:
+ return FloatType::getF32(ctx);
+ case 64:
+ return FloatType::getF64(ctx);
+ case 80:
+ return FloatType::getF80(ctx);
+ case 128:
+ return FloatType::getF128(ctx);
+ }
+ }
+
+ if (predRec.getName() == "NoneType") {
+ return NoneType::get(ctx);
+ }
+
+ if (predRec.getName() == "BF16") {
+ return FloatType::getBF16(ctx);
+ }
+
+ if (predRec.getName() == "TF32") {
+ return FloatType::getTF32(ctx);
+ }
+
+ if (predRec.getName() == "F8E4M3FN") {
+ return FloatType::getFloat8E4M3FN(ctx);
+ }
+
+ if (predRec.getName() == "F8E5M2") {
+ return FloatType::getFloat8E5M2(ctx);
+ }
+
+ if (predRec.getName() == "F8E4M3") {
+ return FloatType::getFloat8E4M3(ctx);
+ }
+
+ if (predRec.getName() == "F8E4M3FNUZ") {
+ return FloatType::getFloat8E4M3FNUZ(ctx);
+ }
+
+ if (predRec.getName() == "F8E4M3B11FNUZ") {
+ return FloatType::getFloat8E4M3B11FNUZ(ctx);
+ }
+
+ if (predRec.getName() == "F8E5M2FNUZ") {
+ return FloatType::getFloat8E5M2FNUZ(ctx);
+ }
+
+ if (predRec.getName() == "F8E3M4") {
+ return FloatType::getFloat8E3M4(ctx);
+ }
+
+ if (predRec.isSubClassOf("Complex")) {
+ const Record *elementRec = predRec.getValueAsDef("elementType");
+ auto elementType = recordToType(ctx, *elementRec);
+ if (elementType.has_value()) {
+ return ComplexType::get(elementType.value());
+ }
+ }
+
+ return std::nullopt;
+}
+
Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
MLIRContext *ctx = builder.getContext();
const Record &predRec = constraint.getDef();
@@ -78,11 +202,45 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
return op.getOutput();
}
- std::string condition = constraint.getPredicate().getCondition();
- // Build a CPredOp to match the C constraint built.
- irdl::CPredOp op = builder.create<irdl::CPredOp>(
- UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
- return op;
+ // Integer types
+ if (predRec.getName() == "AnyInteger") {
+ auto op = builder.create<irdl::BaseOp>(
+ UnknownLoc::get(ctx), StringAttr::get(ctx, "!builtin.integer"));
+ return op.getOutput();
+ }
+
+ if (predRec.isSubClassOf("AnyI")) {
+ auto width = predRec.getValueAsInt("bitwidth");
+ std::vector<Value> types = {
+ typeToConstraint(builder, ctx,
+ IntegerType::get(ctx, width, IntegerType::Signless)),
+ typeToConstraint(builder, ctx,
+ IntegerType::get(ctx, width, IntegerType::Signed)),
+ typeToConstraint(builder, ctx,
+ IntegerType::get(ctx, width, IntegerType::Unsigned))};
+ auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), types);
+ return op.getOutput();
+ }
+
+ auto type = recordToType(ctx, predRec);
+
+ if (type.has_value()) {
+ return typeToConstraint(builder, ctx, type.value());
+ }
+
+ // Confined type
+ if (predRec.isSubClassOf("ConfinedType")) {
+ std::vector<Value> constraints;
+ constraints.push_back(createConstraint(
+ builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
+ for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
+ constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
+ }
+ auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+ return op.getOutput();
+ }
+
+ return createPredicate(builder, constraint.getPredicate());
}
/// Returns the name of the operation without the dialect prefix.
@@ -131,10 +289,12 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
// Create the operands and results operations.
- consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
- operandVariadicity);
- consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
- resultVariadicity);
+ if (!operands.empty())
+ consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
+ operandVariadicity);
+ if (!results.empty())
+ consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
+ resultVariadicity);
return op;
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/105505
More information about the Mlir-commits
mailing list