[Mlir-commits] [mlir] [MLIR][Mem2Reg] Replace pattern based approach with a bulk one (PR #85426)

Christian Ulmann llvmlistbot at llvm.org
Sat Mar 16 07:39:21 PDT 2024


https://github.com/Dinistro updated https://github.com/llvm/llvm-project/pull/85426

>From 0c9431168afa054a5d9d493582687e332969dc8a Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 15 Mar 2024 07:21:19 +0000
Subject: [PATCH 1/2] [MLIR][Mem2Reg] Replace pattern based approach with a
 one-shot pass

This commit changes MLIR's Mem2Reg implementation back from being
pattern based into a full pass. Using Mem2Reg as a pattern is
wasteful, as each application can invalidate the dominance info.
Applying changes in bulk allows for reuse of the same dominance info.

Unfortunately, this requires some test changes, due to the `IRBuilder`
not simplifying IR.
---
 mlir/include/mlir/Transforms/Mem2Reg.h        | 20 ---------
 mlir/lib/Transforms/Mem2Reg.cpp               | 45 ++++++++++++-------
 .../Dialect/LLVMIR/mem2reg-intrinsics.mlir    | 37 ++++++++++-----
 3 files changed, 55 insertions(+), 47 deletions(-)

diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index 89244feb21754e..d145f7ed437582 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -9,8 +9,6 @@
 #ifndef MLIR_TRANSFORMS_MEM2REG_H
 #define MLIR_TRANSFORMS_MEM2REG_H
 
-#include "mlir/IR/Dominance.h"
-#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "llvm/ADT/Statistic.h"
@@ -25,24 +23,6 @@ struct Mem2RegStatistics {
   llvm::Statistic *newBlockArgumentAmount = nullptr;
 };
 
-/// Pattern applying mem2reg to the regions of the operations on which it
-/// matches.
-class Mem2RegPattern
-    : public OpInterfaceRewritePattern<PromotableAllocationOpInterface> {
-public:
-  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
-
-  Mem2RegPattern(MLIRContext *context, Mem2RegStatistics statistics = {},
-                 PatternBenefit benefit = 1)
-      : OpInterfaceRewritePattern(context, benefit), statistics(statistics) {}
-
-  LogicalResult matchAndRewrite(PromotableAllocationOpInterface allocator,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  Mem2RegStatistics statistics;
-};
-
 /// Attempts to promote the memory slots of the provided allocators. Succeeds if
 /// at least one memory slot was promoted.
 LogicalResult
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index f3a973d9994083..5d39aecb2c7d14 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -14,10 +14,8 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
 #include "mlir/Transforms/RegionUtils.h"
-#include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/GenericIteratedDominanceFrontier.h"
@@ -635,13 +633,6 @@ LogicalResult mlir::tryToPromoteMemorySlots(
   return success(promotedAny);
 }
 
-LogicalResult
-Mem2RegPattern::matchAndRewrite(PromotableAllocationOpInterface allocator,
-                                PatternRewriter &rewriter) const {
-  hasBoundedRewriteRecursion();
-  return tryToPromoteMemorySlots({allocator}, rewriter, statistics);
-}
-
 namespace {
 
 struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
@@ -650,17 +641,37 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
   void runOnOperation() override {
     Operation *scopeOp = getOperation();
 
-    Mem2RegStatistics statictics{&promotedAmount, &newBlockArgumentAmount};
+    Mem2RegStatistics statistics{&promotedAmount, &newBlockArgumentAmount};
 
-    GreedyRewriteConfig config;
-    config.enableRegionSimplification = enableRegionSimplification;
+    bool changed = false;
 
-    RewritePatternSet rewritePatterns(&getContext());
-    rewritePatterns.add<Mem2RegPattern>(&getContext(), statictics);
-    FrozenRewritePatternSet frozen(std::move(rewritePatterns));
+    for (Region &region : scopeOp->getRegions()) {
+      if (region.getBlocks().empty())
+        continue;
 
-    if (failed(applyPatternsAndFoldGreedily(scopeOp, frozen, config)))
-      signalPassFailure();
+      OpBuilder builder(&region.front(), region.front().begin());
+      IRRewriter rewriter(builder);
+
+      // Promoting a slot can allow for further promotion of other slots,
+      // promotion is tried until no promotion succeeds.
+      while (true) {
+        SmallVector<PromotableAllocationOpInterface> allocators;
+        // Build a list of allocators to attempt to promote the slots of.
+        for (Block &block : region)
+          for (Operation &op : block.getOperations())
+            if (auto allocator = dyn_cast<PromotableAllocationOpInterface>(op))
+              allocators.emplace_back(allocator);
+
+        // Attempt promoting until no promotion succeeds.
+        if (failed(tryToPromoteMemorySlots(allocators, rewriter, statistics)))
+          break;
+
+        changed = true;
+        getAnalysisManager().invalidate({});
+      }
+    }
+    if (!changed)
+      markAllAnalysesPreserved();
   }
 };
 
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
index ce6338fb348837..32c30c5bf2b292 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg))" --split-input-file | FileCheck %s
 
 // CHECK-LABEL: llvm.func @basic_memset
 // CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
