[llvm-branch-commits] [mlir] c8fb6ee - [mlir][PatternRewriter] Add a new hook to selectively replace uses of an operation

River Riddle via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jan 14 12:09:12 PST 2021


Author: River Riddle
Date: 2021-01-14T11:58:21-08:00
New Revision: c8fb6ee34151b18bcc9ed1a6b8f502a0b40a414e

URL: https://github.com/llvm/llvm-project/commit/c8fb6ee34151b18bcc9ed1a6b8f502a0b40a414e
DIFF: https://github.com/llvm/llvm-project/commit/c8fb6ee34151b18bcc9ed1a6b8f502a0b40a414e.diff

LOG: [mlir][PatternRewriter] Add a new hook to selectively replace uses of an operation

This revision adds a new `replaceOpWithIf` hook that replaces uses of an operation that satisfy a given functor. If all uses are replaced, the operation gets erased in a similar manner to `replaceOp`. DialectConversion support will be added in a followup as this requires adjusting how replacements are tracked there.

Differential Revision: https://reviews.llvm.org/D94632

Added: 
    mlir/test/Transforms/test-pattern-selective-replacement.mlir

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/IR/PatternMatch.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index d97b328cdc01..1a306e6ba58c 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -11,6 +11,7 @@
 
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "llvm/ADT/FunctionExtras.h"
 
 namespace mlir {
 
@@ -447,6 +448,30 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
                          Region::iterator before);
   void cloneRegionBefore(Region &region, Block *before);
 
+  /// This method replaces the uses of the results of `op` with the values in
+  /// `newValues` when the provided `functor` returns true for a specific use.
+  /// The number of values in `newValues` is required to match the number of
+  /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
+  /// the uses of `op` were replaced. Note that in some pattern rewriters, the
+  /// given 'functor' may be stored beyond the lifetime of the pattern being
+  /// applied. As such, the function should not capture by reference and instead
+  /// use value capture as necessary.
+  virtual void
+  replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
+                  llvm::unique_function<bool(OpOperand &) const> functor);
+  void replaceOpWithIf(Operation *op, ValueRange newValues,
+                       llvm::unique_function<bool(OpOperand &) const> functor) {
+    replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
+                    std::move(functor));
+  }
+
+  /// This method replaces the uses of the results of `op` with the values in
+  /// `newValues` when a use is nested within the given `block`. The number of
+  /// values in `newValues` is required to match the number of results of `op`.
+  /// If all uses of this operation are replaced, the operation is erased.
+  void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
+                            bool *allUsesReplaced = nullptr);
+
   /// This method performs the final replacement for a pattern, where the
   /// results of the operation are updated to use the specified list of SSA
   /// values.

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 51c7788ffb14..ca28c175fbdd 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -470,6 +470,12 @@ class ConversionPatternRewriter final : public PatternRewriter {
   // PatternRewriter Hooks
   //===--------------------------------------------------------------------===//
 
+  /// PatternRewriter hook for replacing the results of an operation when the
+  /// given functor returns true.
+  void replaceOpWithIf(
+      Operation *op, ValueRange newValues, bool *allUsesReplaced,
+      llvm::unique_function<bool(OpOperand &) const> functor) override;
+
   /// PatternRewriter hook for replacing the results of an operation.
   void replaceOp(Operation *op, ValueRange newValues) override;
   using PatternRewriter::replaceOp;

diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 6558fcf4606d..44f22ceeb3cf 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -155,6 +155,41 @@ PatternRewriter::~PatternRewriter() {
   // Out of line to provide a vtable anchor for the class.
 }
 
+/// This method replaces the uses of the results of `op` with the values in
+/// `newValues` when the provided `functor` returns true for a specific use.
+/// The number of values in `newValues` is required to match the number of
+/// results of `op`.
+void PatternRewriter::replaceOpWithIf(
+    Operation *op, ValueRange newValues, bool *allUsesReplaced,
+    llvm::unique_function<bool(OpOperand &) const> functor) {
+  assert(op->getNumResults() == newValues.size() &&
+         "incorrect number of values to replace operation");
+
+  // Notify the rewriter subclass that we're about to replace this root.
+  notifyRootReplaced(op);
+
+  // Replace each use of the results when the functor is true.
+  bool replacedAllUses = true;
+  for (auto it : llvm::zip(op->getResults(), newValues)) {
+    std::get<0>(it).replaceUsesWithIf(std::get<1>(it), functor);
+    replacedAllUses &= std::get<0>(it).use_empty();
+  }
+  if (allUsesReplaced)
+    *allUsesReplaced = replacedAllUses;
+}
+
+/// This method replaces the uses of the results of `op` with the values in
+/// `newValues` when a use is nested within the given `block`. The number of
+/// values in `newValues` is required to match the number of results of `op`.
+/// If all uses of this operation are replaced, the operation is erased.
+void PatternRewriter::replaceOpWithinBlock(Operation *op, ValueRange newValues,
+                                           Block *block,
+                                           bool *allUsesReplaced) {
+  replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
+    return block->getParentOp()->isProperAncestor(use.getOwner());
+  });
+}
+
 /// This method performs the final replacement for a pattern, where the
 /// results of the operation are updated to use the specified list of SSA
 /// values.

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f1fa1e250610..a97c461a8e9c 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1250,6 +1250,21 @@ ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
       impl(new detail::ConversionPatternRewriterImpl(*this)) {}
 ConversionPatternRewriter::~ConversionPatternRewriter() {}
 
