[Mlir-commits] [mlir] [mlir] [tblgen-to-irdl] Add attributes to tblgen-to-irdl script (PR #109633)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 23 01:21:04 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-ods

Author: Alex Rice (alexarice)

<details>
<summary>Changes</summary>

Adds the ability to export attributes from the dialect and attributes of operations in the dialect

@<!-- -->math-fehr 

---
Full diff: https://github.com/llvm/llvm-project/pull/109633.diff


3 Files Affected:

- (modified) mlir/include/mlir/IR/CommonAttrConstraints.td (+3) 
- (modified) mlir/test/tblgen-to-irdl/TestDialect.td (+17) 
- (modified) mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp (+132-7) 


``````````diff
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 853fb318c76e71..de5f6797235e3c 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -178,6 +178,7 @@ class AnyAttrOf<list<Attr> allowedAttrs, string summary = "",
         summary)> {
     let returnType = cppType;
     let convertFromStorage = fromStorage;
+    list<Attr> allowedAttributes = allowedAttrs;
 }
 
 def LocationAttr : Attr<CPred<"::llvm::isa<::mlir::LocationAttr>($_self)">,
@@ -743,6 +744,8 @@ class ConfinedAttr<Attr attr, list<AttrConstraint> constraints> : Attr<
   let isOptional = attr.isOptional;
 
   let baseAttr = attr;
+
+  list<AttrConstraint> attrConstraints = constraints;
 }
 
 // An AttrConstraint that holds if all attr constraints specified in
diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td
index 4fea3d8576e9ab..1ba84a5d3683d4 100644
--- a/mlir/test/tblgen-to-irdl/TestDialect.td
+++ b/mlir/test/tblgen-to-irdl/TestDialect.td
@@ -13,6 +13,10 @@ class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
   let mnemonic = typeMnemonic;
 }
 
+class Test_Attr<string name, string attrMnemonic> : AttrDef<Test_Dialect, name> {
+  let mnemonic = attrMnemonic;
+}
+
 class Test_Op<string mnemonic, list<Trait> traits = []>
     : Op<Test_Dialect, mnemonic, traits>;
 
@@ -22,6 +26,8 @@ def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
 def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
 // CHECK: irdl.type @"!singleton_c"
 def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}
+// CHECK: irdl.attribute @"#test"
+def Test_TestAttr : Test_Attr<"Test", "test"> {}
 
 
 // Check that AllOfType is converted correctly.
@@ -45,6 +51,17 @@ def Test_AnyOp : Test_Op<"any"> {
 // CHECK-NEXT:    irdl.operands(%[[v0]])
 // CHECK-NEXT:  }
 
+// Check attributes are converted correctly.
+def Test_AttributesOp : Test_Op<"attributes"> {
+  let arguments = (ins I16Attr:$int_attr,
+                       Test_TestAttr:$test_attr);
+}
+// CHECK-LABEL: irdl.operation @attributes {
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base "!builtin.integer"
+// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.base @test::@"#test"
+// CHECK-NEXT:    irdl.attributes {"int_attr" = %[[v0]], "test_attr" = %[[v1]]}
+// CHECK-NEXT:  }
+
 // Check confined types are converted correctly.
 def Test_ConfinedOp : Test_Op<"confined"> {
   let arguments = (ins ConfinedType<AnyType, [CPred<"::llvm::isa<::mlir::TensorType>($_self)">]>:$tensor,
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index 45957bafc378e3..d0a3552fb123da 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -74,8 +74,14 @@ Value typeToConstraint(OpBuilder &builder, Type type) {
   return op.getOutput();
 }
 
-std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
+Value baseToConstraint(OpBuilder &builder, StringRef baseClass) {
+  MLIRContext *ctx = builder.getContext();
+  auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+                                         StringAttr::get(ctx, baseClass));
+  return op.getOutput();
+}
 
+std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
   if (predRec.isSubClassOf("I")) {
     auto width = predRec.getValueAsInt("bitwidth");
     return IntegerType::get(ctx, width, IntegerType::Signless);
@@ -164,12 +170,12 @@ std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
   return std::nullopt;
 }
 
-Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
   MLIRContext *ctx = builder.getContext();
   const Record &predRec = constraint.getDef();
 
   if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
-    return createConstraint(builder, predRec.getValueAsDef("baseType"));
+    return createTypeConstraint(builder, predRec.getValueAsDef("baseType"));
 
   if (predRec.getName() == "AnyType") {
     auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
@@ -196,7 +202,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
     std::vector<Value> constraints;
     for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
       constraints.push_back(
-          createConstraint(builder, tblgen::Constraint(child)));
+          createTypeConstraint(builder, tblgen::Constraint(child)));
     }
     auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
     return op.getOutput();
@@ -206,7 +212,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
     std::vector<Value> constraints;
     for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
       constraints.push_back(
-          createConstraint(builder, tblgen::Constraint(child)));
+          createTypeConstraint(builder, tblgen::Constraint(child)));
     }
     auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
     return op.getOutput();
