[Mlir-commits] [mlir] c26847d - [mlir][drr] Allow variadic in rewrite side (#93340)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 24 15:53:29 PDT 2024


Author: Jacques Pienaar
Date: 2024-05-24T15:53:25-07:00
New Revision: c26847dc814a67527c6395a440abf7e1178ffe40

URL: https://github.com/llvm/llvm-project/commit/c26847dc814a67527c6395a440abf7e1178ffe40
DIFF: https://github.com/llvm/llvm-project/commit/c26847dc814a67527c6395a440abf7e1178ffe40.diff

LOG: [mlir][drr] Allow variadic in rewrite side (#93340)

Enables writing patterns where one has op creation with variadic in
result pattern more easily.

Signed-off-by: Jacques Pienaar <jpienaar at google.com>

Added: 
    

Modified: 
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/pattern.mlir
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index c5d0341b7de77..faf70ad91b06b 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1696,6 +1696,12 @@ def : Pat<
   (MixedVOperandOp5 $input1a, $input1b, $input2, $attr1,
                     ConstantStrAttr<StrAttr, "MatchVariadic">)>;
 
+def : Pat<
+  (MixedVOperandOp5 $input1a, $input1b, $input2, $attr1,
+                    ConstantStrAttr<StrAttr, "MatchInverseVariadic">),
+  (MixedVOperandOp3 $input2, (variadic $input1b), (variadic $input1a),
+                    ConstantAttr<I32Attr, "1">:$attr1)>;
+
 def : Pat<
   (MixedVOperandOp4 (variadic (MixedVOperandInOutI32Op $input1a),
                               (MixedVOperandInOutI32Op $input1b)),

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 7f9c450f15b21..5ff8710b93770 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -527,6 +527,14 @@ func.func @testMatchVariadic(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) ->
   return
 }
 
+// CHECK-LABEL: @testReplaceVariadic
+func.func @testReplaceVariadic(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () {
+  // CHECK: "test.mixed_variadic_in3"(%arg2, %arg1, %arg0) <{count = 1 : i32}>
+  "test.mixed_variadic_in5"(%arg0, %arg1, %arg2) <{attr1 = 0 : i32, pattern_name = "MatchInverseVariadic"}> : (i32, i32, i32) -> ()
+
+  return
+}
+
 // CHECK-LABEL: @testMatchVariadicSubDag
 func.func @testMatchVariadicSubDag(%arg0: i32, %arg1: i32, %arg2: i32) -> () {
   // CHECK: %[[IN0:.*]] = "test.mixed_variadic_in_out_i32"(%arg0) : (i32) -> i32

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index e63a065a07084..d8e16d98fd756 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -159,6 +159,10 @@ class PatternEmitter {
   // Returns the symbol of the old value serving as the replacement.
   StringRef handleReplaceWithValue(DagNode tree);
 
+  // Emits the C++ statement to replace the matched DAG with an array of
+  // matched values.
+  std::string handleVariadic(DagNode tree, int depth);
+
   // Trailing directives are used at the end of DAG node argument lists to
   // specify additional behaviour for op matchers and creators, etc.
   struct TrailingDirectives {
@@ -1241,6 +1245,9 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
   if (resultTree.isReplaceWithValue())
     return handleReplaceWithValue(resultTree).str();
 
+  if (resultTree.isVariadic())
+    return handleVariadic(resultTree, depth);
+
   // Normal op creation.
   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
   if (resultTree.getSymbol().empty()) {
@@ -1251,6 +1258,26 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
   return symbol;
 }
 
+std::string PatternEmitter::handleVariadic(DagNode tree, int depth) {
+  assert(tree.isVariadic());
+
+  auto name = std::string(formatv("tblgen_variadic_values_{0}", nextValueId++));
+  symbolInfoMap.bindValue(name);
+  os << "::llvm::SmallVector<::mlir::Value, 4> " << name << ";\n";
+  for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+    if (auto child = tree.getArgAsNestedDag(i)) {
+      os << name << ".push_back(" << handleResultPattern(child, i, depth + 1)
+         << ");\n";
+    } else {
+      os << name << ".push_back("
+         << handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i))
+         << ");\n";
+    }
+  }
+
+  return name;
+}
+
 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
   assert(tree.isReplaceWithValue());
 


        


More information about the Mlir-commits mailing list