[Mlir-commits] [mlir] [mlir][drr] Allow variadic in rewrite side (PR #93340)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 24 13:00:40 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Jacques Pienaar (jpienaar)

<details>
<summary>Changes</summary>

Enables writing patterns where one has op creation with variadic in result pattern more easily.

Signed-off-by: Jacques Pienaar <jpienaar@<!-- -->google.com>

---
Full diff: https://github.com/llvm/llvm-project/pull/93340.diff


3 Files Affected:

- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+6) 
- (modified) mlir/test/mlir-tblgen/pattern.mlir (+8) 
- (modified) mlir/tools/mlir-tblgen/RewriterGen.cpp (+27) 


``````````diff
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index c5d0341b7de77..faf70ad91b06b 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1696,6 +1696,12 @@ def : Pat<
   (MixedVOperandOp5 $input1a, $input1b, $input2, $attr1,
                     ConstantStrAttr<StrAttr, "MatchVariadic">)>;
 
+def : Pat<
+  (MixedVOperandOp5 $input1a, $input1b, $input2, $attr1,
+                    ConstantStrAttr<StrAttr, "MatchInverseVariadic">),
+  (MixedVOperandOp3 $input2, (variadic $input1b), (variadic $input1a),
+                    ConstantAttr<I32Attr, "1">:$attr1)>;
+
 def : Pat<
   (MixedVOperandOp4 (variadic (MixedVOperandInOutI32Op $input1a),
                               (MixedVOperandInOutI32Op $input1b)),
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 7f9c450f15b21..6b510abb93294 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -527,6 +527,14 @@ func.func @testMatchVariadic(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) ->
   return
 }
 
+// CHECK-LABEL: @testReplaceVariadic
+func.func @testReplaceVariadic(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () {
+  // CHECK" "test.mixed_variadic_in3"(%arg2, %arg1, %arg0) <{count = 1 : i32}>
+  "test.mixed_variadic_in5"(%arg0, %arg1, %arg2) <{attr1 = 0 : i32, pattern_name = "MatchInverseVariadic"}> : (i32, i32, i32) -> ()
+
+  return
+}
+
 // CHECK-LABEL: @testMatchVariadicSubDag
 func.func @testMatchVariadicSubDag(%arg0: i32, %arg1: i32, %arg2: i32) -> () {
   // CHECK: %[[IN0:.*]] = "test.mixed_variadic_in_out_i32"(%arg0) : (i32) -> i32
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index e63a065a07084..d8e16d98fd756 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -159,6 +159,10 @@ class PatternEmitter {
   // Returns the symbol of the old value serving as the replacement.
   StringRef handleReplaceWithValue(DagNode tree);
 
+  // Emits the C++ statement to replace the matched DAG with an array of
+  // matched values.
+  std::string handleVariadic(DagNode tree, int depth);
+
   // Trailing directives are used at the end of DAG node argument lists to
   // specify additional behaviour for op matchers and creators, etc.
   struct TrailingDirectives {
@@ -1241,6 +1245,9 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
   if (resultTree.isReplaceWithValue())
     return handleReplaceWithValue(resultTree).str();
 
+  if (resultTree.isVariadic())
+    return handleVariadic(resultTree, depth);
+
   // Normal op creation.
   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
   if (resultTree.getSymbol().empty()) {
@@ -1251,6 +1258,26 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
   return symbol;
 }
 
+std::string PatternEmitter::handleVariadic(DagNode tree, int depth) {
+  assert(tree.isVariadic());
+
+  auto name = std::string(formatv("tblgen_variadic_values_{0}", nextValueId++));
+  symbolInfoMap.bindValue(name);
+  os << "::llvm::SmallVector<::mlir::Value, 4> " << name << ";\n";
+  for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+    if (auto child = tree.getArgAsNestedDag(i)) {
+      os << name << ".push_back(" << handleResultPattern(child, i, depth + 1)
+         << ");\n";
+    } else {
+      os << name << ".push_back("
+         << handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i))
+         << ");\n";
+    }
+  }
+
+  return name;
+}
+
 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
   assert(tree.isReplaceWithValue());
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/93340


More information about the Mlir-commits mailing list