+/// PatternRewriter hook for replacing the results of an operation when the
+/// given functor returns true.
+void ConversionPatternRewriter::replaceOpWithIf(
+    Operation *op, ValueRange newValues, bool *allUsesReplaced,
+    llvm::unique_function<bool(OpOperand &) const> functor) {
+  // TODO: To support this we will need to rework a bit of how replacements are
+  // tracked, given that this isn't guranteed to replace all of the uses of an
+  // operation. The main change is that now an operation can be replaced
+  // multiple times, in parts. The current "set" based tracking is mainly useful
+  // for tracking if a replaced operation should be ignored, i.e. if all of the
+  // uses will be replaced.
+  llvm_unreachable(
+      "replaceOpWithIf is currently not supported by DialectConversion");
+}
+
 /// PatternRewriter hook for replacing the results of an operation.
 void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
   LLVM_DEBUG({

diff  --git a/mlir/test/Transforms/test-pattern-selective-replacement.mlir b/mlir/test/Transforms/test-pattern-selective-replacement.mlir
new file mode 100644
index 000000000000..d22c439f3905
--- /dev/null
+++ b/mlir/test/Transforms/test-pattern-selective-replacement.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-pattern-selective-replacement -verify-diagnostics %s | FileCheck %s
+
+// Test that operations can be selectively replaced.
+
+// CHECK-LABEL: @test1
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func @test1(%arg0: i32, %arg1 : i32) -> () {
+  // CHECK: addi %[[ARG1]], %[[ARG1]]
+  // CHECK-NEXT: "test.return"(%[[ARG0]]
+  %cast = "test.cast"(%arg0, %arg1) : (i32, i32) -> (i32)
+  %non_terminator = addi %cast, %cast : i32
+  "test.return"(%cast, %non_terminator) : (i32, i32) -> ()
+}
+
+// -----

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 34e8b4d34a4f..a4c16a6d533f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -847,6 +847,10 @@ struct TestTypeConversionDriver
 };
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// Test Block Merging
+//===----------------------------------------------------------------------===//
+
 namespace {
 /// A rewriter pattern that tests that blocks can be merged.
 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
@@ -955,6 +959,46 @@ struct TestMergeBlocksPatternDriver
 };
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// Test Selective Replacement
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A rewrite mechanism to inline the body of the op into its parent, when both
+/// ops can have a single block.
+struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
+  using OpRewritePattern<TestCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TestCastOp op,
+                                PatternRewriter &rewriter) const final {
+    if (op.getNumOperands() != 2)
+      return failure();
+    OperandRange operands = op.getOperands();
+
+    // Replace non-terminator uses with the first operand.
+    rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
+      return operand.getOwner()->isKnownTerminator();
+    });
+    // Replace everything else with the second operand if the operation isn't
+    // dead.
+    rewriter.replaceOp(op, op.getOperand(1));
+    return success();
+  }
+};
+
+struct TestSelectiveReplacementPatternDriver
+    : public PassWrapper<TestSelectiveReplacementPatternDriver,
+                         OperationPass<>> {
+  void runOnOperation() override {
+    mlir::OwningRewritePatternList patterns;
+    MLIRContext *context = &getContext();
+    patterns.insert<TestSelectiveOpReplacementPattern>(context);
+    applyPatternsAndFoldGreedily(getOperation()->getRegions(),
+                                 std::move(patterns));
+  }
+};
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // PassRegistration
 //===----------------------------------------------------------------------===//
@@ -992,6 +1036,9 @@ void registerPatternsTestPass() {
   PassRegistration<TestMergeBlocksPatternDriver>{
       "test-merge-blocks",
       "Test Merging operation in ConversionPatternRewriter"};
+  PassRegistration<TestSelectiveReplacementPatternDriver>{
+      "test-pattern-selective-replacement",
+      "Test selective replacement in the PatternRewriter"};
 }
 } // namespace test
 } // namespace mlir


        


More information about the llvm-branch-commits mailing list