[Mlir-commits] [mlir] [mlir][ODS] Switch declarative rewrite rules to properties structs (PR #124876)

Krzysztof Drewniak llvmlistbot at llvm.org
Mon Mar 10 19:57:45 PDT 2025


https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/124876

>From 9b2288bde8bb68e6900a5c41d2b5162c6e50fae8 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <krzysdrewniak at gmail.com>
Date: Tue, 28 Jan 2025 20:25:38 -0800
Subject: [PATCH 1/3] [mlir][ODS] Switch declarative rewrite rules to
 properties structs

Now that we have collective builders that take
`const [RelevantOp]::Properties &` arguments, we don't need to serialize
all the attributes that'll be set during an output pattern into a dictionary
attribute. Similarly, we can use the properties struct to get the attributes
instead of needing to go through the big if statement in getAttrOfType<>().

This also enables us to have declarative rewrite rules that match non-attribute
properties in a future PR.

This commit also adds a basic test for the generated matchers since there
didn't seem to already be one.
---
 .../rewriter-attributes-properties.td         | 47 +++++++++++
 mlir/tools/mlir-tblgen/RewriterGen.cpp        | 81 +++++++++++++------
 2 files changed, 105 insertions(+), 23 deletions(-)
 create mode 100644 mlir/test/mlir-tblgen/rewriter-attributes-properties.td

diff --git a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
new file mode 100644
index 0000000000000..77869d36cc12e
--- /dev/null
+++ b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
@@ -0,0 +1,47 @@
+// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/PatternBase.td"
+
+def Test_Dialect : Dialect {
+  let name = "test";
+}
+class NS_Op<string mnemonic, list<Trait> traits> :
+    Op<Test_Dialect, mnemonic, traits>;
+
+def AOp : NS_Op<"a_op", []> {
+  let arguments = (ins
+    I32:$x,
+    I32Attr:$y
+  );
+
+  let results = (outs I32:$z);
+}
+
+def BOp : NS_Op<"b_op", []> {
+  let arguments = (ins
+    I32Attr:$y
+  );
+
+  let results = (outs I32:$z);
+}
+
+def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>;
+// CHECK-LABEL: struct test1
+// CHECK: ::llvm::LogicalResult matchAndRewrite
+// CHECK: ::mlir::IntegerAttr y;
+// CHECK: test::BOp x;
+// CHECK: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
+// CHECK: tblgen_ops.push_back(op0);
+// CHECK: x = castedOp1;
+// CHECK: tblgen_attr = castedOp1.getProperties().getY();
+// CHECK: if (!(tblgen_attr))
+// CHECK: y = tblgen_attr;
+// CHECK: tblgen_ops.push_back(op1);
+
+// CHECK: test::AOp tblgen_AOp_0;
+// CHECK: ::llvm::SmallVector<::mlir::Value, 4> tblgen_values;
+// CHECK: test::AOp::Properties tblgen_props;
+// CHECK: tblgen_values.push_back((*x.getODSResults(0).begin()));
+// CHECK: tblgen_props.y = ::llvm::dyn_cast_if_present<decltype(tblgen_props.y)>(y);
+// CHECK: tblgen_AOp_0 = rewriter.create<test::AOp>(odsLoc, tblgen_types, tblgen_values, tblgen_props);
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index f6eb5bdfe568e..5dd4f87a6d0ce 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -122,7 +122,7 @@ class PatternEmitter {
 
   // Emits C++ statements for matching the `argIndex`-th argument of the given
   // DAG `tree` as an attribute.
-  void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
+  void emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex,
                           int depth);
 
   // Emits C++ for checking a match with a corresponding match failure
@@ -664,7 +664,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
                        /*variadicSubIndex=*/std::nullopt);
       ++nextOperand;
     } else if (isa<NamedAttribute *>(opArg)) {
-      emitAttributeMatch(tree, opName, opArgIdx, depth);
+      emitAttributeMatch(tree, castedName, opArgIdx, depth);
     } else {
       PrintFatalError(loc, "unhandled case when matching op");
     }
@@ -864,16 +864,22 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
   os.unindent() << "}\n";
 }
 
