[Mlir-commits] [mlir] 616c86a - [mlir][drr] Set operand segment in rewrite
Jacques Pienaar
llvmlistbot at llvm.org
Thu Oct 19 13:06:25 PDT 2023
Author: Jacques Pienaar
Date: 2023-10-19T13:06:17-07:00
New Revision: 616c86accbf4c9ada37da6fb6b04554dec0fffee
URL: https://github.com/llvm/llvm-project/commit/616c86accbf4c9ada37da6fb6b04554dec0fffee
DIFF: https://github.com/llvm/llvm-project/commit/616c86accbf4c9ada37da6fb6b04554dec0fffee.diff
LOG: [mlir][drr] Set operand segment in rewrite
This allows some basic variadic operands in rewrites. There were some workarounds employed (like "aliasing" the attribute). Couldn't find a way to do this directly with properties.
Added:
Modified:
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/pattern.mlir
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index edb63924b3553f2..1add9bd3c329438 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2495,6 +2495,29 @@ def TestDefaultStrAttrHasValueOp : TEST_Op<"has_str_value"> {
def : Pat<(TestDefaultStrAttrNoValueOp $value),
(TestDefaultStrAttrHasValueOp ConstantStrAttr<StrAttr, "foo">)>;
+//===----------------------------------------------------------------------===//
+// Test Ops with variadics
+//===----------------------------------------------------------------------===//
+
+def TestVariadicRewriteSrcOp : TEST_Op<"variadic_rewrite_src_op", [AttrSizedOperandSegments]> {
+ let arguments = (ins
+ Variadic<AnyType>:$arg,
+ AnyType:$brg,
+ Variadic<AnyType>:$crg
+ );
+}
+
+def TestVariadicRewriteDstOp : TEST_Op<"variadic_rewrite_dst_op", [AttrSizedOperandSegments]> {
+ let arguments = (ins
+ AnyType:$brg,
+ Variadic<AnyType>:$crg,
+ Variadic<AnyType>:$arg
+ );
+}
+
+def : Pat<(TestVariadicRewriteSrcOp $arg, $brg, $crg),
+ (TestVariadicRewriteDstOp $brg, $crg, $arg)>;
+
//===----------------------------------------------------------------------===//
// Test Ops with Default-Valued Attributes and Differing Print Settings
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 5f776338bd40be8..7f9c450f15b21e8 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -683,3 +683,16 @@ func.func @testConstantStrAttr() -> () {
test.no_str_value {value = "bar"}
return
}
+
+//===----------------------------------------------------------------------===//
+// Test that patterns with variadics propagate sizes
+//===----------------------------------------------------------------------===//
+
+func.func @testVariadic(%arg_0: i32, %arg_1: i32, %brg: i64,
+ %crg_0: f32, %crg_1: f32, %crg_2: f32, %crg_3: f32) -> () {
+ // CHECK: "test.variadic_rewrite_dst_op"(%arg2, %arg3, %arg4, %arg5, %arg6, %arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 4, 2>}> : (i64, f32, f32, f32, f32, i32, i32) -> ()
+ "test.variadic_rewrite_src_op"(%arg_0, %arg_1, %brg,
+ %crg_0, %crg_1, %crg_2, %crg_3) {operandSegmentSizes = array<i32: 2, 1, 4>} :
+ (i32, i32, i64, f32, f32, f32, f32) -> ()
+ return
+}
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 9f36a3b430274a3..77c34cb03e987ea 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1743,10 +1743,15 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
"if (auto tmpAttr = {1}) {\n"
" tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
"tmpAttr);\n}\n";
+ int numVariadic = 0;
+ bool hasOperandSegmentSizes = false;
+ std::vector<std::string> sizes;
for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
// The argument in the op definition.
auto opArgName = resultOp.getArgName(argIndex);
+ hasOperandSegmentSizes =
+ hasOperandSegmentSizes || opArgName == "operandSegmentSizes";
if (auto subTree = node.getArgAsNestedDag(argIndex)) {
if (!subTree.isNativeCodeCall())
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
@@ -1766,6 +1771,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
std::string varName;
if (operand->isVariadic()) {
+ ++numVariadic;
std::string range;
if (node.isNestedDagArg(argIndex)) {
range = childNodeNames.lookup(argIndex);
@@ -1777,7 +1783,9 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
range = symbolInfoMap.getValueAndRangeUse(range);
os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n",
range);
+ sizes.push_back(formatv("static_cast<int32_t>({0}.size())", range));
} else {
+ sizes.push_back("1");
os << formatv("tblgen_values.push_back(");
if (node.isNestedDagArg(argIndex)) {
os << symbolInfoMap.getValueAndRangeUse(
@@ -1804,6 +1812,19 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
os << ");\n";
}
}
+
+ if (numVariadic > 1 && !hasOperandSegmentSizes) {
+ // Only set size if it can't be computed.
+ const auto *sameVariadicSize =
+ resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
+ if (!sameVariadicSize) {
+ const char *setSizes = R"(
+ tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
+ rewriter.getDenseI32ArrayAttr({{ {0} }));
+ )";
+ os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+ }
+ }
}
StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,
More information about the Mlir-commits
mailing list