[llvm-branch-commits] [mlir] [mlir][ODS] Switch declarative rewrite rules to properties structs (PR #124876)
Krzysztof Drewniak via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Feb 5 22:31:11 PST 2025
https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/124876
>From 5dc001f21e663d3a2e9dfdaa46b29a8731d21af9 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <krzysdrewniak at gmail.com>
Date: Tue, 28 Jan 2025 20:25:38 -0800
Subject: [PATCH 1/2] [mlir][ODS] Switch declarative rewrite rules to
properties structs
Now that we have collective builders that take
`const [RelevantOp]::Properties &` arguments, we don't need to serialize
all the attributes that'll be set during an output pattern into a dictionary
attribute. Similarly, we can use the properties struct to get the attributes
instead of needing to go through the big if statement in getAttrOfType<>().
This also enables us to have declarative rewrite rules that match non-attribute
properties in a future PR.
This commit also adds a basic test for the generated matchers since there
didn't seem to already be one.
---
.../rewriter-attributes-properties.td | 47 +++++++++++
mlir/tools/mlir-tblgen/RewriterGen.cpp | 81 +++++++++++++------
2 files changed, 105 insertions(+), 23 deletions(-)
create mode 100644 mlir/test/mlir-tblgen/rewriter-attributes-properties.td
diff --git a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
new file mode 100644
index 000000000000000..77869d36cc12ee4
--- /dev/null
+++ b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
@@ -0,0 +1,47 @@
+// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/PatternBase.td"
+
+def Test_Dialect : Dialect {
+ let name = "test";
+}
+class NS_Op<string mnemonic, list<Trait> traits> :
+ Op<Test_Dialect, mnemonic, traits>;
+
+def AOp : NS_Op<"a_op", []> {
+ let arguments = (ins
+ I32:$x,
+ I32Attr:$y
+ );
+
+ let results = (outs I32:$z);
+}
+
+def BOp : NS_Op<"b_op", []> {
+ let arguments = (ins
+ I32Attr:$y
+ );
+
+ let results = (outs I32:$z);
+}
+
+def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>;
+// CHECK-LABEL: struct test1
+// CHECK: ::llvm::LogicalResult matchAndRewrite
+// CHECK: ::mlir::IntegerAttr y;
+// CHECK: test::BOp x;
+// CHECK: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
+// CHECK: tblgen_ops.push_back(op0);
+// CHECK: x = castedOp1;
+// CHECK: tblgen_attr = castedOp1.getProperties().getY();
+// CHECK: if (!(tblgen_attr))
+// CHECK: y = tblgen_attr;
+// CHECK: tblgen_ops.push_back(op1);
+
+// CHECK: test::AOp tblgen_AOp_0;
+// CHECK: ::llvm::SmallVector<::mlir::Value, 4> tblgen_values;
+// CHECK: test::AOp::Properties tblgen_props;
+// CHECK: tblgen_values.push_back((*x.getODSResults(0).begin()));
+// CHECK: tblgen_props.y = ::llvm::dyn_cast_if_present<decltype(tblgen_props.y)>(y);
+// CHECK: tblgen_AOp_0 = rewriter.create<test::AOp>(odsLoc, tblgen_types, tblgen_values, tblgen_props);
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index a041c4d3277798d..9d8d20798dc8db3 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -122,7 +122,7 @@ class PatternEmitter {
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an attribute.
- void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
+ void emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex,
int depth);
// Emits C++ for checking a match with a corresponding match failure
@@ -664,7 +664,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
/*variadicSubIndex=*/std::nullopt);
++nextOperand;
} else if (isa<NamedAttribute *>(opArg)) {
- emitAttributeMatch(tree, opName, opArgIdx, depth);
+ emitAttributeMatch(tree, castedName, opArgIdx, depth);
} else {
PrintFatalError(loc, "unhandled case when matching op");
}
@@ -864,16 +864,22 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
os.unindent() << "}\n";
}
-void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
+void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName,
int argIndex, int depth) {
Operator &op = tree.getDialectOp(opMap);
auto *namedAttr = cast<NamedAttribute *>(op.getArg(argIndex));
const auto &attr = namedAttr->attr;
os << "{\n";
- os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
- "(void)tblgen_attr;\n",
- opName, attr.getStorageType(), namedAttr->name);
+ if (op.getDialect().usePropertiesForAttributes()) {
+ os.indent() << formatv("auto tblgen_attr = {0}.getProperties().{1}();\n",
+ castedName, op.getGetterName(namedAttr->name));
+ } else {
+ os.indent() << formatv(
+ "auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
+ "(void)tblgen_attr;\n",
+ castedName, attr.getStorageType(), namedAttr->name);
+ }
// TODO: This should use getter method to avoid duplication.
if (attr.hasDefaultValue()) {
@@ -887,7 +893,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
// That is precisely what getDiscardableAttr() returns on missing
// attributes.
} else {
- emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
+ emitMatchCheck(castedName, tgfmt("tblgen_attr", &fmtCtx),
formatv("\"expected op '{0}' to have attribute '{1}' "
"of type '{2}'\"",
op.getOperationName(), namedAttr->name,
@@ -918,7 +924,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
}
}
emitStaticVerifierCall(
- verifier, opName, "tblgen_attr",
+ verifier, castedName, "tblgen_attr",
formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
"'{2}'\"",
op.getOperationName(), namedAttr->name,
@@ -1532,6 +1538,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
LLVM_DEBUG(llvm::dbgs() << '\n');
Operator &resultOp = tree.getDialectOp(opMap);
+ bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
auto numOpArgs = resultOp.getNumArgs();
auto numPatArgs = tree.getNumArgs();
@@ -1623,9 +1630,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
// Then create the op.
- os.scope("", "\n}\n").os << formatv(
- "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
- valuePackName, resultOp.getQualCppClassName(), locToUse);
+ os.scope("", "\n}\n").os
+ << formatv("{0} = rewriter.create<{1}>({2}, tblgen_values, {3});",
+ valuePackName, resultOp.getQualCppClassName(), locToUse,
+ useProperties ? "tblgen_props" : "tblgen_attrs");
return resultValue;
}
@@ -1682,8 +1690,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
}
}
os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
- "tblgen_values, tblgen_attrs);\n",
- valuePackName, resultOp.getQualCppClassName(), locToUse);
+ "tblgen_values, {3});\n",
+ valuePackName, resultOp.getQualCppClassName(), locToUse,
+ useProperties ? "tblgen_props" : "tblgen_attrs");
os.unindent() << "}\n";
return resultValue;
}
@@ -1791,12 +1800,21 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
Operator &resultOp = node.getDialectOp(opMap);
+ bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
auto scope = os.scope();
os << formatv("::llvm::SmallVector<::mlir::Value, 4> "
"tblgen_values; (void)tblgen_values;\n");
- os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
- "tblgen_attrs; (void)tblgen_attrs;\n");
+ if (useProperties) {
+ os << formatv("{0}::Properties tblgen_props; (void)tblgen_props;\n",
+ resultOp.getQualCppClassName());
+ } else {
+ os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
+ "tblgen_attrs; (void)tblgen_attrs;\n");
+ }
+ const char *setPropCmd =
+ "tblgen_props.{0} = "
+ "::llvm::dyn_cast_if_present<decltype(tblgen_props.{0})>({1});\n";
const char *addAttrCmd =
"if (auto tmpAttr = {1}) {\n"
" tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
@@ -1814,13 +1832,23 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
if (!subTree.isNativeCodeCall())
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
- os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
+
+ if (useProperties) {
+ os << formatv(setPropCmd, opArgName, childNodeNames.lookup(argIndex));
+ } else {
+ os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
+ }
} else {
auto leaf = node.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern.
auto patArgName = node.getArgName(argIndex);
- os << formatv(addAttrCmd, opArgName,
- handleOpArgument(leaf, patArgName));
+ if (useProperties) {
+ os << formatv(setPropCmd, opArgName,
+ handleOpArgument(leaf, patArgName));
+ } else {
+ os << formatv(addAttrCmd, opArgName,
+ handleOpArgument(leaf, patArgName));
+ }
}
continue;
}
@@ -1876,11 +1904,18 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
const auto *sameVariadicSize =
resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
if (!sameVariadicSize) {
- const char *setSizes = R"(
- tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
- rewriter.getDenseI32ArrayAttr({{ {0} }));
- )";
- os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+ if (useProperties) {
+ const char *setSizes = R"(
+ tblgen_props.operandSegmentSizes = {{ {0} };
+ )";
+ os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+ } else {
+ const char *setSizes = R"(
+ tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
+ rewriter.getDenseI32ArrayAttr({{ {0} }));
+ )";
+ os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+ }
}
}
}
>From 4e1dc86ab5743cf896b1d0abc2ded8c4540c86fa Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <krzysdrewniak at gmail.com>
Date: Tue, 28 Jan 2025 21:10:01 -0800
Subject: [PATCH 2/2] Test fails on Windows, try to loosen it up
---
mlir/test/mlir-tblgen/rewriter-attributes-properties.td | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
index 77869d36cc12ee4..fc36a51789ec28d 100644
--- a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
+++ b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
@@ -29,9 +29,9 @@ def BOp : NS_Op<"b_op", []> {
def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>;
// CHECK-LABEL: struct test1
// CHECK: ::llvm::LogicalResult matchAndRewrite
-// CHECK: ::mlir::IntegerAttr y;
-// CHECK: test::BOp x;
-// CHECK: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
+// CHECK-DAG: ::mlir::IntegerAttr y;
+// CHECK-DAG: test::BOp x;
+// CHECK-DAG: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
// CHECK: tblgen_ops.push_back(op0);
// CHECK: x = castedOp1;
// CHECK: tblgen_attr = castedOp1.getProperties().getY();
More information about the llvm-branch-commits
mailing list