-void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
+void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName,
                                         int argIndex, int depth) {
   Operator &op = tree.getDialectOp(opMap);
   auto *namedAttr = cast<NamedAttribute *>(op.getArg(argIndex));
   const auto &attr = namedAttr->attr;
 
   os << "{\n";
-  os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
-                         "(void)tblgen_attr;\n",
-                         opName, attr.getStorageType(), namedAttr->name);
+  if (op.getDialect().usePropertiesForAttributes()) {
+    os.indent() << formatv("auto tblgen_attr = {0}.getProperties().{1}();\n",
+                           castedName, op.getGetterName(namedAttr->name));
+  } else {
+    os.indent() << formatv(
+        "auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
+        "(void)tblgen_attr;\n",
+        castedName, attr.getStorageType(), namedAttr->name);
+  }
 
   // TODO: This should use getter method to avoid duplication.
   if (attr.hasDefaultValue()) {
@@ -887,7 +893,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
     // That is precisely what getDiscardableAttr() returns on missing
     // attributes.
   } else {
-    emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
+    emitMatchCheck(castedName, tgfmt("tblgen_attr", &fmtCtx),
                    formatv("\"expected op '{0}' to have attribute '{1}' "
                            "of type '{2}'\"",
                            op.getOperationName(), namedAttr->name,
@@ -918,7 +924,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
       }
     }
     emitStaticVerifierCall(
-        verifier, opName, "tblgen_attr",
+        verifier, castedName, "tblgen_attr",
         formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
                 "'{2}'\"",
                 op.getOperationName(), namedAttr->name,
@@ -1532,6 +1538,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
   LLVM_DEBUG(llvm::dbgs() << '\n');
 
   Operator &resultOp = tree.getDialectOp(opMap);
+  bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
   auto numOpArgs = resultOp.getNumArgs();
   auto numPatArgs = tree.getNumArgs();
 
@@ -1623,9 +1630,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
     createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
 
     // Then create the op.
-    os.scope("", "\n}\n").os << formatv(
-        "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
-        valuePackName, resultOp.getQualCppClassName(), locToUse);
+    os.scope("", "\n}\n").os
+        << formatv("{0} = rewriter.create<{1}>({2}, tblgen_values, {3});",
+                   valuePackName, resultOp.getQualCppClassName(), locToUse,
+                   useProperties ? "tblgen_props" : "tblgen_attrs");
     return resultValue;
   }
 
@@ -1682,8 +1690,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
     }
   }
   os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
-                "tblgen_values, tblgen_attrs);\n",
-                valuePackName, resultOp.getQualCppClassName(), locToUse);
+                "tblgen_values, {3});\n",
+                valuePackName, resultOp.getQualCppClassName(), locToUse,
+                useProperties ? "tblgen_props" : "tblgen_attrs");
   os.unindent() << "}\n";
   return resultValue;
 }
@@ -1791,12 +1800,21 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
     DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
   Operator &resultOp = node.getDialectOp(opMap);
 
+  bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
   auto scope = os.scope();
   os << formatv("::llvm::SmallVector<::mlir::Value, 4> "
                 "tblgen_values; (void)tblgen_values;\n");
-  os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
-                "tblgen_attrs; (void)tblgen_attrs;\n");
+  if (useProperties) {
+    os << formatv("{0}::Properties tblgen_props; (void)tblgen_props;\n",
+                  resultOp.getQualCppClassName());
+  } else {
+    os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
+                  "tblgen_attrs; (void)tblgen_attrs;\n");
+  }
 
+  const char *setPropCmd =
+      "tblgen_props.{0} = "
+      "::llvm::dyn_cast_if_present<decltype(tblgen_props.{0})>({1});\n";
   const char *addAttrCmd =
       "if (auto tmpAttr = {1}) {\n"
       "  tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
@@ -1814,13 +1832,23 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
         if (!subTree.isNativeCodeCall())
           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
                                "for creating attribute");
