[Mlir-commits] [mlir] [mlir][drr] Allow variadic in rewrite side (PR #93340)
Jacques Pienaar
llvmlistbot at llvm.org
Fri May 24 13:00:11 PDT 2024
https://github.com/jpienaar created https://github.com/llvm/llvm-project/pull/93340
Enables writing patterns where one has op creation with variadic in result pattern more easily.
Signed-off-by: Jacques Pienaar <jpienaar at google.com>
>From ab6cd1a11f8e923d06884568b9a4b51f10fb7f16 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Fri, 24 May 2024 19:10:51 +0000
Subject: [PATCH] [mlir][drr] Allow variadic in rewrite side
Enables writing patterns where one has op creation with variadic in result pattern more easily.
Signed-off-by: Jacques Pienaar <jpienaar at google.com>
---
mlir/test/lib/Dialect/Test/TestOps.td | 6 ++++++
mlir/test/mlir-tblgen/pattern.mlir | 8 ++++++++
mlir/tools/mlir-tblgen/RewriterGen.cpp | 27 ++++++++++++++++++++++++++
3 files changed, 41 insertions(+)
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index c5d0341b7de77..faf70ad91b06b 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1696,6 +1696,12 @@ def : Pat<
(MixedVOperandOp5 $input1a, $input1b, $input2, $attr1,
ConstantStrAttr<StrAttr, "MatchVariadic">)>;
+def : Pat<
+ (MixedVOperandOp5 $input1a, $input1b, $input2, $attr1,
+ ConstantStrAttr<StrAttr, "MatchInverseVariadic">),
+ (MixedVOperandOp3 $input2, (variadic $input1b), (variadic $input1a),
+ ConstantAttr<I32Attr, "1">:$attr1)>;
+
def : Pat<
(MixedVOperandOp4 (variadic (MixedVOperandInOutI32Op $input1a),
(MixedVOperandInOutI32Op $input1b)),
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 7f9c450f15b21..6b510abb93294 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -527,6 +527,14 @@ func.func @testMatchVariadic(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) ->
return
}
+// CHECK-LABEL: @testReplaceVariadic
+func.func @testReplaceVariadic(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () {
+ // CHECK" "test.mixed_variadic_in3"(%arg2, %arg1, %arg0) <{count = 1 : i32}>
+ "test.mixed_variadic_in5"(%arg0, %arg1, %arg2) <{attr1 = 0 : i32, pattern_name = "MatchInverseVariadic"}> : (i32, i32, i32) -> ()
+
+ return
+}
+
// CHECK-LABEL: @testMatchVariadicSubDag
func.func @testMatchVariadicSubDag(%arg0: i32, %arg1: i32, %arg2: i32) -> () {
// CHECK: %[[IN0:.*]] = "test.mixed_variadic_in_out_i32"(%arg0) : (i32) -> i32
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index e63a065a07084..d8e16d98fd756 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -159,6 +159,10 @@ class PatternEmitter {
// Returns the symbol of the old value serving as the replacement.
StringRef handleReplaceWithValue(DagNode tree);
+ // Emits the C++ statement to replace the matched DAG with an array of
+ // matched values.
+ std::string handleVariadic(DagNode tree, int depth);
+
// Trailing directives are used at the end of DAG node argument lists to
// specify additional behaviour for op matchers and creators, etc.
struct TrailingDirectives {
@@ -1241,6 +1245,9 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
if (resultTree.isReplaceWithValue())
return handleReplaceWithValue(resultTree).str();
+ if (resultTree.isVariadic())
+ return handleVariadic(resultTree, depth);
+
// Normal op creation.
auto symbol = handleOpCreation(resultTree, resultIndex, depth);
if (resultTree.getSymbol().empty()) {
@@ -1251,6 +1258,26 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
return symbol;
}
+std::string PatternEmitter::handleVariadic(DagNode tree, int depth) {
+ assert(tree.isVariadic());
+
+ auto name = std::string(formatv("tblgen_variadic_values_{0}", nextValueId++));
+ symbolInfoMap.bindValue(name);
+ os << "::llvm::SmallVector<::mlir::Value, 4> " << name << ";\n";
+ for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+ if (auto child = tree.getArgAsNestedDag(i)) {
+ os << name << ".push_back(" << handleResultPattern(child, i, depth + 1)
+ << ");\n";
+ } else {
+ os << name << ".push_back("
+ << handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i))
+ << ");\n";
+ }
+ }
+
+ return name;
+}
+
StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
assert(tree.isReplaceWithValue());
More information about the Mlir-commits
mailing list