[Mlir-commits] [mlir] [MLIR][Mem2Reg] Replace pattern based approach with a bulk one (PR #85426)

Christian Ulmann llvmlistbot at llvm.org
Fri Mar 15 09:47:02 PDT 2024


https://github.com/Dinistro created https://github.com/llvm/llvm-project/pull/85426

This commit changes MLIR's Mem2Reg implementation back from being
pattern based into a full pass. Using Mem2Reg as a pattern is
wasteful, as each application can invalidate the dominance info.
Applying changes in bulk allows for reuse of the same dominance info.

Unfortunately, this requires some test changes, due to the `IRBuilder`
not simplifying IR.

>From 0c9431168afa054a5d9d493582687e332969dc8a Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 15 Mar 2024 07:21:19 +0000
Subject: [PATCH] [MLIR][Mem2Reg] Replace pattern based approach with a
 one-shot pass

This commit changes MLIR's Mem2Reg implementation back from being
pattern based into a full pass. Using Mem2Reg as a pattern is
wasteful, as each application can invalidate the dominance info.
Applying changes in bulk allows for reuse of the same dominance info.

Unfortunately, this requires some test changes, due to the `IRBuilder`
not simplifying IR.
---
 mlir/include/mlir/Transforms/Mem2Reg.h        | 20 ---------
 mlir/lib/Transforms/Mem2Reg.cpp               | 45 ++++++++++++-------
 .../Dialect/LLVMIR/mem2reg-intrinsics.mlir    | 37 ++++++++++-----
 3 files changed, 55 insertions(+), 47 deletions(-)

diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index 89244feb21754e..d145f7ed437582 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -9,8 +9,6 @@
 #ifndef MLIR_TRANSFORMS_MEM2REG_H
 #define MLIR_TRANSFORMS_MEM2REG_H
 
-#include "mlir/IR/Dominance.h"
-#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "llvm/ADT/Statistic.h"
@@ -25,24 +23,6 @@ struct Mem2RegStatistics {
   llvm::Statistic *newBlockArgumentAmount = nullptr;
 };
 
-/// Pattern applying mem2reg to the regions of the operations on which it
-/// matches.
-class Mem2RegPattern
-    : public OpInterfaceRewritePattern<PromotableAllocationOpInterface> {
-public:
-  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
-
-  Mem2RegPattern(MLIRContext *context, Mem2RegStatistics statistics = {},
-                 PatternBenefit benefit = 1)
-      : OpInterfaceRewritePattern(context, benefit), statistics(statistics) {}
-
-  LogicalResult matchAndRewrite(PromotableAllocationOpInterface allocator,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  Mem2RegStatistics statistics;
-};
-
 /// Attempts to promote the memory slots of the provided allocators. Succeeds if
 /// at least one memory slot was promoted.
 LogicalResult
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index f3a973d9994083..5d39aecb2c7d14 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -14,10 +14,8 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
 #include "mlir/Transforms/RegionUtils.h"
-#include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/GenericIteratedDominanceFrontier.h"
@@ -635,13 +633,6 @@ LogicalResult mlir::tryToPromoteMemorySlots(
   return success(promotedAny);
 }
 
-LogicalResult
-Mem2RegPattern::matchAndRewrite(PromotableAllocationOpInterface allocator,
-                                PatternRewriter &rewriter) const {
-  hasBoundedRewriteRecursion();
-  return tryToPromoteMemorySlots({allocator}, rewriter, statistics);
-}
-
 namespace {
 
 struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
@@ -650,17 +641,37 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
   void runOnOperation() override {
     Operation *scopeOp = getOperation();
 
-    Mem2RegStatistics statictics{&promotedAmount, &newBlockArgumentAmount};
+    Mem2RegStatistics statistics{&promotedAmount, &newBlockArgumentAmount};
 
-    GreedyRewriteConfig config;
-    config.enableRegionSimplification = enableRegionSimplification;
+    bool changed = false;
 
-    RewritePatternSet rewritePatterns(&getContext());
-    rewritePatterns.add<Mem2RegPattern>(&getContext(), statictics);
-    FrozenRewritePatternSet frozen(std::move(rewritePatterns));
+    for (Region &region : scopeOp->getRegions()) {
+      if (region.getBlocks().empty())
+        continue;
 
-    if (failed(applyPatternsAndFoldGreedily(scopeOp, frozen, config)))
-      signalPassFailure();
+      OpBuilder builder(&region.front(), region.front().begin());
+      IRRewriter rewriter(builder);
+
+      // Promoting a slot can allow for further promotion of other slots,
+      // promotion is tried until no promotion succeeds.
+      while (true) {
+        SmallVector<PromotableAllocationOpInterface> allocators;
+        // Build a list of allocators to attempt to promote the slots of.
+        for (Block &block : region)
+          for (Operation &op : block.getOperations())
+            if (auto allocator = dyn_cast<PromotableAllocationOpInterface>(op))
+              allocators.emplace_back(allocator);
+
+        // Attempt promoting until no promotion succeeds.
+        if (failed(tryToPromoteMemorySlots(allocators, rewriter, statistics)))
+          break;
+
+        changed = true;
+        getAnalysisManager().invalidate({});
+      }
+    }
+    if (!changed)
+      markAllAnalysesPreserved();
   }
 };
 
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
index ce6338fb348837..32c30c5bf2b292 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg))" --split-input-file | FileCheck %s
 
 // CHECK-LABEL: llvm.func @basic_memset
 // CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
