[Mlir-commits] [mlir] [mlir][drr] Allow variadic in rewrite side (PR #93340)

Jacques Pienaar llvmlistbot at llvm.org
Fri May 24 15:33:22 PDT 2024


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/93340

>From ab6cd1a11f8e923d06884568b9a4b51f10fb7f16 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Fri, 24 May 2024 19:10:51 +0000
Subject: [PATCH 1/2] [mlir][drr] Allow variadic in rewrite side

Enables writing patterns where one has op creation with variadic in result pattern more easily.

Signed-off-by: Jacques Pienaar <jpienaar at google.com>
---
 mlir/test/lib/Dialect/Test/TestOps.td  |  6 ++++++
 mlir/test/mlir-tblgen/pattern.mlir     |  8 ++++++++
 mlir/tools/mlir-tblgen/RewriterGen.cpp | 27 ++++++++++++++++++++++++++
 3 files changed, 41 insertions(+)

diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index c5d0341b7de77..faf70ad91b06b 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1696,6 +1696,12 @@ def : Pat<
   (MixedVOperandOp5 $input1a, $input1b, $input2, $attr1,
                     ConstantStrAttr<StrAttr, "MatchVariadic">)>;
 
+def : Pat<
+  (MixedVOperandOp5 $input1a, $input1b, $input2, $attr1,
+                    ConstantStrAttr<StrAttr, "MatchInverseVariadic">),
+  (MixedVOperandOp3 $input2, (variadic $input1b), (variadic $input1a),
+                    ConstantAttr<I32Attr, "1">:$attr1)>;
+
 def : Pat<
   (MixedVOperandOp4 (variadic (MixedVOperandInOutI32Op $input1a),
                               (MixedVOperandInOutI32Op $input1b)),
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 7f9c450f15b21..6b510abb93294 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -527,6 +527,14 @@ func.func @testMatchVariadic(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) ->
   return
 }
 
+// CHECK-LABEL: @testReplaceVariadic
+func.func @testReplaceVariadic(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () {
+  // CHECK" "test.mixed_variadic_in3"(%arg2, %arg1, %arg0) <{count = 1 : i32}>
+  "test.mixed_variadic_in5"(%arg0, %arg1, %arg2) <{attr1 = 0 : i32, pattern_name = "MatchInverseVariadic"}> : (i32, i32, i32) -> ()
+
+  return
+}
+
 // CHECK-LABEL: @testMatchVariadicSubDag
 func.func @testMatchVariadicSubDag(%arg0: i32, %arg1: i32, %arg2: i32) -> () {
   // CHECK: %[[IN0:.*]] = "test.mixed_variadic_in_out_i32"(%arg0) : (i32) -> i32
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index e63a065a07084..d8e16d98fd756 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -159,6 +159,10 @@ class PatternEmitter {
   // Returns the symbol of the old value serving as the replacement.
   StringRef handleReplaceWithValue(DagNode tree);
 
+  // Emits the C++ statement to replace the matched DAG with an array of
+  // matched values.
+  std::string handleVariadic(DagNode tree, int depth);
+
   // Trailing directives are used at the end of DAG node argument lists to
   // specify additional behaviour for op matchers and creators, etc.
   struct TrailingDirectives {
@@ -1241,6 +1245,9 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
   if (resultTree.isReplaceWithValue())
     return handleReplaceWithValue(resultTree).str();
 
+  if (resultTree.isVariadic())
+    return handleVariadic(resultTree, depth);
+
   // Normal op creation.
   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
   if (resultTree.getSymbol().empty()) {
@@ -1251,6 +1258,26 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
   return symbol;
 }
 
+std::string PatternEmitter::handleVariadic(DagNode tree, int depth) {
+  assert(tree.isVariadic());
+
+  auto name = std::string(formatv("tblgen_variadic_values_{0}", nextValueId++));
+  symbolInfoMap.bindValue(name);
+  os << "::llvm::SmallVector<::mlir::Value, 4> " << name << ";\n";
+  for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+    if (auto child = tree.getArgAsNestedDag(i)) {
+      os << name << ".push_back(" << handleResultPattern(child, i, depth + 1)
+         << ");\n";
+    } else {
+      os << name << ".push_back("
+         << handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i))
+         << ");\n";
+    }
+  }
+
+  return name;
+}
+
 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
   assert(tree.isReplaceWithValue());
 

>From 093965534d3e26e382d91e49397cd318785ee535 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Fri, 24 May 2024 22:33:11 +0000
Subject: [PATCH 2/2] Fix CHECK

Signed-off-by: Jacques Pienaar <jpienaar at google.com>
---
 mlir/test/mlir-tblgen/pattern.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 6b510abb93294..5ff8710b93770 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -529,7 +529,7 @@ func.func @testMatchVariadic(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) ->
 
 // CHECK-LABEL: @testReplaceVariadic
 func.func @testReplaceVariadic(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () {
-  // CHECK" "test.mixed_variadic_in3"(%arg2, %arg1, %arg0) <{count = 1 : i32}>
+  // CHECK: "test.mixed_variadic_in3"(%arg2, %arg1, %arg0) <{count = 1 : i32}>
   "test.mixed_variadic_in5"(%arg0, %arg1, %arg2) <{attr1 = 0 : i32, pattern_name = "MatchInverseVariadic"}> : (i32, i32, i32) -> ()
 
   return



More information about the Mlir-commits mailing list