[Mlir-commits] [mlir] Fix block merging (PR #96871)

Giuseppe Rossini llvmlistbot at llvm.org
Fri Jun 28 01:25:58 PDT 2024


https://github.com/giuseros updated https://github.com/llvm/llvm-project/pull/96871

>From b54f74a6992a8adbad75f2979d590c36cfbf620c Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Mon, 24 Jun 2024 09:36:32 +0100
Subject: [PATCH 1/2] Fix block merging

---
 .../BufferDeallocationSimplification.cpp      |   9 +-
 mlir/lib/Transforms/Utils/RegionUtils.cpp     | 139 ++++++++++++++++--
 .../Linalg/detensorize_entry_block.mlir       |   6 +-
 3 files changed, 137 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 954485cfede3d..5227b22653eef 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
                  SplitDeallocWhenNotAliasingAnyOther,
                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
                                                                 analysis);
+    // We don't want that the block structure changes invalidating the
+    // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
+    // region simplification
+    GreedyRewriteConfig config;
+    config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
     populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
 
-    if (failed(
-            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+    if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+                                            config)))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 4c0f15bafbaba..31c4bfe1f6b2f 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -16,11 +17,15 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
 
 #include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
 
 #include <deque>
+#include <iterator>
 
 using namespace mlir;
 
@@ -699,9 +704,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
       blockIterators.push_back(mergeBlock->begin());
 
     // Update each of the predecessor terminators with the new arguments.
-    SmallVector<SmallVector<Value, 8>, 2> newArguments(
-        1 + blocksToMerge.size(),
-        SmallVector<Value, 8>(operandsToMerge.size()));
+    SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
+                                                       SmallVector<Value, 8>());
     unsigned curOpIndex = 0;
     for (const auto &it : llvm::enumerate(operandsToMerge)) {
       unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -712,13 +716,22 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
         Block::iterator &blockIter = blockIterators[i];
         std::advance(blockIter, nextOpOffset);
         auto &operand = blockIter->getOpOperand(it.value().second);
-        newArguments[i][it.index()] = operand.get();
-
-        // Update the operand and insert an argument if this is the leader.
-        if (i == 0) {
-          Value operandVal = operand.get();
-          operand.set(leaderBlock->addArgument(operandVal.getType(),
-                                               operandVal.getLoc()));
+        Value operandVal = operand.get();
+        Value *it = std::find(newArguments[i].begin(), newArguments[i].end(),
+                              operandVal);
+        if (it == newArguments[i].end()) {
+          newArguments[i].push_back(operandVal);
+          // Update the operand and insert an argument if this is the leader.
+          if (i == 0) {
+            operand.set(leaderBlock->addArgument(operandVal.getType(),
+                                                 operandVal.getLoc()));
+          }
+        } else if (i == 0) {
+          // If this is the leader, update the operand but do not insert a new
+          // argument. Instead, the opearand should point to one of the
+          // arguments we already passed (and that contained `operandVal`)
+          operand.set(leaderBlock->getArgument(
+              std::distance(newArguments[i].begin(), it)));
         }
       }
     }
