[Mlir-commits] [mlir] 3ba79a3 - [mlir][mem2reg] Add mem2reg rewrite pattern.
Tobias Gysi
llvmlistbot at llvm.org
Tue May 9 07:13:10 PDT 2023
Author: Théo Degioanni
Date: 2023-05-09T14:01:45Z
New Revision: 3ba79a368122ec3779ba8199ca1b84f5ad57e71a
URL: https://github.com/llvm/llvm-project/commit/3ba79a368122ec3779ba8199ca1b84f5ad57e71a
DIFF: https://github.com/llvm/llvm-project/commit/3ba79a368122ec3779ba8199ca1b84f5ad57e71a.diff
LOG: [mlir][mem2reg] Add mem2reg rewrite pattern.
This revision introduces the ability to invoke mem2reg as a rewrite pattern. This also modified the canonical mem2reg pass to use the rewrite pattern approach.
Depends on D149825
Reviewed By: gysit
Differential Revision: https://reviews.llvm.org/D149958
Added:
Modified:
mlir/include/mlir/Transforms/Mem2Reg.h
mlir/lib/Transforms/Mem2Reg.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index 1593b12ff1ac2..a34ea68e750bf 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -11,6 +11,7 @@
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
namespace mlir {
@@ -117,6 +118,19 @@ class MemorySlotPromoter {
MemorySlotPromotionInfo info;
};
+/// Pattern applying mem2reg to the regions of the operations on which it
+/// matches.
+class Mem2RegPattern : public RewritePattern {
+public:
+ using RewritePattern::RewritePattern;
+
+ Mem2RegPattern(MLIRContext *ctx, PatternBenefit benefit = 1)
+ : RewritePattern(MatchAnyOpTypeTag(), benefit, ctx) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
+};
+
/// Attempts to promote the memory slots of the provided allocators. Succeeds if
/// at least one memory slot was promoted.
LogicalResult
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 633813f8ef3d1..a4bf97ca5ffdb 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
@@ -22,6 +23,8 @@ namespace mlir {
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir
+#define DEBUG_TYPE "mem2reg"
+
using namespace mlir;
/// mem2reg
@@ -422,6 +425,9 @@ void MemorySlotPromoter::promoteSlot() {
}
}
+ LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr
+ << "\n");
+
allocator.handlePromotionComplete(slot, defaultValue);
}
@@ -450,39 +456,49 @@ LogicalResult mlir::tryToPromoteMemorySlots(
return success(!toPromote.empty());
}
-namespace {
+LogicalResult Mem2RegPattern::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ hasBoundedRewriteRecursion();
-struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
- void runOnOperation() override {
- Operation *scopeOp = getOperation();
- bool changed = false;
+ if (op->getNumRegions() == 0)
+ return failure();
- for (Region ®ion : scopeOp->getRegions()) {
- if (region.getBlocks().empty())
- continue;
+ DominanceInfo dominance;
- OpBuilder builder(®ion.front(), region.front().begin());
+ SmallVector<PromotableAllocationOpInterface> allocators;
+ // Build a list of allocators to attempt to promote the slots of.
+ for (Region ®ion : op->getRegions())
+ for (auto allocator : region.getOps<PromotableAllocationOpInterface>())
+ allocators.emplace_back(allocator);
- // Promoting a slot can allow for further promotion of other slots,
- // promotion is tried until no promotion succeeds.
- while (true) {
- DominanceInfo &dominance = getAnalysis<DominanceInfo>();
+ // Because pattern rewriters are normally not expressive enough to support a
+ // transformation like mem2reg, this uses an escape hatch to mark modified
+ // operations manually and operate outside of its context.
+ rewriter.startRootUpdate(op);
- SmallVector<PromotableAllocationOpInterface> allocators;
- // Build a list of allocators to attempt to promote the slots of.
- for (Block &block : region)
- for (Operation &op : block.getOperations())
- if (auto allocator = dyn_cast<PromotableAllocationOpInterface>(op))
- allocators.emplace_back(allocator);
+ OpBuilder builder(rewriter.getContext());
- // Attempt promoting until no promotion succeeds.
- if (failed(tryToPromoteMemorySlots(allocators, builder, dominance)))
- break;
+ if (failed(tryToPromoteMemorySlots(allocators, builder, dominance))) {
+ rewriter.cancelRootUpdate(op);
+ return failure();
+ }
- changed = true;
- getAnalysisManager().invalidate({});
- }
- }
+ rewriter.finalizeRootUpdate(op);
+ return success();
+}
+
+namespace {
+
+struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
+ void runOnOperation() override {
+ Operation *scopeOp = getOperation();
+ bool changed = false;
+
+ RewritePatternSet rewritePatterns(&getContext());
+ rewritePatterns.add<Mem2RegPattern>(&getContext());
+ FrozenRewritePatternSet frozen(std::move(rewritePatterns));
+ (void)applyOpPatternsAndFold({scopeOp}, frozen, GreedyRewriteConfig(),
+ &changed);
if (!changed)
markAllAnalysesPreserved();
More information about the Mlir-commits
mailing list