[Mlir-commits] [mlir] 4bba8bd - [mlir] Add RewriterBase::replaceAllUsesWith for Blocks.

Ingo Müller llvmlistbot at llvm.org
Tue Feb 14 23:23:27 PST 2023


Author: Ingo Müller
Date: 2023-02-15T07:23:21Z
New Revision: 4bba8bd33efdfb43c840e2dfa7ef5263254facdb

URL: https://github.com/llvm/llvm-project/commit/4bba8bd33efdfb43c840e2dfa7ef5263254facdb
DIFF: https://github.com/llvm/llvm-project/commit/4bba8bd33efdfb43c840e2dfa7ef5263254facdb.diff

LOG: [mlir] Add RewriterBase::replaceAllUsesWith for Blocks.

When changing IR in a RewriterPattern, all changes must go through the
rewriter. There are several convenience functions in RewriterBase that
help with high-level modifications, such as replaceAllUsesWith for
Values, but there is currently none to do the same task for Blocks.

Reviewed By: mehdi_amini, ingomueller-net

Differential Revision: https://reviews.llvm.org/D142525

Added: 
    

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/lib/IR/PatternMatch.cpp
    mlir/test/Transforms/test-strict-pattern-driver.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 64eb66bf53b73..187ce060f7ebb 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -505,7 +505,16 @@ class RewriterBase : public OpBuilder, public OpBuilder::Listener {
   /// Find uses of `from` and replace them with `to`. It also marks every
   /// modified uses and notifies the rewriter that an in-place operation
   /// modification is about to happen.
-  void replaceAllUsesWith(Value from, Value to);
+  void replaceAllUsesWith(Value from, Value to) {
+    return replaceAllUsesWith(from.getImpl(), to);
+  }
+  template <typename OperandType, typename ValueT>
+  void replaceAllUsesWith(IRObjectWithUseList<OperandType> *from, ValueT &&to) {
+    for (OperandType &operand : llvm::make_early_inc_range(from->getUses())) {
+      Operation *op = operand.getOwner();
+      updateRootInPlace(op, [&]() { operand.set(to); });
+    }
+  }
 
   /// Find uses of `from` and replace them with `to` if the `functor` returns
   /// true. It also marks every modified uses and notifies the rewriter that an

diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index b082b0d4cd6ef..1ca86cdcba1cc 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -309,14 +309,6 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
   source->erase();
 }
 
-/// Find uses of `from` and replace it with `to`
-void RewriterBase::replaceAllUsesWith(Value from, Value to) {
-  for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
-    Operation *op = operand.getOwner();
-    updateRootInPlace(op, [&]() { operand.set(to); });
-  }
-}
-
 /// Find uses of `from` and replace them with `to` if the `functor` returns
 /// true. It also marks every modified uses and notifies the rewriter that an
 /// in-place operation modification is about to happen.

diff  --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 9dbaea18967f8..5df2d6d1fdeeb 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -1,3 +1,7 @@
+// RUN: mlir-opt \
+// RUN:     -test-strict-pattern-driver="strictness=AnyOp" \
+// RUN:     --split-input-file %s | FileCheck %s --check-prefix=CHECK-AN
+
 // RUN: mlir-opt \
 // RUN:     -test-strict-pattern-driver="strictness=ExistingAndNewOps" \
 // RUN:     --split-input-file %s | FileCheck %s --check-prefix=CHECK-EN
@@ -58,3 +62,24 @@ func.func @test_replace_with_erase_op() {
   "test.replace_with_new_op"() {create_erase_op} : () -> ()
   return
 }