-        os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
+
+        if (useProperties) {
+          os << formatv(setPropCmd, opArgName, childNodeNames.lookup(argIndex));
+        } else {
+          os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
+        }
       } else {
         auto leaf = node.getArgAsLeaf(argIndex);
         // The argument in the result DAG pattern.
         auto patArgName = node.getArgName(argIndex);
-        os << formatv(addAttrCmd, opArgName,
-                      handleOpArgument(leaf, patArgName));
+        if (useProperties) {
+          os << formatv(setPropCmd, opArgName,
+                        handleOpArgument(leaf, patArgName));
+        } else {
+          os << formatv(addAttrCmd, opArgName,
+                        handleOpArgument(leaf, patArgName));
+        }
       }
       continue;
     }
@@ -1876,11 +1904,18 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
     const auto *sameVariadicSize =
         resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
     if (!sameVariadicSize) {
-      const char *setSizes = R"(
-        tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
-          rewriter.getDenseI32ArrayAttr({{ {0} }));
-          )";
-      os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+      if (useProperties) {
+        const char *setSizes = R"(
+          tblgen_props.operandSegmentSizes = {{ {0} };
+        )";
+        os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+      } else {
+        const char *setSizes = R"(
+          tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
+            rewriter.getDenseI32ArrayAttr({{ {0} }));
+            )";
+        os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+      }
     }
   }
 }

>From dbd8b51f0d27e3f2ba035c0dc2b27a95cdfea806 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <krzysdrewniak at gmail.com>
Date: Tue, 28 Jan 2025 21:10:01 -0800
Subject: [PATCH 2/3] Test fails on Windows, try to loosen it up

---
 mlir/test/mlir-tblgen/rewriter-attributes-properties.td | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
index 77869d36cc12e..fc36a51789ec2 100644
--- a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
+++ b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
@@ -29,9 +29,9 @@ def BOp : NS_Op<"b_op", []> {
 def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>;
 // CHECK-LABEL: struct test1
 // CHECK: ::llvm::LogicalResult matchAndRewrite
-// CHECK: ::mlir::IntegerAttr y;
-// CHECK: test::BOp x;
-// CHECK: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
+// CHECK-DAG: ::mlir::IntegerAttr y;
+// CHECK-DAG: test::BOp x;
+// CHECK-DAG: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
 // CHECK: tblgen_ops.push_back(op0);
 // CHECK: x = castedOp1;
 // CHECK: tblgen_attr = castedOp1.getProperties().getY();

>From 0ce3cecc51d79cc187a3afcd3cdccc4b7cff903b Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <krzysdrewniak at gmail.com>
Date: Mon, 10 Mar 2025 19:57:27 -0700
Subject: [PATCH 3/3] Review suggestions

---
 mlir/tools/mlir-tblgen/RewriterGen.cpp | 16 ++++------------
 1 file changed, 4 insertions(+), 12 deletions(-)

diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 5dd4f87a6d0ce..f921788abdd71 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1819,6 +1819,8 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
       "if (auto tmpAttr = {1}) {\n"
       "  tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
       "tmpAttr);\n}\n";
+  const char *setterCmd = (useProperties) ? setPropCmd : addAttrCmd;
+
   int numVariadic = 0;
   bool hasOperandSegmentSizes = false;
   std::vector<std::string> sizes;
@@ -1833,22 +1835,12 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
                                "for creating attribute");
 
-        if (useProperties) {
-          os << formatv(setPropCmd, opArgName, childNodeNames.lookup(argIndex));
-        } else {
-          os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
-        }
+        os << formatv(setterCmd, opArgName, childNodeNames.lookup(argIndex));
       } else {
         auto leaf = node.getArgAsLeaf(argIndex);
         // The argument in the result DAG pattern.
         auto patArgName = node.getArgName(argIndex);
-        if (useProperties) {
-          os << formatv(setPropCmd, opArgName,
-                        handleOpArgument(leaf, patArgName));
-        } else {
-          os << formatv(addAttrCmd, opArgName,
-                        handleOpArgument(leaf, patArgName));
-        }
+        os << formatv(setterCmd, opArgName, handleOpArgument(leaf, patArgName));
       }
       continue;
     }



More information about the Mlir-commits mailing list