@@ -6,13 +6,13 @@ llvm.func @basic_memset(%memset_value: i8) -> i32 {
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
   %memset_len = llvm.mlir.constant(4 : i32) : i32
-  // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
-  // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
   // CHECK-NOT: "llvm.intr.memset"
   // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
   // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
   // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
   // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
   // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
   // CHECK-NOT: "llvm.intr.memset"
@@ -31,7 +31,14 @@ llvm.func @basic_memset_constant() -> i32 {
   %memset_len = llvm.mlir.constant(4 : i32) : i32
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
-  // CHECK: %[[RES:.*]] = llvm.mlir.constant(707406378 : i32) : i32
+  // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]]  : i32
+  // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]]  : i32
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]]  : i32
+  // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]]  : i32
   // CHECK: llvm.return %[[RES]] : i32
   llvm.return %2 : i32
 }
@@ -44,16 +51,16 @@ llvm.func @exotic_target_memset(%memset_value: i8) -> i40 {
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
   %memset_len = llvm.mlir.constant(5 : i32) : i32
-  // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
-  // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
-  // CHECK-DAG: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
   // CHECK-NOT: "llvm.intr.memset"
   // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i40
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
   // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
   // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
   // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
   // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+  // CHECK: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
   // CHECK: %[[SHIFTED_COMPL:.*]] = llvm.shl %[[VALUE_32]], %[[C32]]
   // CHECK: %[[VALUE_COMPL:.*]] = llvm.or %[[VALUE_32]], %[[SHIFTED_COMPL]]
   // CHECK-NOT: "llvm.intr.memset"
@@ -72,7 +79,17 @@ llvm.func @exotic_target_memset_constant() -> i40 {
   %memset_len = llvm.mlir.constant(5 : i32) : i32
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i40
-  // CHECK: %[[RES:.*]] = llvm.mlir.constant(181096032810 : i40) : i40
+  // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // CHECK: %[[ZEXT_42:.*]] = llvm.zext %[[C42]] : i8 to i40
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
+  // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[ZEXT_42]], %[[C8]]  : i40
+  // CHECK: %[[OR_0:.*]] = llvm.or %[[ZEXT_42]], %[[SHIFTED_42]]  : i40
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
+  // CHECK: %[[SHIFTED_COMPL:.*]] = llvm.shl %[[OR_0]], %[[C16]]  : i40
+  // CHECK: %[[OR_1:.*]] = llvm.or %[[OR_0]], %[[SHIFTED_COMPL]]  : i40
+  // CHECK: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
+  // CHECK: %[[OR_COMPL:.*]] = llvm.shl %[[OR_1]], %[[C32]]  : i40
+  // CHECK: %[[RES:.*]] = llvm.or %[[OR_1]], %[[OR_COMPL]]  : i40
   // CHECK: llvm.return %[[RES]] : i40
   llvm.return %2 : i40
 }
@@ -195,7 +212,7 @@ llvm.func @basic_memcpy_dest(%destination: !llvm.ptr) -> i32 {
 // CHECK-LABEL: llvm.func @double_memcpy
 llvm.func @double_memcpy() -> i32 {
   %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-NEXT: %[[DATA:.*]] = llvm.mlir.constant(42 : i32) : i32
+  // CHECK: %[[DATA:.*]] = llvm.mlir.constant(42 : i32) : i32
   %data = llvm.mlir.constant(42 : i32) : i32
   %is_volatile = llvm.mlir.constant(false) : i1
   %memcpy_len = llvm.mlir.constant(4 : i32) : i32
@@ -206,7 +223,7 @@ llvm.func @double_memcpy() -> i32 {
   "llvm.intr.memcpy"(%2, %1, %memcpy_len) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
 
   %res = llvm.load %2 : !llvm.ptr -> i32
-  // CHECK-NEXT: llvm.return %[[DATA]] : i32
+  // CHECK: llvm.return %[[DATA]] : i32
   llvm.return %res : i32
 }
 

>From 0e241e759a649f5ef39201c85fc7d1c4a90e62b4 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Sat, 16 Mar 2024 14:39:10 +0000
Subject: [PATCH 2/2] address review comments & fix region walk

---
 mlir/lib/Transforms/Mem2Reg.cpp               |  7 +++---
 .../Dialect/LLVMIR/mem2reg-intrinsics.mlir    | 25 -------------------
 2 files changed, 3 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 5d39aecb2c7d14..84ac69b4514b4f 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -657,10 +657,9 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
       while (true) {
         SmallVector<PromotableAllocationOpInterface> allocators;
         // Build a list of allocators to attempt to promote the slots of.
-        for (Block &block : region)
-          for (Operation &op : block.getOperations())
-            if (auto allocator = dyn_cast<PromotableAllocationOpInterface>(op))
-              allocators.emplace_back(allocator);
+        region.walk([&](PromotableAllocationOpInterface allocator) {
+          allocators.emplace_back(allocator);
+        });
 
         // Attempt promoting until no promotion succeeds.
         if (failed(tryToPromoteMemorySlots(allocators, rewriter, statistics)))
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
index 32c30c5bf2b292..4fc80a87f20df5 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
@@ -71,31 +71,6 @@ llvm.func @exotic_target_memset(%memset_value: i8) -> i40 {
 
 // -----
 
-// CHECK-LABEL: llvm.func @exotic_target_memset_constant
-llvm.func @exotic_target_memset_constant() -> i40 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  %1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  %memset_value = llvm.mlir.constant(42 : i8) : i8
-  %memset_len = llvm.mlir.constant(5 : i32) : i32
-  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
-  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i40
-  // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
-  // CHECK: %[[ZEXT_42:.*]] = llvm.zext %[[C42]] : i8 to i40
-  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
-  // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[ZEXT_42]], %[[C8]]  : i40
-  // CHECK: %[[OR_0:.*]] = llvm.or %[[ZEXT_42]], %[[SHIFTED_42]]  : i40
-  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
-  // CHECK: %[[SHIFTED_COMPL:.*]] = llvm.shl %[[OR_0]], %[[C16]]  : i40
-  // CHECK: %[[OR_1:.*]] = llvm.or %[[OR_0]], %[[SHIFTED_COMPL]]  : i40
-  // CHECK: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
-  // CHECK: %[[OR_COMPL:.*]] = llvm.shl %[[OR_1]], %[[C32]]  : i40
-  // CHECK: %[[RES:.*]] = llvm.or %[[OR_1]], %[[OR_COMPL]]  : i40
-  // CHECK: llvm.return %[[RES]] : i40
-  llvm.return %2 : i40
-}
-
-// -----
-
 // CHECK-LABEL: llvm.func @no_volatile_memset
 llvm.func @no_volatile_memset() -> i32 {
   // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32



More information about the Mlir-commits mailing list