+
+// -----
+
+// CHECK-AN-LABEL: func @test_trigger_rewrite_through_block
+//       CHECK-AN: "test.change_block_op"()[^[[BB0:.*]], ^[[BB0]]]
+//       CHECK-AN: return
+//       CHECK-AN: ^[[BB1:[^:]*]]:
+//       CHECK-AN: "test.implicit_change_op"()[^[[BB1]]]
+func.func @test_trigger_rewrite_through_block() {
+  return
+^bb1:
+  // Uses bb1. ChangeBlockOp replaces that and all other usages of bb1 with bb2.
+  "test.change_block_op"() [^bb1, ^bb2] : () -> ()
+^bb2:
+  return
+^bb3:
+  // Also uses bb1. ChangeBlockOp replaces that usage with bb2. This triggers
+  // this op being put on the worklist, which triggers ImplicitChangeOp, which,
+  // in turn, replaces the successor with bb3.
+  "test.implicit_change_op"() [^bb1] : () -> ()
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index a15d30725c1ad..4bfbb3496ec3a 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -256,11 +256,19 @@ struct TestStrictPatternDriver
   void runOnOperation() override {
     MLIRContext *ctx = &getContext();
     mlir::RewritePatternSet patterns(ctx);
-    patterns.add<InsertSameOp, ReplaceWithNewOp, EraseOp>(ctx);
+    patterns.add<
+        // clang-format off
+        InsertSameOp,
+        ReplaceWithNewOp,
+        EraseOp,
+        ChangeBlockOp,
+        ImplicitChangeOp
+        // clang-format on
+        >(ctx);
     SmallVector<Operation *> ops;
     getOperation()->walk([&](Operation *op) {
       StringRef opName = op->getName().getStringRef();
-      if (opName == "test.insert_same_op" ||
+      if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
           opName == "test.replace_with_new_op" || opName == "test.erase_op") {
         ops.push_back(op);
       }
@@ -342,7 +350,7 @@ struct TestStrictPatternDriver
     }
   };
 
-  // Remove an operation may introduce the re-visiting of its opreands.
+  // Remove an operation may introduce the re-visiting of its operands.
   class EraseOp : public RewritePattern {
   public:
     EraseOp(MLIRContext *context)
@@ -353,6 +361,55 @@ struct TestStrictPatternDriver
       return success();
     }
   };
+
+  // The following two patterns test RewriterBase::replaceAllUsesWith.
+  //
+  // That function replaces all usages of a Block (or a Value) with another one
+  // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
+  // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
+  // worklist: when an op is modified, it is added to the worklist. The two
+  // patterns below make the tracking observable: ChangeBlockOp replaces all
+  // usages of a block and that pattern is applied because the corresponding ops
+  // are put on the initial worklist (see above). ImplicitChangeOp does an
+  // unrelated change but ops of the corresponding type are *not* on the initial
+  // worklist, so the effect of the second pattern is only visible if the
+  // tracking and subsequent adding to the worklist actually works.
+
+  // Replace all usages of the first successor with the second successor.
+  class ChangeBlockOp : public RewritePattern {
+  public:
+    ChangeBlockOp(MLIRContext *context)
+        : RewritePattern("test.change_block_op", /*benefit=*/1, context) {}
+    LogicalResult matchAndRewrite(Operation *op,
+                                  PatternRewriter &rewriter) const override {
+      if (op->getNumSuccessors() < 2)
+        return failure();
+      Block *firstSuccessor = op->getSuccessor(0);
+      Block *secondSuccessor = op->getSuccessor(1);
+      if (firstSuccessor == secondSuccessor)
+        return failure();
+      // This is the function being tested:
+      rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor);
+      // Using the following line instead would make the test fail:
+      // firstSuccessor->replaceAllUsesWith(secondSuccessor);
+      return success();
+    }
+  };
+
+  // Changes the successor to the parent block.
+  class ImplicitChangeOp : public RewritePattern {
+  public:
+    ImplicitChangeOp(MLIRContext *context)
+        : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {}
+    LogicalResult matchAndRewrite(Operation *op,
+                                  PatternRewriter &rewriter) const override {
+      if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock())
+        return failure();
+      rewriter.updateRootInPlace(
+          op, [&]() { op->setSuccessor(op->getBlock(), 0); });
+      return success();
+    }
+  };
 };
 
 } // namespace


        


More information about the Mlir-commits mailing list