[Mlir-commits] [mlir] [MLIR][Mem2Reg] Replace pattern based approach with a bulk one (PR #85426)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 15 09:47:31 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-llvm
Author: Christian Ulmann (Dinistro)
<details>
<summary>Changes</summary>
This commit changes MLIR's Mem2Reg implementation back from being
pattern based into a full pass. Using Mem2Reg as a pattern is
wasteful, as each application can invalidate the dominance info.
Applying changes in bulk allows for reuse of the same dominance info.
Unfortunately, this requires some test changes, due to the `IRBuilder`
not simplifying IR.
---
Full diff: https://github.com/llvm/llvm-project/pull/85426.diff
3 Files Affected:
- (modified) mlir/include/mlir/Transforms/Mem2Reg.h (-20)
- (modified) mlir/lib/Transforms/Mem2Reg.cpp (+28-17)
- (modified) mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir (+27-10)
``````````diff
diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index 89244feb21754e..d145f7ed437582 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -9,8 +9,6 @@
#ifndef MLIR_TRANSFORMS_MEM2REG_H
#define MLIR_TRANSFORMS_MEM2REG_H
-#include "mlir/IR/Dominance.h"
-#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "llvm/ADT/Statistic.h"
@@ -25,24 +23,6 @@ struct Mem2RegStatistics {
llvm::Statistic *newBlockArgumentAmount = nullptr;
};
-/// Pattern applying mem2reg to the regions of the operations on which it
-/// matches.
-class Mem2RegPattern
- : public OpInterfaceRewritePattern<PromotableAllocationOpInterface> {
-public:
- using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
-
- Mem2RegPattern(MLIRContext *context, Mem2RegStatistics statistics = {},
- PatternBenefit benefit = 1)
- : OpInterfaceRewritePattern(context, benefit), statistics(statistics) {}
-
- LogicalResult matchAndRewrite(PromotableAllocationOpInterface allocator,
- PatternRewriter &rewriter) const override;
-
-private:
- Mem2RegStatistics statistics;
-};
-
/// Attempts to promote the memory slots of the provided allocators. Succeeds if
/// at least one memory slot was promoted.
LogicalResult
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index f3a973d9994083..5d39aecb2c7d14 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -14,10 +14,8 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
-#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/GenericIteratedDominanceFrontier.h"
@@ -635,13 +633,6 @@ LogicalResult mlir::tryToPromoteMemorySlots(
return success(promotedAny);
}
-LogicalResult
-Mem2RegPattern::matchAndRewrite(PromotableAllocationOpInterface allocator,
- PatternRewriter &rewriter) const {
- hasBoundedRewriteRecursion();
- return tryToPromoteMemorySlots({allocator}, rewriter, statistics);
-}
-
namespace {
struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
@@ -650,17 +641,37 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
void runOnOperation() override {
Operation *scopeOp = getOperation();
- Mem2RegStatistics statictics{&promotedAmount, &newBlockArgumentAmount};
+ Mem2RegStatistics statistics{&promotedAmount, &newBlockArgumentAmount};
- GreedyRewriteConfig config;
- config.enableRegionSimplification = enableRegionSimplification;
+ bool changed = false;
- RewritePatternSet rewritePatterns(&getContext());
- rewritePatterns.add<Mem2RegPattern>(&getContext(), statictics);
- FrozenRewritePatternSet frozen(std::move(rewritePatterns));
+ for (Region ®ion : scopeOp->getRegions()) {
+ if (region.getBlocks().empty())
+ continue;
- if (failed(applyPatternsAndFoldGreedily(scopeOp, frozen, config)))
- signalPassFailure();
+ OpBuilder builder(®ion.front(), region.front().begin());
+ IRRewriter rewriter(builder);
+
+ // Promoting a slot can allow for further promotion of other slots,
+ // promotion is tried until no promotion succeeds.
+ while (true) {
+ SmallVector<PromotableAllocationOpInterface> allocators;
+ // Build a list of allocators to attempt to promote the slots of.
+ for (Block &block : region)
+ for (Operation &op : block.getOperations())
+ if (auto allocator = dyn_cast<PromotableAllocationOpInterface>(op))
+ allocators.emplace_back(allocator);
+
+ // Attempt promoting until no promotion succeeds.
+ if (failed(tryToPromoteMemorySlots(allocators, rewriter, statistics)))
+ break;
+
+ changed = true;
+ getAnalysisManager().invalidate({});
+ }
+ }
+ if (!changed)
+ markAllAnalysesPreserved();
}
};
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
index ce6338fb348837..32c30c5bf2b292 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg))" --split-input-file | FileCheck %s
// CHECK-LABEL: llvm.func @basic_memset
// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
@@ -6,13 +6,13 @@ llvm.func @basic_memset(%memset_value: i8) -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%memset_len = llvm.mlir.constant(4 : i32) : i32
- // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
- // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
// CHECK-NOT: "llvm.intr.memset"
// CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+ // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
// CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+ // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
// CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
// CHECK-NOT: "llvm.intr.memset"
@@ -31,7 +31,14 @@ llvm.func @basic_memset_constant() -> i32 {
%memset_len = llvm.mlir.constant(4 : i32) : i32
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
- // CHECK: %[[RES:.*]] = llvm.mlir.constant(707406378 : i32) : i32
+ // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
+ // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]] : i32
+ // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]] : i32
+ // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]] : i32
+ // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]] : i32
// CHECK: llvm.return %[[RES]] : i32
llvm.return %2 : i32
}
@@ -44,16 +51,16 @@ llvm.func @exotic_target_memset(%memset_value: i8) -> i40 {
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%memset_len = llvm.mlir.constant(5 : i32) : i32
- // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
- // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
- // CHECK-DAG: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
// CHECK-NOT: "llvm.intr.memset"
// CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i40
+ // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
// CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
// CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+ // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
// CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
// CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+ // CHECK: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
// CHECK: %[[SHIFTED_COMPL:.*]] = llvm.shl %[[VALUE_32]], %[[C32]]
// CHECK: %[[VALUE_COMPL:.*]] = llvm.or %[[VALUE_32]], %[[SHIFTED_COMPL]]
// CHECK-NOT: "llvm.intr.memset"
@@ -72,7 +79,17 @@ llvm.func @exotic_target_memset_constant() -> i40 {
%memset_len = llvm.mlir.constant(5 : i32) : i32
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i40
- // CHECK: %[[RES:.*]] = llvm.mlir.constant(181096032810 : i40) : i40
+ // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // CHECK: %[[ZEXT_42:.*]] = llvm.zext %[[C42]] : i8 to i40
+ // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
+ // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[ZEXT_42]], %[[C8]] : i40
+ // CHECK: %[[OR_0:.*]] = llvm.or %[[ZEXT_42]], %[[SHIFTED_42]] : i40
+ // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
+ // CHECK: %[[SHIFTED_COMPL:.*]] = llvm.shl %[[OR_0]], %[[C16]] : i40
+ // CHECK: %[[OR_1:.*]] = llvm.or %[[OR_0]], %[[SHIFTED_COMPL]] : i40
+ // CHECK: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
+ // CHECK: %[[OR_COMPL:.*]] = llvm.shl %[[OR_1]], %[[C32]] : i40
+ // CHECK: %[[RES:.*]] = llvm.or %[[OR_1]], %[[OR_COMPL]] : i40
// CHECK: llvm.return %[[RES]] : i40
llvm.return %2 : i40
}
@@ -195,7 +212,7 @@ llvm.func @basic_memcpy_dest(%destination: !llvm.ptr) -> i32 {
// CHECK-LABEL: llvm.func @double_memcpy
llvm.func @double_memcpy() -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK-NEXT: %[[DATA:.*]] = llvm.mlir.constant(42 : i32) : i32
+ // CHECK: %[[DATA:.*]] = llvm.mlir.constant(42 : i32) : i32
%data = llvm.mlir.constant(42 : i32) : i32
%is_volatile = llvm.mlir.constant(false) : i1
%memcpy_len = llvm.mlir.constant(4 : i32) : i32
@@ -206,7 +223,7 @@ llvm.func @double_memcpy() -> i32 {
"llvm.intr.memcpy"(%2, %1, %memcpy_len) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
%res = llvm.load %2 : !llvm.ptr -> i32
- // CHECK-NEXT: llvm.return %[[DATA]] : i32
+ // CHECK: llvm.return %[[DATA]] : i32
llvm.return %res : i32
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/85426
More information about the Mlir-commits
mailing list