@@ -6,13 +6,13 @@ llvm.func @basic_memset(%memset_value: i8) -> i32 {
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
   %memset_len = llvm.mlir.constant(4 : i32) : i32
-  // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
-  // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
   // CHECK-NOT: "llvm.intr.memset"
   // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
   // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
   // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
   // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
   // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
   // CHECK-NOT: "llvm.intr.memset"
@@ -31,7 +31,14 @@ llvm.func @basic_memset_constant() -> i32 {
   %memset_len = llvm.mlir.constant(4 : i32) : i32
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
-  // CHECK: %[[RES:.*]] = llvm.mlir.constant(707406378 : i32) : i32
+  // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]]  : i32
+  // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]]  : i32
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]]  : i32
+  // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]]  : i32
   // CHECK: llvm.return %[[RES]] : i32
   llvm.return %2 : i32
 }
@@ -44,16 +51,16 @@ llvm.func @exotic_target_memset(%memset_value: i8) -> i40 {
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
   %memset_len = llvm.mlir.constant(5 : i32) : i32
-  // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
-  // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
-  // CHECK-DAG: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
   // CHECK-NOT: "llvm.intr.memset"
   // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i40
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
   // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
   // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
   // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
   // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+  // CHECK: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
   // CHECK: %[[SHIFTED_COMPL:.*]] = llvm.shl %[[VALUE_32]], %[[C32]]
   // CHECK: %[[VALUE_COMPL:.*]] = llvm.or %[[VALUE_32]], %[[SHIFTED_COMPL]]
   // CHECK-NOT: "llvm.intr.memset"
@@ -72,7 +79,17 @@ llvm.func @exotic_target_memset_constant() -> i40 {
   %memset_len = llvm.mlir.constant(5 : i32) : i32
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i40
-  // CHECK: %[[RES:.*]] = llvm.mlir.constant(181096032810 : i40) : i40
+  // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // CHECK: %[[ZEXT_42:.*]] = llvm.zext %[[C42]] : i8 to i40
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
+  // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[ZEXT_42]], %[[C8]]  : i40
+  // CHECK: %[[OR_0:.*]] = llvm.or %[[ZEXT_42]], %[[SHIFTED_42]]  : i40
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
+  // CHECK: %[[SHIFTED_COMPL:.*]] = llvm.shl %[[OR_0]], %[[C16]]  : i40
+  // CHECK: %[[OR_1:.*]] = llvm.or %[[OR_0]], %[[SHIFTED_COMPL]]  : i40
+  // CHECK: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
+  // CHECK: %[[OR_COMPL:.*]] = llvm.shl %[[OR_1]], %[[C32]]  : i40
+  // CHECK: %[[RES:.*]] = llvm.or %[[OR_1]], %[[OR_COMPL]]  : i40
   // CHECK: llvm.return %[[RES]] : i40
   llvm.return %2 : i40
 }
@@ -195,7 +212,7 @@ llvm.func @basic_memcpy_dest(%destination: !llvm.ptr) -> i32 {
 // CHECK-LABEL: llvm.func @double_memcpy
 llvm.func @double_memcpy() -> i32 {
   %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-NEXT: %[[DATA:.*]] = llvm.mlir.constant(42 : i32) : i32
+  // CHECK: %[[DATA:.*]] = llvm.mlir.constant(42 : i32) : i32
   %data = llvm.mlir.constant(42 : i32) : i32
   %is_volatile = llvm.mlir.constant(false) : i1
   %memcpy_len = llvm.mlir.constant(4 : i32) : i32
@@ -206,7 +223,7 @@ llvm.func @double_memcpy() -> i32 {
   "llvm.intr.memcpy"(%2, %1, %memcpy_len) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
 
   %res = llvm.load %2 : !llvm.ptr -> i32
-  // CHECK-NEXT: llvm.return %[[DATA]] : i32
+  // CHECK: llvm.return %[[DATA]] : i32
   llvm.return %res : i32
 }
 



More information about the Mlir-commits mailing list