[llvm-branch-commits] [mlir] c8fb6ee - [mlir][PatternRewriter] Add a new hook to selectively replace uses of an operation
River Riddle via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jan 14 12:09:12 PST 2021
Author: River Riddle
Date: 2021-01-14T11:58:21-08:00
New Revision: c8fb6ee34151b18bcc9ed1a6b8f502a0b40a414e
URL: https://github.com/llvm/llvm-project/commit/c8fb6ee34151b18bcc9ed1a6b8f502a0b40a414e
DIFF: https://github.com/llvm/llvm-project/commit/c8fb6ee34151b18bcc9ed1a6b8f502a0b40a414e.diff
LOG: [mlir][PatternRewriter] Add a new hook to selectively replace uses of an operation
This revision adds a new `replaceOpWithIf` hook that replaces uses of an operation that satisfy a given functor. If all uses are replaced, the operation gets erased in a similar manner to `replaceOp`. DialectConversion support will be added in a followup as this requires adjusting how replacements are tracked there.
Differential Revision: https://reviews.llvm.org/D94632
Added:
mlir/test/Transforms/test-pattern-selective-replacement.mlir
Modified:
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index d97b328cdc01..1a306e6ba58c 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -11,6 +11,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
+#include "llvm/ADT/FunctionExtras.h"
namespace mlir {
@@ -447,6 +448,30 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
Region::iterator before);
void cloneRegionBefore(Region ®ion, Block *before);
+ /// This method replaces the uses of the results of `op` with the values in
+ /// `newValues` when the provided `functor` returns true for a specific use.
+ /// The number of values in `newValues` is required to match the number of
+ /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
+ /// the uses of `op` were replaced. Note that in some pattern rewriters, the
+ /// given 'functor' may be stored beyond the lifetime of the pattern being
+ /// applied. As such, the function should not capture by reference and instead
+ /// use value capture as necessary.
+ virtual void
+ replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
+ llvm::unique_function<bool(OpOperand &) const> functor);
+ void replaceOpWithIf(Operation *op, ValueRange newValues,
+ llvm::unique_function<bool(OpOperand &) const> functor) {
+ replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
+ std::move(functor));
+ }
+
+ /// This method replaces the uses of the results of `op` with the values in
+ /// `newValues` when a use is nested within the given `block`. The number of
+ /// values in `newValues` is required to match the number of results of `op`.
+ /// If all uses of this operation are replaced, the operation is erased.
+ void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
+ bool *allUsesReplaced = nullptr);
+
/// This method performs the final replacement for a pattern, where the
/// results of the operation are updated to use the specified list of SSA
/// values.
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 51c7788ffb14..ca28c175fbdd 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -470,6 +470,12 @@ class ConversionPatternRewriter final : public PatternRewriter {
// PatternRewriter Hooks
//===--------------------------------------------------------------------===//
+ /// PatternRewriter hook for replacing the results of an operation when the
+ /// given functor returns true.
+ void replaceOpWithIf(
+ Operation *op, ValueRange newValues, bool *allUsesReplaced,
+ llvm::unique_function<bool(OpOperand &) const> functor) override;
+
/// PatternRewriter hook for replacing the results of an operation.
void replaceOp(Operation *op, ValueRange newValues) override;
using PatternRewriter::replaceOp;
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 6558fcf4606d..44f22ceeb3cf 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -155,6 +155,41 @@ PatternRewriter::~PatternRewriter() {
// Out of line to provide a vtable anchor for the class.
}
+/// This method replaces the uses of the results of `op` with the values in
+/// `newValues` when the provided `functor` returns true for a specific use.
+/// The number of values in `newValues` is required to match the number of
+/// results of `op`.
+void PatternRewriter::replaceOpWithIf(
+ Operation *op, ValueRange newValues, bool *allUsesReplaced,
+ llvm::unique_function<bool(OpOperand &) const> functor) {
+ assert(op->getNumResults() == newValues.size() &&
+ "incorrect number of values to replace operation");
+
+ // Notify the rewriter subclass that we're about to replace this root.
+ notifyRootReplaced(op);
+
+ // Replace each use of the results when the functor is true.
+ bool replacedAllUses = true;
+ for (auto it : llvm::zip(op->getResults(), newValues)) {
+ std::get<0>(it).replaceUsesWithIf(std::get<1>(it), functor);
+ replacedAllUses &= std::get<0>(it).use_empty();
+ }
+ if (allUsesReplaced)
+ *allUsesReplaced = replacedAllUses;
+}
+
+/// This method replaces the uses of the results of `op` with the values in
+/// `newValues` when a use is nested within the given `block`. The number of
+/// values in `newValues` is required to match the number of results of `op`.
+/// If all uses of this operation are replaced, the operation is erased.
+void PatternRewriter::replaceOpWithinBlock(Operation *op, ValueRange newValues,
+ Block *block,
+ bool *allUsesReplaced) {
+ replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
+ return block->getParentOp()->isProperAncestor(use.getOwner());
+ });
+}
+
/// This method performs the final replacement for a pattern, where the
/// results of the operation are updated to use the specified list of SSA
/// values.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f1fa1e250610..a97c461a8e9c 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1250,6 +1250,21 @@ ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
impl(new detail::ConversionPatternRewriterImpl(*this)) {}
ConversionPatternRewriter::~ConversionPatternRewriter() {}
+/// PatternRewriter hook for replacing the results of an operation when the
+/// given functor returns true.
+void ConversionPatternRewriter::replaceOpWithIf(
+ Operation *op, ValueRange newValues, bool *allUsesReplaced,
+ llvm::unique_function<bool(OpOperand &) const> functor) {
+ // TODO: To support this we will need to rework a bit of how replacements are
+ // tracked, given that this isn't guranteed to replace all of the uses of an
+ // operation. The main change is that now an operation can be replaced
+ // multiple times, in parts. The current "set" based tracking is mainly useful
+ // for tracking if a replaced operation should be ignored, i.e. if all of the
+ // uses will be replaced.
+ llvm_unreachable(
+ "replaceOpWithIf is currently not supported by DialectConversion");
+}
+
/// PatternRewriter hook for replacing the results of an operation.
void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
LLVM_DEBUG({
diff --git a/mlir/test/Transforms/test-pattern-selective-replacement.mlir b/mlir/test/Transforms/test-pattern-selective-replacement.mlir
new file mode 100644
index 000000000000..d22c439f3905
--- /dev/null
+++ b/mlir/test/Transforms/test-pattern-selective-replacement.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-pattern-selective-replacement -verify-diagnostics %s | FileCheck %s
+
+// Test that operations can be selectively replaced.
+
+// CHECK-LABEL: @test1
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func @test1(%arg0: i32, %arg1 : i32) -> () {
+ // CHECK: addi %[[ARG1]], %[[ARG1]]
+ // CHECK-NEXT: "test.return"(%[[ARG0]]
+ %cast = "test.cast"(%arg0, %arg1) : (i32, i32) -> (i32)
+ %non_terminator = addi %cast, %cast : i32
+ "test.return"(%cast, %non_terminator) : (i32, i32) -> ()
+}
+
+// -----
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 34e8b4d34a4f..a4c16a6d533f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -847,6 +847,10 @@ struct TestTypeConversionDriver
};
} // end anonymous namespace
+//===----------------------------------------------------------------------===//
+// Test Block Merging
+//===----------------------------------------------------------------------===//
+
namespace {
/// A rewriter pattern that tests that blocks can be merged.
struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
@@ -955,6 +959,46 @@ struct TestMergeBlocksPatternDriver
};
} // namespace
+//===----------------------------------------------------------------------===//
+// Test Selective Replacement
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A rewrite mechanism to inline the body of the op into its parent, when both
+/// ops can have a single block.
+struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
+ using OpRewritePattern<TestCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TestCastOp op,
+ PatternRewriter &rewriter) const final {
+ if (op.getNumOperands() != 2)
+ return failure();
+ OperandRange operands = op.getOperands();
+
+ // Replace non-terminator uses with the first operand.
+ rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
+ return operand.getOwner()->isKnownTerminator();
+ });
+ // Replace everything else with the second operand if the operation isn't
+ // dead.
+ rewriter.replaceOp(op, op.getOperand(1));
+ return success();
+ }
+};
+
+struct TestSelectiveReplacementPatternDriver
+ : public PassWrapper<TestSelectiveReplacementPatternDriver,
+ OperationPass<>> {
+ void runOnOperation() override {
+ mlir::OwningRewritePatternList patterns;
+ MLIRContext *context = &getContext();
+ patterns.insert<TestSelectiveOpReplacementPattern>(context);
+ applyPatternsAndFoldGreedily(getOperation()->getRegions(),
+ std::move(patterns));
+ }
+};
+} // namespace
+
//===----------------------------------------------------------------------===//
// PassRegistration
//===----------------------------------------------------------------------===//
@@ -992,6 +1036,9 @@ void registerPatternsTestPass() {
PassRegistration<TestMergeBlocksPatternDriver>{
"test-merge-blocks",
"Test Merging operation in ConversionPatternRewriter"};
+ PassRegistration<TestSelectiveReplacementPatternDriver>{
+ "test-pattern-selective-replacement",
+ "Test selective replacement in the PatternRewriter"};
}
} // namespace test
} // namespace mlir
More information about the llvm-branch-commits
mailing list