@@ -241,7 +247,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
   // Confined type
   if (predRec.isSubClassOf("ConfinedType")) {
     std::vector<Value> constraints;
-    constraints.push_back(createConstraint(
+    constraints.push_back(createTypeConstraint(
         builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
     for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
       constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
@@ -253,6 +259,85 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
   return createPredicate(builder, constraint.getPredicate());
 }
 
+Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+  MLIRContext *ctx = builder.getContext();
+  const Record &predRec = constraint.getDef();
+
+  if (predRec.isSubClassOf("DefaultValuedAttr") ||
+      predRec.isSubClassOf("DefaultValuedOptionalAttr") ||
+      predRec.isSubClassOf("OptionalAttr")) {
+    return createAttrConstraint(builder, predRec.getValueAsDef("baseAttr"));
+  }
+
+  if (predRec.isSubClassOf("ConfinedAttr")) {
+    std::vector<Value> constraints;
+    constraints.push_back(createAttrConstraint(
+        builder, tblgen::Constraint(predRec.getValueAsDef("baseAttr"))));
+    for (Record *child : predRec.getValueAsListOfDefs("attrConstraints")) {
+      constraints.push_back(createPredicate(
+          builder, tblgen::Pred(child->getValueAsDef("predicate"))));
+    }
+    auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AnyAttrOf")) {
+    std::vector<Value> constraints;
+    for (Record *child : predRec.getValueAsListOfDefs("allowedAttributes")) {
+      constraints.push_back(
+          createAttrConstraint(builder, tblgen::Constraint(child)));
+    }
+    auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
+    return op.getOutput();
+  }
+
+  if (predRec.getName() == "AnyAttr") {
+    auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AnyIntegerAttrBase") ||
+      predRec.isSubClassOf("SignlessIntegerAttrBase") ||
+      predRec.isSubClassOf("SignedIntegerAttrBase") ||
+      predRec.isSubClassOf("UnsignedIntegerAttrBase") ||
+      predRec.isSubClassOf("BoolAttr")) {
+    return baseToConstraint(builder, "!builtin.integer");
+  }
+
+  if (predRec.isSubClassOf("FloatAttrBase")) {
+    return baseToConstraint(builder, "!builtin.float");
+  }
+
+  if (predRec.isSubClassOf("StringBasedAttr")) {
+    return baseToConstraint(builder, "!builtin.string");
+  }
+
+  if (predRec.getName() == "UnitAttr") {
+    auto op =
+        builder.create<irdl::IsOp>(UnknownLoc::get(ctx), UnitAttr::get(ctx));
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AttrDef")) {
+    auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name");
+    if (dialect == selectedDialect) {
+      std::string combined = ("#" + predRec.getValueAsString("mnemonic")).str();
+      SmallVector<FlatSymbolRefAttr> nested = {SymbolRefAttr::get(ctx, combined)
+
+      };
+      auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
+      auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol);
+      return op.getOutput();
+    }
+    std::string typeName = ("#" + predRec.getValueAsString("attrName")).str();
+    auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+                                           StringAttr::get(ctx, typeName));
+    return op.getOutput();
+  }
+
+  return createPredicate(builder, constraint.getPredicate());
+}
+
 /// Returns the name of the operation without the dialect prefix.
 static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
   StringRef opName = tblgenOp.getDef().getValueAsString("opName");
@@ -265,6 +350,12 @@ static StringRef getTypeName(tblgen::TypeDef &tblgenType) {
   return opName;
 }
 
+/// Returns the name of the attr without the dialect prefix.
+static StringRef getAttrName(tblgen::AttrDef &tblgenType) {
+  StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic");
+  return opName;
+}
+
 /// Extract an operation to IRDL.
 irdl::OperationOp createIRDLOperation(OpBuilder &builder,
                                       tblgen::Operator &tblgenOp) {
@@ -282,7 +373,7 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
     SmallVector<Value> operands;
     SmallVector<irdl::VariadicityAttr> variadicity;
     for (const NamedTypeConstraint &namedCons : namedCons) {
-      auto operand = createConstraint(consBuilder, namedCons.constraint);
+      auto operand = createTypeConstraint(consBuilder, namedCons.constraint);
       operands.push_back(operand);
 
       irdl::VariadicityAttr var;
@@ -304,6 +395,15 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
   auto [operands, operandVariadicity] = getValues(tblgenOp.getOperands());
   auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
 
+  SmallVector<Value> attributes;
+  SmallVector<Attribute> attrNames;
+  for (auto namedAttr : tblgenOp.getAttributes()) {
+    if (namedAttr.attr.isOptional())
+      continue;
+    attributes.push_back(createAttrConstraint(consBuilder, namedAttr.attr));
+    attrNames.push_back(StringAttr::get(ctx, namedAttr.name));
+  }
+
   // Create the operands and results operations.
   if (!operands.empty())
     consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
@@ -311,6 +411,9 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
   if (!results.empty())
     consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
                                         resultVariadicity);
+  if (!attributes.empty())
+    consBuilder.create<irdl::AttributesOp>(UnknownLoc::get(ctx), attributes,
+                                           ArrayAttr::get(ctx, attrNames));
 
   return op;
 }
@@ -328,6 +431,20 @@ irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) {
   return op;
 }
 
+irdl::AttributeOp createIRDLAttr(OpBuilder &builder,
+                                 tblgen::AttrDef &tblgenAttr) {
+  MLIRContext *ctx = builder.getContext();
+  StringRef attrName = getAttrName(tblgenAttr);
+  std::string combined = ("#" + attrName).str();
+
+  irdl::AttributeOp op = builder.create<irdl::AttributeOp>(
+      UnknownLoc::get(ctx), StringAttr::get(ctx, combined));
+
+  op.getBody().emplaceBlock();
+
+  return op;
+}
+
 static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
   MLIRContext *ctx = builder.getContext();
   return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx),
@@ -358,6 +475,14 @@ static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper,
     createIRDLType(builder, tblgenType);
   }
 
+  for (const Record *attr :
+       recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef")) {
+    tblgen::AttrDef tblgenAttr(attr);
+    if (tblgenAttr.getDialect().getName() != selectedDialect)
+      continue;
+    createIRDLAttr(builder, tblgenAttr);
+  }
+
   for (const Record *def :
        recordKeeper.getAllDerivedDefinitionsIfDefined("Op")) {
     tblgen::Operator tblgenOp(def);

``````````

</details>


https://github.com/llvm/llvm-project/pull/109633


More information about the Mlir-commits mailing list