[Mlir-commits] [mlir] [mlir] [tblgen-to-irdl] Add attributes to tblgen-to-irdl script (PR #109633)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 23 01:21:05 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Alex Rice (alexarice)
<details>
<summary>Changes</summary>
Adds the ability to export attributes from the dialect and attributes of operations in the dialect
@<!-- -->math-fehr
---
Full diff: https://github.com/llvm/llvm-project/pull/109633.diff
3 Files Affected:
- (modified) mlir/include/mlir/IR/CommonAttrConstraints.td (+3)
- (modified) mlir/test/tblgen-to-irdl/TestDialect.td (+17)
- (modified) mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp (+132-7)
``````````diff
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 853fb318c76e71..de5f6797235e3c 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -178,6 +178,7 @@ class AnyAttrOf<list<Attr> allowedAttrs, string summary = "",
summary)> {
let returnType = cppType;
let convertFromStorage = fromStorage;
+ list<Attr> allowedAttributes = allowedAttrs;
}
def LocationAttr : Attr<CPred<"::llvm::isa<::mlir::LocationAttr>($_self)">,
@@ -743,6 +744,8 @@ class ConfinedAttr<Attr attr, list<AttrConstraint> constraints> : Attr<
let isOptional = attr.isOptional;
let baseAttr = attr;
+
+ list<AttrConstraint> attrConstraints = constraints;
}
// An AttrConstraint that holds if all attr constraints specified in
diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td
index 4fea3d8576e9ab..1ba84a5d3683d4 100644
--- a/mlir/test/tblgen-to-irdl/TestDialect.td
+++ b/mlir/test/tblgen-to-irdl/TestDialect.td
@@ -13,6 +13,10 @@ class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
let mnemonic = typeMnemonic;
}
+class Test_Attr<string name, string attrMnemonic> : AttrDef<Test_Dialect, name> {
+ let mnemonic = attrMnemonic;
+}
+
class Test_Op<string mnemonic, list<Trait> traits = []>
: Op<Test_Dialect, mnemonic, traits>;
@@ -22,6 +26,8 @@ def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
// CHECK: irdl.type @"!singleton_c"
def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}
+// CHECK: irdl.attribute @"#test"
+def Test_TestAttr : Test_Attr<"Test", "test"> {}
// Check that AllOfType is converted correctly.
@@ -45,6 +51,17 @@ def Test_AnyOp : Test_Op<"any"> {
// CHECK-NEXT: irdl.operands(%[[v0]])
// CHECK-NEXT: }
+// Check attributes are converted correctly.
+def Test_AttributesOp : Test_Op<"attributes"> {
+ let arguments = (ins I16Attr:$int_attr,
+ Test_TestAttr:$test_attr);
+}
+// CHECK-LABEL: irdl.operation @attributes {
+// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!builtin.integer"
+// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base @test::@"#test"
+// CHECK-NEXT: irdl.attributes {"int_attr" = %[[v0]], "test_attr" = %[[v1]]}
+// CHECK-NEXT: }
+
// Check confined types are converted correctly.
def Test_ConfinedOp : Test_Op<"confined"> {
let arguments = (ins ConfinedType<AnyType, [CPred<"::llvm::isa<::mlir::TensorType>($_self)">]>:$tensor,
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index 45957bafc378e3..d0a3552fb123da 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -74,8 +74,14 @@ Value typeToConstraint(OpBuilder &builder, Type type) {
return op.getOutput();
}
-std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
+Value baseToConstraint(OpBuilder &builder, StringRef baseClass) {
+ MLIRContext *ctx = builder.getContext();
+ auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+ StringAttr::get(ctx, baseClass));
+ return op.getOutput();
+}
+std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
if (predRec.isSubClassOf("I")) {
auto width = predRec.getValueAsInt("bitwidth");
return IntegerType::get(ctx, width, IntegerType::Signless);
@@ -164,12 +170,12 @@ std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
return std::nullopt;
}
-Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
MLIRContext *ctx = builder.getContext();
const Record &predRec = constraint.getDef();
if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
- return createConstraint(builder, predRec.getValueAsDef("baseType"));
+ return createTypeConstraint(builder, predRec.getValueAsDef("baseType"));
if (predRec.getName() == "AnyType") {
auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
@@ -196,7 +202,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
std::vector<Value> constraints;
for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
- createConstraint(builder, tblgen::Constraint(child)));
+ createTypeConstraint(builder, tblgen::Constraint(child)));
}
auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
@@ -206,7 +212,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
std::vector<Value> constraints;
for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
- createConstraint(builder, tblgen::Constraint(child)));
+ createTypeConstraint(builder, tblgen::Constraint(child)));
}
auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
@@ -241,7 +247,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
// Confined type
if (predRec.isSubClassOf("ConfinedType")) {
std::vector<Value> constraints;
- constraints.push_back(createConstraint(
+ constraints.push_back(createTypeConstraint(
builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
@@ -253,6 +259,85 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
return createPredicate(builder, constraint.getPredicate());
}
+Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+ MLIRContext *ctx = builder.getContext();
+ const Record &predRec = constraint.getDef();
+
+ if (predRec.isSubClassOf("DefaultValuedAttr") ||
+ predRec.isSubClassOf("DefaultValuedOptionalAttr") ||
+ predRec.isSubClassOf("OptionalAttr")) {
+ return createAttrConstraint(builder, predRec.getValueAsDef("baseAttr"));
+ }
+
+ if (predRec.isSubClassOf("ConfinedAttr")) {
+ std::vector<Value> constraints;
+ constraints.push_back(createAttrConstraint(
+ builder, tblgen::Constraint(predRec.getValueAsDef("baseAttr"))));
+ for (Record *child : predRec.getValueAsListOfDefs("attrConstraints")) {
+ constraints.push_back(createPredicate(
+ builder, tblgen::Pred(child->getValueAsDef("predicate"))));
+ }
+ auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+ return op.getOutput();
+ }
+
+ if (predRec.isSubClassOf("AnyAttrOf")) {
+ std::vector<Value> constraints;
+ for (Record *child : predRec.getValueAsListOfDefs("allowedAttributes")) {
+ constraints.push_back(
+ createAttrConstraint(builder, tblgen::Constraint(child)));
+ }
+ auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
+ return op.getOutput();
+ }
+
+ if (predRec.getName() == "AnyAttr") {
+ auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
+ return op.getOutput();
+ }
+
+ if (predRec.isSubClassOf("AnyIntegerAttrBase") ||
+ predRec.isSubClassOf("SignlessIntegerAttrBase") ||
+ predRec.isSubClassOf("SignedIntegerAttrBase") ||
+ predRec.isSubClassOf("UnsignedIntegerAttrBase") ||
+ predRec.isSubClassOf("BoolAttr")) {
+ return baseToConstraint(builder, "!builtin.integer");
+ }
+
+ if (predRec.isSubClassOf("FloatAttrBase")) {
+ return baseToConstraint(builder, "!builtin.float");
+ }
+
+ if (predRec.isSubClassOf("StringBasedAttr")) {
+ return baseToConstraint(builder, "!builtin.string");
+ }
+
+ if (predRec.getName() == "UnitAttr") {
+ auto op =
+ builder.create<irdl::IsOp>(UnknownLoc::get(ctx), UnitAttr::get(ctx));
+ return op.getOutput();
+ }
+
+ if (predRec.isSubClassOf("AttrDef")) {
+ auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name");
+ if (dialect == selectedDialect) {
+ std::string combined = ("#" + predRec.getValueAsString("mnemonic")).str();
+ SmallVector<FlatSymbolRefAttr> nested = {SymbolRefAttr::get(ctx, combined)
+
+ };
+ auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
+ auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol);
+ return op.getOutput();
+ }
+ std::string typeName = ("#" + predRec.getValueAsString("attrName")).str();
+ auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+ StringAttr::get(ctx, typeName));
+ return op.getOutput();
+ }
+
+ return createPredicate(builder, constraint.getPredicate());
+}
+
/// Returns the name of the operation without the dialect prefix.
static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
StringRef opName = tblgenOp.getDef().getValueAsString("opName");
@@ -265,6 +350,12 @@ static StringRef getTypeName(tblgen::TypeDef &tblgenType) {
return opName;
}
+/// Returns the name of the attr without the dialect prefix.
+static StringRef getAttrName(tblgen::AttrDef &tblgenType) {
+ StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic");
+ return opName;
+}
+
/// Extract an operation to IRDL.
irdl::OperationOp createIRDLOperation(OpBuilder &builder,
tblgen::Operator &tblgenOp) {
@@ -282,7 +373,7 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
SmallVector<Value> operands;
SmallVector<irdl::VariadicityAttr> variadicity;
for (const NamedTypeConstraint &namedCons : namedCons) {
- auto operand = createConstraint(consBuilder, namedCons.constraint);
+ auto operand = createTypeConstraint(consBuilder, namedCons.constraint);
operands.push_back(operand);
irdl::VariadicityAttr var;
@@ -304,6 +395,15 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
auto [operands, operandVariadicity] = getValues(tblgenOp.getOperands());
auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
+ SmallVector<Value> attributes;
+ SmallVector<Attribute> attrNames;
+ for (auto namedAttr : tblgenOp.getAttributes()) {
+ if (namedAttr.attr.isOptional())
+ continue;
+ attributes.push_back(createAttrConstraint(consBuilder, namedAttr.attr));
+ attrNames.push_back(StringAttr::get(ctx, namedAttr.name));
+ }
+
// Create the operands and results operations.
if (!operands.empty())
consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
@@ -311,6 +411,9 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
if (!results.empty())
consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
resultVariadicity);
+ if (!attributes.empty())
+ consBuilder.create<irdl::AttributesOp>(UnknownLoc::get(ctx), attributes,
+ ArrayAttr::get(ctx, attrNames));
return op;
}
@@ -328,6 +431,20 @@ irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) {
return op;
}
+irdl::AttributeOp createIRDLAttr(OpBuilder &builder,
+ tblgen::AttrDef &tblgenAttr) {
+ MLIRContext *ctx = builder.getContext();
+ StringRef attrName = getAttrName(tblgenAttr);
+ std::string combined = ("#" + attrName).str();
+
+ irdl::AttributeOp op = builder.create<irdl::AttributeOp>(
+ UnknownLoc::get(ctx), StringAttr::get(ctx, combined));
+
+ op.getBody().emplaceBlock();
+
+ return op;
+}
+
static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
MLIRContext *ctx = builder.getContext();
return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx),
@@ -358,6 +475,14 @@ static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper,
createIRDLType(builder, tblgenType);
}
+ for (const Record *attr :
+ recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef")) {
+ tblgen::AttrDef tblgenAttr(attr);
+ if (tblgenAttr.getDialect().getName() != selectedDialect)
+ continue;
+ createIRDLAttr(builder, tblgenAttr);
+ }
+
for (const Record *def :
recordKeeper.getAllDerivedDefinitionsIfDefined("Op")) {
tblgen::Operator tblgenOp(def);
``````````
</details>
https://github.com/llvm/llvm-project/pull/109633
More information about the Mlir-commits
mailing list