[llvm-branch-commits] [mlir] [mlir][ODS] Switch declarative rewrite rules to properties structs (PR #124876)

Krzysztof Drewniak via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Feb 5 22:31:11 PST 2025


https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/124876

>From 5dc001f21e663d3a2e9dfdaa46b29a8731d21af9 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <krzysdrewniak at gmail.com>
Date: Tue, 28 Jan 2025 20:25:38 -0800
Subject: [PATCH 1/2] [mlir][ODS] Switch declarative rewrite rules to
 properties structs

Now that we have collective builders that take
`const [RelevantOp]::Properties &` arguments, we don't need to serialize
all the attributes that'll be set during an output pattern into a dictionary
attribute. Similarly, we can use the properties struct to get the attributes
instead of needing to go through the big if statement in getAttrOfType<>().

This also enables us to have declarative rewrite rules that match non-attribute
properties in a future PR.

This commit also adds a basic test for the generated matchers since there
didn't seem to already be one.
---
 .../rewriter-attributes-properties.td         | 47 +++++++++++
 mlir/tools/mlir-tblgen/RewriterGen.cpp        | 81 +++++++++++++------
 2 files changed, 105 insertions(+), 23 deletions(-)
 create mode 100644 mlir/test/mlir-tblgen/rewriter-attributes-properties.td

diff --git a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
new file mode 100644
index 000000000000000..77869d36cc12ee4
--- /dev/null
+++ b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
@@ -0,0 +1,47 @@
+// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/PatternBase.td"
+
+def Test_Dialect : Dialect {
+  let name = "test";
+}
+class NS_Op<string mnemonic, list<Trait> traits> :
+    Op<Test_Dialect, mnemonic, traits>;
+
+def AOp : NS_Op<"a_op", []> {
+  let arguments = (ins
+    I32:$x,
+    I32Attr:$y
+  );
+
+  let results = (outs I32:$z);
+}
+
+def BOp : NS_Op<"b_op", []> {
+  let arguments = (ins
+    I32Attr:$y
+  );
+
+  let results = (outs I32:$z);
+}
+
+def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>;
+// CHECK-LABEL: struct test1
+// CHECK: ::llvm::LogicalResult matchAndRewrite
+// CHECK: ::mlir::IntegerAttr y;
+// CHECK: test::BOp x;
+// CHECK: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
+// CHECK: tblgen_ops.push_back(op0);
+// CHECK: x = castedOp1;
+// CHECK: tblgen_attr = castedOp1.getProperties().getY();
+// CHECK: if (!(tblgen_attr))
+// CHECK: y = tblgen_attr;
+// CHECK: tblgen_ops.push_back(op1);
+
+// CHECK: test::AOp tblgen_AOp_0;
+// CHECK: ::llvm::SmallVector<::mlir::Value, 4> tblgen_values;
+// CHECK: test::AOp::Properties tblgen_props;
+// CHECK: tblgen_values.push_back((*x.getODSResults(0).begin()));
+// CHECK: tblgen_props.y = ::llvm::dyn_cast_if_present<decltype(tblgen_props.y)>(y);
+// CHECK: tblgen_AOp_0 = rewriter.create<test::AOp>(odsLoc, tblgen_types, tblgen_values, tblgen_props);
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index a041c4d3277798d..9d8d20798dc8db3 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -122,7 +122,7 @@ class PatternEmitter {
 
   // Emits C++ statements for matching the `argIndex`-th argument of the given
   // DAG `tree` as an attribute.
-  void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
+  void emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex,
                           int depth);
 
   // Emits C++ for checking a match with a corresponding match failure
@@ -664,7 +664,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
                        /*variadicSubIndex=*/std::nullopt);
       ++nextOperand;
     } else if (isa<NamedAttribute *>(opArg)) {
-      emitAttributeMatch(tree, opName, opArgIdx, depth);
+      emitAttributeMatch(tree, castedName, opArgIdx, depth);
     } else {
       PrintFatalError(loc, "unhandled case when matching op");
     }
@@ -864,16 +864,22 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
   os.unindent() << "}\n";
 }
 