@@ -818,6 +831,104 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
   return success(anyChanged);
 }
 
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+                                            Block &block) {
+  SmallVector<size_t> argsToErase;
+
+  // Go through the arguments of the block
+  for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
+    bool sameArg = true;
+    Value commonValue;
+
+    // Go through the block predecessor and flag if they pass to the block
+    // different values for the same argument
+    for (auto predIt = block.pred_begin(), predE = block.pred_end();
+         predIt != predE; ++predIt) {
+      auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
+      if (!branch) {
+        sameArg = false;
+        break;
+      }
+      unsigned succIndex = predIt.getSuccessorIndex();
+      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+      auto operands = succOperands.getForwardedOperands();
+      if (!commonValue) {
+        commonValue = operands[argIdx];
+      } else {
+        if (operands[argIdx] != commonValue) {
+          sameArg = false;
+          break;
+        }
+      }
+    }
+
+    // If they are passing the same value, drop the argument
+    if (commonValue && sameArg) {
+      argsToErase.push_back(argIdx);
+
+      // Remove the argument from the block
+      Value argVal = block.getArgument(argIdx);
+      rewriter.replaceAllUsesWith(argVal, commonValue);
+    }
+  }
+
+  // Remove the arguments
+  for (auto argIdx : llvm::reverse(argsToErase)) {
+    block.eraseArgument(argIdx);
+
+    // Remove the argument from the branch ops
+    for (auto predIt = block.pred_begin(), predE = block.pred_end();
+         predIt != predE; ++predIt) {
+      auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
+      unsigned succIndex = predIt.getSuccessorIndex();
+      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+      succOperands.erase(argIdx);
+    }
+  }
+  return success(!argsToErase.empty());
+}
+
+/// This optimization drops redundant argument to blocks. I.e., if a given
+/// argument to a block receives the same value from each of the block
+/// predecessors, we can remove the argument from the block and use directly the
+/// original value. This is a simple example:
+///
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(2 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
+/// : i64) ^bb1(%arg0 : i64, %arg1 : i64):
+///    llvm.call @foo(%arg0, %arg1)
+///
+/// The previous IR can be rewritten as:
+/// %cond = llvm.call @rand() : () -> i1
+/// %val = llvm.mlir.constant(1 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
+/// ^bb1(%arg0 : i64):
+///    llvm.call @foo(%val0, %arg1)
+///
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+                                            MutableArrayRef<Region> regions) {
+  llvm::SmallSetVector<Region *, 1> worklist;
+  for (auto &region : regions)
+    worklist.insert(&region);
+  bool anyChanged = false;
+  while (!worklist.empty()) {
+    Region *region = worklist.pop_back_val();
+
+    // Add any nested regions to the worklist.
+    for (Block &block : *region) {
+      anyChanged = succeeded(dropRedundantArguments(rewriter, block));
+
+      for (auto &op : block)
+        for (auto &nestedRegion : op.getRegions())
+          worklist.insert(&nestedRegion);
+    }
+  }
+  return success(anyChanged);
+}
+
 //===----------------------------------------------------------------------===//
 // Region Simplification
 //===----------------------------------------------------------------------===//
@@ -832,8 +943,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
   bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
   bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
   bool mergedIdenticalBlocks = false;
-  if (mergeBlocks)
+  bool droppedRedundantArguments = false;
+  if (mergeBlocks) {
     mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
+    droppedRedundantArguments =
+        succeeded(dropRedundantArguments(rewriter, regions));
+  }
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
-                 mergedIdenticalBlocks);
+                 mergedIdenticalBlocks || droppedRedundantArguments);
 }
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
index d1a89226fdb58..50a2d6bf532aa 100644
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
 // CHECK-LABEL: @main
 // CHECK-SAME:       (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
 // CHECK:   %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
-// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
-// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
-// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
+// CHECK: cf.br ^{{.*}}
+// CHECK: ^{{.*}}:
+// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
 // CHECK:   return %[[ELEMENTS]] : tensor<f32>

>From 88a11539bb33b3c56c6202648edc51d1a7f3f364 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Fri, 28 Jun 2024 09:25:13 +0100
Subject: [PATCH 2/2] Add test case

---
 .../Transforms/canonicalize-block-merge.mlir  |  6 +-
 .../test-canonicalize-merge-large-blocks.mlir | 76 +++++++++++++++++++
 2 files changed, 79 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir

diff --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir
index 3b8b1fce0575a..92cfde817cf7f 100644
--- a/mlir/test/Transforms/canonicalize-block-merge.mlir
+++ b/mlir/test/Transforms/canonicalize-block-merge.mlir
@@ -87,7 +87,7 @@ func.func @mismatch_operands_matching_arguments(%cond : i1, %arg0 : i32, %arg1 :
 
 // CHECK-LABEL: func @mismatch_argument_uses(
 func.func @mismatch_argument_uses(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32, i32) {
-  // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+  // CHECK: return {{.*}}, {{.*}}
 
   cf.cond_br %cond, ^bb1(%arg1 : i32), ^bb2(%arg0 : i32)
 
@@ -101,7 +101,7 @@ func.func @mismatch_argument_uses(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32,
 
 // CHECK-LABEL: func @mismatch_argument_types(
 func.func @mismatch_argument_types(%cond : i1, %arg0 : i32, %arg1 : i16) {
-  // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+  // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
 
   cf.cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg1 : i16)
 
@@ -115,7 +115,7 @@ func.func @mismatch_argument_types(%cond : i1, %arg0 : i32, %arg1 : i16) {
 
 // CHECK-LABEL: func @mismatch_argument_count(
 func.func @mismatch_argument_count(%cond : i1, %arg0 : i32) {
-  // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+  // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
 
   cf.cond_br %cond, ^bb1(%arg0 : i32), ^bb2
 
diff --git a/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
new file mode 100644
index 0000000000000..570ff6905a04d
--- /dev/null
+++ b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
@@ -0,0 +1,76 @@
+ // RUN: mlir-opt -pass-pipeline='builtin.module(llvm.func(canonicalize{region-simplify=aggressive}))' %s | FileCheck %s
+
+llvm.func @foo(%arg0: i64)
+
+llvm.func @rand() -> i1
+
+// CHECK-LABEL: func @large_merge_block(
+llvm.func @large_merge_block(%arg0: i64) {
+  // CHECK:  %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK:  %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+  // CHECK:  %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+  // CHECK:  %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+  // CHECK:  %[[C4:.*]] = llvm.mlir.constant(4 : i64) : i64
+
+  // CHECK:  llvm.cond_br %5, ^bb1(%[[C1]], %[[C3]], %[[C4]], %[[C2]] : i64, i64, i64, i64), ^bb1(%[[C4]], %[[C2]], %[[C1]], %[[C3]] : i64, i64, i64, i64)
+  // CHECK: ^bb{{.*}}(%[[arg0:.*]]: i64, %[[arg1:.*]]: i64, %[[arg2:.*]]: i64, %[[arg3:.*]]: i64):
+  // CHECK:    llvm.cond_br %{{.*}}, ^bb2(%[[arg0]] : i64), ^bb2(%[[arg3]] : i64)
+  // CHECK: ^bb{{.*}}(%11: i64):
+  // CHECK:    llvm.br ^bb{{.*}}
+  // CHECK: ^bb{{.*}}:
+  // CHECK:   llvm.call
+  // CHECK:   llvm.cond_br {{.*}}, ^bb{{.*}}(%[[arg1]] : i64), ^bb{{.*}}(%[[arg2]] : i64)
+  // CHECK: ^bb{{.*}}:
+  // CHECK:   llvm.call
+  // CHECK    llvm.br ^bb{{.*}}
+
+  %0 = llvm.mlir.constant(0 : i64) : i64
+  %1 = llvm.mlir.constant(1 : i64) : i64
+  %2 = llvm.mlir.constant(2 : i64) : i64
+  %3 = llvm.mlir.constant(3 : i64) : i64
+  %4 = llvm.mlir.constant(4 : i64) : i64
+  %10 = llvm.icmp "eq" %arg0, %0 : i64
+  llvm.cond_br %10, ^bb1, ^bb14
+^bb1:  // pred: ^bb0
+  %11 = llvm.call @rand() : () -> i1
+  llvm.cond_br %11, ^bb2, ^bb3
+^bb2:  // pred: ^bb1
+  llvm.call @foo(%1) : (i64) -> ()
+  llvm.br ^bb4
+^bb3:  // pred: ^bb1
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.br ^bb4
+^bb4:  // 2 preds: ^bb2, ^bb3
+  %14 = llvm.call @rand() : () -> i1
+  llvm.cond_br %14, ^bb5, ^bb6
+^bb5:  // pred: ^bb4
+  llvm.call @foo(%3) : (i64) -> ()
+  llvm.br ^bb13
+^bb6:  // pred: ^bb4
+  llvm.call @foo(%4) : (i64) -> ()
+  llvm.br ^bb13
+^bb13:  // 2 preds: ^bb11, ^bb12
+  llvm.br ^bb27
+^bb14:  // pred: ^bb0
+  %23 = llvm.call @rand() : () -> i1
+  llvm.cond_br %23, ^bb15, ^bb16
+^bb15:  // pred: ^bb14
+  llvm.call @foo(%4) : (i64) -> ()
+  llvm.br ^bb17
+^bb16:  // pred: ^bb14
+  llvm.call @foo(%3) : (i64) -> ()
+  llvm.br ^bb17
+^bb17:  // 2 preds: ^bb15, ^bb16
+  %26 = llvm.call @rand() : () -> i1
+  llvm.cond_br %26, ^bb18, ^bb19
+^bb18:  // pred: ^bb17
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.br ^bb26
+^bb19:  // pred: ^bb17
+  llvm.call @foo(%1) : (i64) -> ()
+  llvm.br ^bb26
+^bb26:  // 2 preds: ^bb24, ^bb25
+  llvm.br ^bb27
+^bb27:  // 2 preds: ^bb13, ^bb26
+  llvm.return
+}



More information about the Mlir-commits mailing list