[Mlir-commits] [mlir] [mlir][ODS] Switch declarative rewrite rules to properties structs (PR #124876)
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Mar 10 19:57:45 PDT 2025
https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/124876
>From 9b2288bde8bb68e6900a5c41d2b5162c6e50fae8 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <krzysdrewniak at gmail.com>
Date: Tue, 28 Jan 2025 20:25:38 -0800
Subject: [PATCH 1/3] [mlir][ODS] Switch declarative rewrite rules to
properties structs
Now that we have collective builders that take
`const [RelevantOp]::Properties &` arguments, we don't need to serialize
all the attributes that'll be set during an output pattern into a dictionary
attribute. Similarly, we can use the properties struct to get the attributes
instead of needing to go through the big if statement in getAttrOfType<>().
This also enables us to have declarative rewrite rules that match non-attribute
properties in a future PR.
This commit also adds a basic test for the generated matchers since there
didn't seem to already be one.
---
.../rewriter-attributes-properties.td | 47 +++++++++++
mlir/tools/mlir-tblgen/RewriterGen.cpp | 81 +++++++++++++------
2 files changed, 105 insertions(+), 23 deletions(-)
create mode 100644 mlir/test/mlir-tblgen/rewriter-attributes-properties.td
diff --git a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
new file mode 100644
index 0000000000000..77869d36cc12e
--- /dev/null
+++ b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
@@ -0,0 +1,47 @@
+// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/PatternBase.td"
+
+def Test_Dialect : Dialect {
+ let name = "test";
+}
+class NS_Op<string mnemonic, list<Trait> traits> :
+ Op<Test_Dialect, mnemonic, traits>;
+
+def AOp : NS_Op<"a_op", []> {
+ let arguments = (ins
+ I32:$x,
+ I32Attr:$y
+ );
+
+ let results = (outs I32:$z);
+}
+
+def BOp : NS_Op<"b_op", []> {
+ let arguments = (ins
+ I32Attr:$y
+ );
+
+ let results = (outs I32:$z);
+}
+
+def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>;
+// CHECK-LABEL: struct test1
+// CHECK: ::llvm::LogicalResult matchAndRewrite
+// CHECK: ::mlir::IntegerAttr y;
+// CHECK: test::BOp x;
+// CHECK: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
+// CHECK: tblgen_ops.push_back(op0);
+// CHECK: x = castedOp1;
+// CHECK: tblgen_attr = castedOp1.getProperties().getY();
+// CHECK: if (!(tblgen_attr))
+// CHECK: y = tblgen_attr;
+// CHECK: tblgen_ops.push_back(op1);
+
+// CHECK: test::AOp tblgen_AOp_0;
+// CHECK: ::llvm::SmallVector<::mlir::Value, 4> tblgen_values;
+// CHECK: test::AOp::Properties tblgen_props;
+// CHECK: tblgen_values.push_back((*x.getODSResults(0).begin()));
+// CHECK: tblgen_props.y = ::llvm::dyn_cast_if_present<decltype(tblgen_props.y)>(y);
+// CHECK: tblgen_AOp_0 = rewriter.create<test::AOp>(odsLoc, tblgen_types, tblgen_values, tblgen_props);
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index f6eb5bdfe568e..5dd4f87a6d0ce 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -122,7 +122,7 @@ class PatternEmitter {
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an attribute.
- void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
+ void emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex,
int depth);
// Emits C++ for checking a match with a corresponding match failure
@@ -664,7 +664,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
/*variadicSubIndex=*/std::nullopt);
++nextOperand;
} else if (isa<NamedAttribute *>(opArg)) {
- emitAttributeMatch(tree, opName, opArgIdx, depth);
+ emitAttributeMatch(tree, castedName, opArgIdx, depth);
} else {
PrintFatalError(loc, "unhandled case when matching op");
}
@@ -864,16 +864,22 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
os.unindent() << "}\n";
}
-void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
+void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName,
int argIndex, int depth) {
Operator &op = tree.getDialectOp(opMap);
auto *namedAttr = cast<NamedAttribute *>(op.getArg(argIndex));
const auto &attr = namedAttr->attr;
os << "{\n";
- os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
- "(void)tblgen_attr;\n",
- opName, attr.getStorageType(), namedAttr->name);
+ if (op.getDialect().usePropertiesForAttributes()) {
+ os.indent() << formatv("auto tblgen_attr = {0}.getProperties().{1}();\n",
+ castedName, op.getGetterName(namedAttr->name));
+ } else {
+ os.indent() << formatv(
+ "auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
+ "(void)tblgen_attr;\n",
+ castedName, attr.getStorageType(), namedAttr->name);
+ }
// TODO: This should use getter method to avoid duplication.
if (attr.hasDefaultValue()) {
@@ -887,7 +893,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
// That is precisely what getDiscardableAttr() returns on missing
// attributes.
} else {
- emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
+ emitMatchCheck(castedName, tgfmt("tblgen_attr", &fmtCtx),
formatv("\"expected op '{0}' to have attribute '{1}' "
"of type '{2}'\"",
op.getOperationName(), namedAttr->name,
@@ -918,7 +924,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
}
}
emitStaticVerifierCall(
- verifier, opName, "tblgen_attr",
+ verifier, castedName, "tblgen_attr",
formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
"'{2}'\"",
op.getOperationName(), namedAttr->name,
@@ -1532,6 +1538,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
LLVM_DEBUG(llvm::dbgs() << '\n');
Operator &resultOp = tree.getDialectOp(opMap);
+ bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
auto numOpArgs = resultOp.getNumArgs();
auto numPatArgs = tree.getNumArgs();
@@ -1623,9 +1630,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
// Then create the op.
- os.scope("", "\n}\n").os << formatv(
- "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
- valuePackName, resultOp.getQualCppClassName(), locToUse);
+ os.scope("", "\n}\n").os
+ << formatv("{0} = rewriter.create<{1}>({2}, tblgen_values, {3});",
+ valuePackName, resultOp.getQualCppClassName(), locToUse,
+ useProperties ? "tblgen_props" : "tblgen_attrs");
return resultValue;
}
@@ -1682,8 +1690,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
}
}
os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
- "tblgen_values, tblgen_attrs);\n",
- valuePackName, resultOp.getQualCppClassName(), locToUse);
+ "tblgen_values, {3});\n",
+ valuePackName, resultOp.getQualCppClassName(), locToUse,
+ useProperties ? "tblgen_props" : "tblgen_attrs");
os.unindent() << "}\n";
return resultValue;
}
@@ -1791,12 +1800,21 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
Operator &resultOp = node.getDialectOp(opMap);
+ bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
auto scope = os.scope();
os << formatv("::llvm::SmallVector<::mlir::Value, 4> "
"tblgen_values; (void)tblgen_values;\n");
- os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
- "tblgen_attrs; (void)tblgen_attrs;\n");
+ if (useProperties) {
+ os << formatv("{0}::Properties tblgen_props; (void)tblgen_props;\n",
+ resultOp.getQualCppClassName());
+ } else {
+ os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
+ "tblgen_attrs; (void)tblgen_attrs;\n");
+ }
+ const char *setPropCmd =
+ "tblgen_props.{0} = "
+ "::llvm::dyn_cast_if_present<decltype(tblgen_props.{0})>({1});\n";
const char *addAttrCmd =
"if (auto tmpAttr = {1}) {\n"
" tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
@@ -1814,13 +1832,23 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
if (!subTree.isNativeCodeCall())
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
- os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
+
+ if (useProperties) {
+ os << formatv(setPropCmd, opArgName, childNodeNames.lookup(argIndex));
+ } else {
+ os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
+ }
} else {
auto leaf = node.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern.
auto patArgName = node.getArgName(argIndex);
- os << formatv(addAttrCmd, opArgName,
- handleOpArgument(leaf, patArgName));
+ if (useProperties) {
+ os << formatv(setPropCmd, opArgName,
+ handleOpArgument(leaf, patArgName));
+ } else {
+ os << formatv(addAttrCmd, opArgName,
+ handleOpArgument(leaf, patArgName));
+ }
}
continue;
}
@@ -1876,11 +1904,18 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
const auto *sameVariadicSize =
resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
if (!sameVariadicSize) {
- const char *setSizes = R"(
- tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
- rewriter.getDenseI32ArrayAttr({{ {0} }));
- )";
- os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+ if (useProperties) {
+ const char *setSizes = R"(
+ tblgen_props.operandSegmentSizes = {{ {0} };
+ )";
+ os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+ } else {
+ const char *setSizes = R"(
+ tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
+ rewriter.getDenseI32ArrayAttr({{ {0} }));
+ )";
+ os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+ }
}
}
}
>From dbd8b51f0d27e3f2ba035c0dc2b27a95cdfea806 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <krzysdrewniak at gmail.com>
Date: Tue, 28 Jan 2025 21:10:01 -0800
Subject: [PATCH 2/3] Test fails on Windows, try to loosen it up
---
mlir/test/mlir-tblgen/rewriter-attributes-properties.td | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
index 77869d36cc12e..fc36a51789ec2 100644
--- a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
+++ b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
@@ -29,9 +29,9 @@ def BOp : NS_Op<"b_op", []> {
def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>;
// CHECK-LABEL: struct test1
// CHECK: ::llvm::LogicalResult matchAndRewrite
-// CHECK: ::mlir::IntegerAttr y;
-// CHECK: test::BOp x;
-// CHECK: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
+// CHECK-DAG: ::mlir::IntegerAttr y;
+// CHECK-DAG: test::BOp x;
+// CHECK-DAG: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
// CHECK: tblgen_ops.push_back(op0);
// CHECK: x = castedOp1;
// CHECK: tblgen_attr = castedOp1.getProperties().getY();
>From 0ce3cecc51d79cc187a3afcd3cdccc4b7cff903b Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <krzysdrewniak at gmail.com>
Date: Mon, 10 Mar 2025 19:57:27 -0700
Subject: [PATCH 3/3] Review suggestions
---
mlir/tools/mlir-tblgen/RewriterGen.cpp | 16 ++++------------
1 file changed, 4 insertions(+), 12 deletions(-)
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 5dd4f87a6d0ce..f921788abdd71 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1819,6 +1819,8 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
"if (auto tmpAttr = {1}) {\n"
" tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
"tmpAttr);\n}\n";
+ const char *setterCmd = (useProperties) ? setPropCmd : addAttrCmd;
+
int numVariadic = 0;
bool hasOperandSegmentSizes = false;
std::vector<std::string> sizes;
@@ -1833,22 +1835,12 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
- if (useProperties) {
- os << formatv(setPropCmd, opArgName, childNodeNames.lookup(argIndex));
- } else {
- os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
- }
+ os << formatv(setterCmd, opArgName, childNodeNames.lookup(argIndex));
} else {
auto leaf = node.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern.
auto patArgName = node.getArgName(argIndex);
- if (useProperties) {
- os << formatv(setPropCmd, opArgName,
- handleOpArgument(leaf, patArgName));
- } else {
- os << formatv(addAttrCmd, opArgName,
- handleOpArgument(leaf, patArgName));
- }
+ os << formatv(setterCmd, opArgName, handleOpArgument(leaf, patArgName));
}
continue;
}
More information about the Mlir-commits
mailing list