[Mlir-commits] [mlir] 3e2992f - [MLIR][Mem2Reg] Replace pattern based approach with a bulk one (#85426)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 18 00:32:53 PDT 2024


Author: Christian Ulmann
Date: 2024-03-18T08:32:50+01:00
New Revision: 3e2992f0706db7089b42c54c4eb40c65afc98ec6

URL: https://github.com/llvm/llvm-project/commit/3e2992f0706db7089b42c54c4eb40c65afc98ec6
DIFF: https://github.com/llvm/llvm-project/commit/3e2992f0706db7089b42c54c4eb40c65afc98ec6.diff

LOG: [MLIR][Mem2Reg] Replace pattern based approach with a bulk one (#85426)

This commit changes MLIR's Mem2Reg implementation back from being
pattern based into a full pass. Using Mem2Reg as a pattern is
wasteful, as each application can invalidate the dominance info.
Applying changes in bulk allows for reuse of the same dominance info.

Unfortunately, this requires some test changes, due to the `IRBuilder`
not simplifying IR.

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/Mem2Reg.h
    mlir/lib/Transforms/Mem2Reg.cpp
    mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index 89244feb21754e..d145f7ed437582 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -9,8 +9,6 @@
 #ifndef MLIR_TRANSFORMS_MEM2REG_H
 #define MLIR_TRANSFORMS_MEM2REG_H
 
-#include "mlir/IR/Dominance.h"
-#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "llvm/ADT/Statistic.h"
@@ -25,24 +23,6 @@ struct Mem2RegStatistics {
   llvm::Statistic *newBlockArgumentAmount = nullptr;
 };
 
-/// Pattern applying mem2reg to the regions of the operations on which it
-/// matches.
-class Mem2RegPattern
-    : public OpInterfaceRewritePattern<PromotableAllocationOpInterface> {
-public:
-  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
-
-  Mem2RegPattern(MLIRContext *context, Mem2RegStatistics statistics = {},
-                 PatternBenefit benefit = 1)
-      : OpInterfaceRewritePattern(context, benefit), statistics(statistics) {}
-
-  LogicalResult matchAndRewrite(PromotableAllocationOpInterface allocator,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  Mem2RegStatistics statistics;
-};
-
 /// Attempts to promote the memory slots of the provided allocators. Succeeds if
 /// at least one memory slot was promoted.
 LogicalResult

diff  --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index f3a973d9994083..84ac69b4514b4f 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -14,10 +14,8 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
 #include "mlir/Transforms/RegionUtils.h"
-#include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/GenericIteratedDominanceFrontier.h"
@@ -635,13 +633,6 @@ LogicalResult mlir::tryToPromoteMemorySlots(
   return success(promotedAny);
 }
 
-LogicalResult
-Mem2RegPattern::matchAndRewrite(PromotableAllocationOpInterface allocator,
-                                PatternRewriter &rewriter) const {
-  hasBoundedRewriteRecursion();
-  return tryToPromoteMemorySlots({allocator}, rewriter, statistics);
-}
-
 namespace {
 
 struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
@@ -650,17 +641,36 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
   void runOnOperation() override {
     Operation *scopeOp = getOperation();
 
-    Mem2RegStatistics statictics{&promotedAmount, &newBlockArgumentAmount};
+    Mem2RegStatistics statistics{&promotedAmount, &newBlockArgumentAmount};
+
+    bool changed = false;
+
+    for (Region &region : scopeOp->getRegions()) {
+      if (region.getBlocks().empty())
+        continue;
 
-    GreedyRewriteConfig config;
-    config.enableRegionSimplification = enableRegionSimplification;
+      OpBuilder builder(&region.front(), region.front().begin());
+      IRRewriter rewriter(builder);
 
-    RewritePatternSet rewritePatterns(&getContext());
-    rewritePatterns.add<Mem2RegPattern>(&getContext(), statictics);
-    FrozenRewritePatternSet frozen(std::move(rewritePatterns));
+      // Promoting a slot can allow for further promotion of other slots,
+      // promotion is tried until no promotion succeeds.
+      while (true) {
+        SmallVector<PromotableAllocationOpInterface> allocators;
+        // Build a list of allocators to attempt to promote the slots of.
+        region.walk([&](PromotableAllocationOpInterface allocator) {
+          allocators.emplace_back(allocator);
+        });
 
-    if (failed(applyPatternsAndFoldGreedily(scopeOp, frozen, config)))
-      signalPassFailure();
+        // Attempt promoting until no promotion succeeds.
+        if (failed(tryToPromoteMemorySlots(allocators, rewriter, statistics)))
+          break;
+
+        changed = true;
+        getAnalysisManager().invalidate({});
+      }
+    }
+    if (!changed)
+      markAllAnalysesPreserved();
   }
 };
 

diff  --git a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
index ce6338fb348837..4fc80a87f20df5 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg))" --split-input-file | FileCheck %s
 
 // CHECK-LABEL: llvm.func @basic_memset
 // CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
@@ -6,13 +6,13 @@ llvm.func @basic_memset(%memset_value: i8) -> i32 {
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
   %memset_len = llvm.mlir.constant(4 : i32) : i32
-  // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
-  // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
   // CHECK-NOT: "llvm.intr.memset"
   // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
   // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
   // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
   // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
   // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
   // CHECK-NOT: "llvm.intr.memset"
@@ -31,7 +31,14 @@ llvm.func @basic_memset_constant() -> i32 {
   %memset_len = llvm.mlir.constant(4 : i32) : i32
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
-  // CHECK: %[[RES:.*]] = llvm.mlir.constant(707406378 : i32) : i32
+  // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]]  : i32
+  // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]]  : i32
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]]  : i32
+  // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]]  : i32
   // CHECK: llvm.return %[[RES]] : i32
   llvm.return %2 : i32
 }