-void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
+void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName,
                                         int argIndex, int depth) {
   Operator &op = tree.getDialectOp(opMap);
   auto *namedAttr = cast<NamedAttribute *>(op.getArg(argIndex));
   const auto &attr = namedAttr->attr;
 
   os << "{\n";
-  os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
-                         "(void)tblgen_attr;\n",
-                         opName, attr.getStorageType(), namedAttr->name);
+  if (op.getDialect().usePropertiesForAttributes()) {
+    os.indent() << formatv("auto tblgen_attr = {0}.getProperties().{1}();\n",
+                           castedName, op.getGetterName(namedAttr->name));
+  } else {
+    os.indent() << formatv(
+        "auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
+        "(void)tblgen_attr;\n",
+        castedName, attr.getStorageType(), namedAttr->name);
+  }
 
   // TODO: This should use getter method to avoid duplication.
   if (attr.hasDefaultValue()) {
@@ -887,7 +893,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
     // That is precisely what getDiscardableAttr() returns on missing
     // attributes.
   } else {
-    emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
+    emitMatchCheck(castedName, tgfmt("tblgen_attr", &fmtCtx),
                    formatv("\"expected op '{0}' to have attribute '{1}' "
                            "of type '{2}'\"",
                            op.getOperationName(), namedAttr->name,
@@ -918,7 +924,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
       }
     }
     emitStaticVerifierCall(
-        verifier, opName, "tblgen_attr",
+        verifier, castedName, "tblgen_attr",
         formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
                 "'{2}'\"",
                 op.getOperationName(), namedAttr->name,
@@ -1532,6 +1538,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
   LLVM_DEBUG(llvm::dbgs() << '\n');
 
   Operator &resultOp = tree.getDialectOp(opMap);
+  bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
   auto numOpArgs = resultOp.getNumArgs();
   auto numPatArgs = tree.getNumArgs();
 
@@ -1623,9 +1630,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
     createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
 
     // Then create the op.
-    os.scope("", "\n}\n").os << formatv(
-        "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
-        valuePackName, resultOp.getQualCppClassName(), locToUse);
+    os.scope("", "\n}\n").os
+        << formatv("{0} = rewriter.create<{1}>({2}, tblgen_values, {3});",
+                   valuePackName, resultOp.getQualCppClassName(), locToUse,
+                   useProperties ? "tblgen_props" : "tblgen_attrs");
     return resultValue;
   }
 
@@ -1682,8 +1690,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
     }
   }
   os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
-                "tblgen_values, tblgen_attrs);\n",
-                valuePackName, resultOp.getQualCppClassName(), locToUse);
+                "tblgen_values, {3});\n",
+                valuePackName, resultOp.getQualCppClassName(), locToUse,
+                useProperties ? "tblgen_props" : "tblgen_attrs");
   os.unindent() << "}\n";
   return resultValue;
 }
@@ -1791,12 +1800,21 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
     DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
   Operator &resultOp = node.getDialectOp(opMap);
 
+  bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
   auto scope = os.scope();
   os << formatv("::llvm::SmallVector<::mlir::Value, 4> "
                 "tblgen_values; (void)tblgen_values;\n");
-  os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
-                "tblgen_attrs; (void)tblgen_attrs;\n");
+  if (useProperties) {
+    os << formatv("{0}::Properties tblgen_props; (void)tblgen_props;\n",
+                  resultOp.getQualCppClassName());
+  } else {
+    os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
+                  "tblgen_attrs; (void)tblgen_attrs;\n");
+  }
 
+  const char *setPropCmd =
+      "tblgen_props.{0} = "
+      "::llvm::dyn_cast_if_present<decltype(tblgen_props.{0})>({1});\n";
   const char *addAttrCmd =
       "if (auto tmpAttr = {1}) {\n"
       "  tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
@@ -1814,13 +1832,23 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
         if (!subTree.isNativeCodeCall())
           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
                                "for creating attribute");
-        os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
+
+        if (useProperties) {
+          os << formatv(setPropCmd, opArgName, childNodeNames.lookup(argIndex));
+        } else {
+          os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
+        }
       } else {
         auto leaf = node.getArgAsLeaf(argIndex);
         // The argument in the result DAG pattern.
         auto patArgName = node.getArgName(argIndex);
-        os << formatv(addAttrCmd, opArgName,
-                      handleOpArgument(leaf, patArgName));
+        if (useProperties) {
+          os << formatv(setPropCmd, opArgName,
+                        handleOpArgument(leaf, patArgName));
+        } else {
+          os << formatv(addAttrCmd, opArgName,
+                        handleOpArgument(leaf, patArgName));
+        }
       }
       continue;
     }
@@ -1876,11 +1904,18 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
     const auto *sameVariadicSize =
         resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
     if (!sameVariadicSize) {
-      const char *setSizes = R"(
-        tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
-          rewriter.getDenseI32ArrayAttr({{ {0} }));
-          )";
-      os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+      if (useProperties) {
+        const char *setSizes = R"(
+          tblgen_props.operandSegmentSizes = {{ {0} };
+        )";
+        os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+      } else {
+        const char *setSizes = R"(
+          tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
+            rewriter.getDenseI32ArrayAttr({{ {0} }));
+            )";
+        os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+      }
     }
   }
 }

>From 4e1dc86ab5743cf896b1d0abc2ded8c4540c86fa Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <krzysdrewniak at gmail.com>
Date: Tue, 28 Jan 2025 21:10:01 -0800
Subject: [PATCH 2/2] Test fails on Windows, try to loosen it up

---
 mlir/test/mlir-tblgen/rewriter-attributes-properties.td | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
index 77869d36cc12ee4..fc36a51789ec28d 100644
--- a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
+++ b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
@@ -29,9 +29,9 @@ def BOp : NS_Op<"b_op", []> {
 def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>;
 // CHECK-LABEL: struct test1
 // CHECK: ::llvm::LogicalResult matchAndRewrite
-// CHECK: ::mlir::IntegerAttr y;
-// CHECK: test::BOp x;
-// CHECK: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
+// CHECK-DAG: ::mlir::IntegerAttr y;
+// CHECK-DAG: test::BOp x;
+// CHECK-DAG: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
 // CHECK: tblgen_ops.push_back(op0);
 // CHECK: x = castedOp1;
 // CHECK: tblgen_attr = castedOp1.getProperties().getY();



More information about the llvm-branch-commits mailing list