@@ -44,16 +51,16 @@ llvm.func @exotic_target_memset(%memset_value: i8) -> i40 {
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
   %memset_len = llvm.mlir.constant(5 : i32) : i32
-  // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
-  // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
-  // CHECK-DAG: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
   // CHECK-NOT: "llvm.intr.memset"
   // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i40
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
   // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
   // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
   // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
   // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+  // CHECK: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
   // CHECK: %[[SHIFTED_COMPL:.*]] = llvm.shl %[[VALUE_32]], %[[C32]]
   // CHECK: %[[VALUE_COMPL:.*]] = llvm.or %[[VALUE_32]], %[[SHIFTED_COMPL]]
   // CHECK-NOT: "llvm.intr.memset"
@@ -64,21 +71,6 @@ llvm.func @exotic_target_memset(%memset_value: i8) -> i40 {
 
 // -----
 
-// CHECK-LABEL: llvm.func @exotic_target_memset_constant
-llvm.func @exotic_target_memset_constant() -> i40 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  %1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  %memset_value = llvm.mlir.constant(42 : i8) : i8
-  %memset_len = llvm.mlir.constant(5 : i32) : i32
-  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
-  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i40
-  // CHECK: %[[RES:.*]] = llvm.mlir.constant(181096032810 : i40) : i40
-  // CHECK: llvm.return %[[RES]] : i40
-  llvm.return %2 : i40
-}
-
-// -----
-
 // CHECK-LABEL: llvm.func @no_volatile_memset
 llvm.func @no_volatile_memset() -> i32 {
   // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -195,7 +187,7 @@ llvm.func @basic_memcpy_dest(%destination: !llvm.ptr) -> i32 {
 // CHECK-LABEL: llvm.func @double_memcpy
 llvm.func @double_memcpy() -> i32 {
   %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-NEXT: %[[DATA:.*]] = llvm.mlir.constant(42 : i32) : i32
+  // CHECK: %[[DATA:.*]] = llvm.mlir.constant(42 : i32) : i32
   %data = llvm.mlir.constant(42 : i32) : i32
   %is_volatile = llvm.mlir.constant(false) : i1
   %memcpy_len = llvm.mlir.constant(4 : i32) : i32
@@ -206,7 +198,7 @@ llvm.func @double_memcpy() -> i32 {
   "llvm.intr.memcpy"(%2, %1, %memcpy_len) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
 
   %res = llvm.load %2 : !llvm.ptr -> i32
-  // CHECK-NEXT: llvm.return %[[DATA]] : i32
+  // CHECK: llvm.return %[[DATA]] : i32
   llvm.return %res : i32
 }
 


        


More information about the Mlir-commits mailing list