[Mlir-commits] [mlir] Fix block merging (PR #97697)

Giuseppe Rossini llvmlistbot at llvm.org
Thu Jul 4 01:59:28 PDT 2024


https://github.com/giuseros created https://github.com/llvm/llvm-project/pull/97697

With this PR I am trying to address: https://github.com/llvm/llvm-project/issues/63230. 

What changed:
- While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same `Value`. The operations operands in the block will point to the argument we already inserted. This needs to happen to all the arguments we pass to the different successors of the parent block
- After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument. 
- This last simplification clashed with `BufferDeallocationSimplification`. The reason, I think, is that the two simplifications are clashing. I.e., `BufferDeallocationSimplification` contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass. 

**Note**: this a rework of #96871 . I ran all the integration tests (`-DMLIR_INCLUDE_INTEGRATION_TESTS=ON`) and they passed.

>From a75c52d65a27819222ac77116abdb0c538647d2d Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Mon, 24 Jun 2024 09:36:32 +0100
Subject: [PATCH 1/4] Fix block merging

---
 .../BufferDeallocationSimplification.cpp      |   9 +-
 mlir/lib/Transforms/Utils/RegionUtils.cpp     | 139 ++++++++++++++++--
 .../Linalg/detensorize_entry_block.mlir       |   6 +-
 3 files changed, 137 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 954485cfede3da..5227b22653eefc 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
                  SplitDeallocWhenNotAliasingAnyOther,
                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
                                                                 analysis);
+    // We don't want that the block structure changes invalidating the
+    // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
+    // region simplification
+    GreedyRewriteConfig config;
+    config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
     populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
 
-    if (failed(
-            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+    if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+                                            config)))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 4c0f15bafbaba3..31c4bfe1f6b2f1 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -16,11 +17,15 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
 
 #include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
 
 #include <deque>
+#include <iterator>
 
 using namespace mlir;
 
@@ -699,9 +704,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
       blockIterators.push_back(mergeBlock->begin());
 
     // Update each of the predecessor terminators with the new arguments.
-    SmallVector<SmallVector<Value, 8>, 2> newArguments(
-        1 + blocksToMerge.size(),
-        SmallVector<Value, 8>(operandsToMerge.size()));
+    SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
+                                                       SmallVector<Value, 8>());
     unsigned curOpIndex = 0;
     for (const auto &it : llvm::enumerate(operandsToMerge)) {
       unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -712,13 +716,22 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
         Block::iterator &blockIter = blockIterators[i];
         std::advance(blockIter, nextOpOffset);
         auto &operand = blockIter->getOpOperand(it.value().second);
-        newArguments[i][it.index()] = operand.get();
-
-        // Update the operand and insert an argument if this is the leader.
-        if (i == 0) {
-          Value operandVal = operand.get();
-          operand.set(leaderBlock->addArgument(operandVal.getType(),
-                                               operandVal.getLoc()));
+        Value operandVal = operand.get();
+        Value *it = std::find(newArguments[i].begin(), newArguments[i].end(),
+                              operandVal);
+        if (it == newArguments[i].end()) {
+          newArguments[i].push_back(operandVal);
+          // Update the operand and insert an argument if this is the leader.
+          if (i == 0) {
+            operand.set(leaderBlock->addArgument(operandVal.getType(),
+                                                 operandVal.getLoc()));
+          }
+        } else if (i == 0) {
+          // If this is the leader, update the operand but do not insert a new
+          // argument. Instead, the opearand should point to one of the
+          // arguments we already passed (and that contained `operandVal`)
+          operand.set(leaderBlock->getArgument(
+              std::distance(newArguments[i].begin(), it)));
         }
       }
     }
@@ -818,6 +831,104 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
   return success(anyChanged);
 }
 
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+                                            Block &block) {
+  SmallVector<size_t> argsToErase;
+
+  // Go through the arguments of the block
+  for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
+    bool sameArg = true;
+    Value commonValue;
+
+    // Go through the block predecessor and flag if they pass to the block
+    // different values for the same argument
+    for (auto predIt = block.pred_begin(), predE = block.pred_end();
+         predIt != predE; ++predIt) {
+      auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
+      if (!branch) {
+        sameArg = false;
+        break;
+      }
+      unsigned succIndex = predIt.getSuccessorIndex();
+      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+      auto operands = succOperands.getForwardedOperands();
+      if (!commonValue) {
+        commonValue = operands[argIdx];
+      } else {
+        if (operands[argIdx] != commonValue) {
+          sameArg = false;
+          break;
+        }
+      }
+    }
+
+    // If they are passing the same value, drop the argument
+    if (commonValue && sameArg) {
+      argsToErase.push_back(argIdx);
+
+      // Remove the argument from the block
+      Value argVal = block.getArgument(argIdx);
+      rewriter.replaceAllUsesWith(argVal, commonValue);
+    }
+  }
+
+  // Remove the arguments
+  for (auto argIdx : llvm::reverse(argsToErase)) {
+    block.eraseArgument(argIdx);
+
+    // Remove the argument from the branch ops
+    for (auto predIt = block.pred_begin(), predE = block.pred_end();
+         predIt != predE; ++predIt) {
+      auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
+      unsigned succIndex = predIt.getSuccessorIndex();
+      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+      succOperands.erase(argIdx);
+    }
+  }
+  return success(!argsToErase.empty());
+}
+
+/// This optimization drops redundant argument to blocks. I.e., if a given
+/// argument to a block receives the same value from each of the block
+/// predecessors, we can remove the argument from the block and use directly the
+/// original value. This is a simple example:
+///
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(2 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
+/// : i64) ^bb1(%arg0 : i64, %arg1 : i64):
+///    llvm.call @foo(%arg0, %arg1)
+///
+/// The previous IR can be rewritten as:
+/// %cond = llvm.call @rand() : () -> i1
+/// %val = llvm.mlir.constant(1 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
+/// ^bb1(%arg0 : i64):
+///    llvm.call @foo(%val0, %arg1)
+///
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+                                            MutableArrayRef<Region> regions) {
+  llvm::SmallSetVector<Region *, 1> worklist;
+  for (auto &region : regions)
+    worklist.insert(&region);
+  bool anyChanged = false;
+  while (!worklist.empty()) {
+    Region *region = worklist.pop_back_val();
+
+    // Add any nested regions to the worklist.
+    for (Block &block : *region) {
+      anyChanged = succeeded(dropRedundantArguments(rewriter, block));
+
+      for (auto &op : block)
+        for (auto &nestedRegion : op.getRegions())
+          worklist.insert(&nestedRegion);
+    }
+  }
+  return success(anyChanged);
+}
+
 //===----------------------------------------------------------------------===//
 // Region Simplification
 //===----------------------------------------------------------------------===//
@@ -832,8 +943,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
   bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
   bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
   bool mergedIdenticalBlocks = false;
-  if (mergeBlocks)
+  bool droppedRedundantArguments = false;
+  if (mergeBlocks) {
     mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
+    droppedRedundantArguments =
+        succeeded(dropRedundantArguments(rewriter, regions));
+  }
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
-                 mergedIdenticalBlocks);
+                 mergedIdenticalBlocks || droppedRedundantArguments);
 }
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
index d1a89226fdb58f..50a2d6bf532aa3 100644
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
 // CHECK-LABEL: @main
 // CHECK-SAME:       (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
 // CHECK:   %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
-// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
-// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
-// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
+// CHECK: cf.br ^{{.*}}
+// CHECK: ^{{.*}}:
+// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
 // CHECK:   return %[[ELEMENTS]] : tensor<f32>

>From 4b1f9fcd517e9de4171ead6941de1de39c4776c0 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Fri, 28 Jun 2024 09:25:13 +0100
Subject: [PATCH 2/4] Add test case

---
 .../Transforms/canonicalize-block-merge.mlir  |  6 +-
 .../test-canonicalize-merge-large-blocks.mlir | 76 +++++++++++++++++++
 2 files changed, 79 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir

diff --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir
index 3b8b1fce0575a4..92cfde817cf7ff 100644
--- a/mlir/test/Transforms/canonicalize-block-merge.mlir
+++ b/mlir/test/Transforms/canonicalize-block-merge.mlir
@@ -87,7 +87,7 @@ func.func @mismatch_operands_matching_arguments(%cond : i1, %arg0 : i32, %arg1 :
 
 // CHECK-LABEL: func @mismatch_argument_uses(
 func.func @mismatch_argument_uses(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32, i32) {
-  // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+  // CHECK: return {{.*}}, {{.*}}
 
   cf.cond_br %cond, ^bb1(%arg1 : i32), ^bb2(%arg0 : i32)
 
@@ -101,7 +101,7 @@ func.func @mismatch_argument_uses(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32,
 
 // CHECK-LABEL: func @mismatch_argument_types(
 func.func @mismatch_argument_types(%cond : i1, %arg0 : i32, %arg1 : i16) {
-  // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+  // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
 
   cf.cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg1 : i16)
 
@@ -115,7 +115,7 @@ func.func @mismatch_argument_types(%cond : i1, %arg0 : i32, %arg1 : i16) {
 
 // CHECK-LABEL: func @mismatch_argument_count(
 func.func @mismatch_argument_count(%cond : i1, %arg0 : i32) {
-  // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+  // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
 
   cf.cond_br %cond, ^bb1(%arg0 : i32), ^bb2
 
diff --git a/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
new file mode 100644
index 00000000000000..570ff6905a04d2
--- /dev/null
+++ b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
@@ -0,0 +1,76 @@
+ // RUN: mlir-opt -pass-pipeline='builtin.module(llvm.func(canonicalize{region-simplify=aggressive}))' %s | FileCheck %s
+
+llvm.func @foo(%arg0: i64)
+
+llvm.func @rand() -> i1
+
+// CHECK-LABEL: func @large_merge_block(
+llvm.func @large_merge_block(%arg0: i64) {
+  // CHECK:  %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK:  %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+  // CHECK:  %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+  // CHECK:  %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+  // CHECK:  %[[C4:.*]] = llvm.mlir.constant(4 : i64) : i64
+
+  // CHECK:  llvm.cond_br %5, ^bb1(%[[C1]], %[[C3]], %[[C4]], %[[C2]] : i64, i64, i64, i64), ^bb1(%[[C4]], %[[C2]], %[[C1]], %[[C3]] : i64, i64, i64, i64)
+  // CHECK: ^bb{{.*}}(%[[arg0:.*]]: i64, %[[arg1:.*]]: i64, %[[arg2:.*]]: i64, %[[arg3:.*]]: i64):
+  // CHECK:    llvm.cond_br %{{.*}}, ^bb2(%[[arg0]] : i64), ^bb2(%[[arg3]] : i64)
+  // CHECK: ^bb{{.*}}(%11: i64):
+  // CHECK:    llvm.br ^bb{{.*}}
+  // CHECK: ^bb{{.*}}:
+  // CHECK:   llvm.call
+  // CHECK:   llvm.cond_br {{.*}}, ^bb{{.*}}(%[[arg1]] : i64), ^bb{{.*}}(%[[arg2]] : i64)
+  // CHECK: ^bb{{.*}}:
+  // CHECK:   llvm.call
+  // CHECK    llvm.br ^bb{{.*}}
+
+  %0 = llvm.mlir.constant(0 : i64) : i64
+  %1 = llvm.mlir.constant(1 : i64) : i64
+  %2 = llvm.mlir.constant(2 : i64) : i64
+  %3 = llvm.mlir.constant(3 : i64) : i64
+  %4 = llvm.mlir.constant(4 : i64) : i64
+  %10 = llvm.icmp "eq" %arg0, %0 : i64
+  llvm.cond_br %10, ^bb1, ^bb14
+^bb1:  // pred: ^bb0
+  %11 = llvm.call @rand() : () -> i1
+  llvm.cond_br %11, ^bb2, ^bb3
+^bb2:  // pred: ^bb1
+  llvm.call @foo(%1) : (i64) -> ()
+  llvm.br ^bb4
+^bb3:  // pred: ^bb1
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.br ^bb4
+^bb4:  // 2 preds: ^bb2, ^bb3
+  %14 = llvm.call @rand() : () -> i1
+  llvm.cond_br %14, ^bb5, ^bb6
+^bb5:  // pred: ^bb4
+  llvm.call @foo(%3) : (i64) -> ()
+  llvm.br ^bb13
+^bb6:  // pred: ^bb4
+  llvm.call @foo(%4) : (i64) -> ()
+  llvm.br ^bb13
+^bb13:  // 2 preds: ^bb11, ^bb12
+  llvm.br ^bb27
+^bb14:  // pred: ^bb0
+  %23 = llvm.call @rand() : () -> i1
+  llvm.cond_br %23, ^bb15, ^bb16
+^bb15:  // pred: ^bb14
+  llvm.call @foo(%4) : (i64) -> ()
+  llvm.br ^bb17
+^bb16:  // pred: ^bb14
+  llvm.call @foo(%3) : (i64) -> ()
+  llvm.br ^bb17
+^bb17:  // 2 preds: ^bb15, ^bb16
+  %26 = llvm.call @rand() : () -> i1
+  llvm.cond_br %26, ^bb18, ^bb19
+^bb18:  // pred: ^bb17
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.br ^bb26
+^bb19:  // pred: ^bb17
+  llvm.call @foo(%1) : (i64) -> ()
+  llvm.br ^bb26
+^bb26:  // 2 preds: ^bb24, ^bb25
+  llvm.br ^bb27
+^bb27:  // 2 preds: ^bb13, ^bb26
+  llvm.return
+}

>From 8d9fe794698cdffd3fb06beef65064091b0a9825 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Sat, 29 Jun 2024 07:51:45 +0100
Subject: [PATCH 3/4] Correcting remaining tests

---
 .../dealloc-branchop-interface.mlir           | 20 +++---
 mlir/test/Dialect/Linalg/detensorize_if.mlir  | 67 ++++++++-----------
 .../Dialect/Linalg/detensorize_while.mlir     | 12 ++--
 .../Linalg/detensorize_while_impure_cf.mlir   | 12 ++--
 .../Linalg/detensorize_while_pure_cf.mlir     |  4 +-
 mlir/test/Transforms/canonicalize-dce.mlir    |  8 +--
 .../Transforms/make-isolated-from-above.mlir  | 18 ++---
 7 files changed, 68 insertions(+), 73 deletions(-)

diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
index 5e8104f83cc4d4..8e14990502143e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
@@ -178,7 +178,7 @@ func.func @condBranchDynamicTypeNested(
 //  CHECK-NEXT: ^bb1
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
+//       CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
 //       CHECK: ^bb2([[IDX:%.*]]:{{.*}})
 //       CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
 //  CHECK-NEXT: test.buffer_based
@@ -186,20 +186,24 @@ func.func @condBranchDynamicTypeNested(
 //  CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
+//       CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
 //  CHECK-NEXT: ^bb3:
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
-//  CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+//  CHECK-NEXT: ^bb4:
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
-//  CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+//  CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+//   CHECK-NOT: bufferization.dealloc
+//   CHECK-NOT: bufferization.clone
+//       CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
+//  CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
 //  CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
 //  CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
-//       CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
-//  CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
+//  CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
 //       CHECK: test.copy
 //       CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
 //  CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])
diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
index 8d17763c04b6c4..c728ad21d2209b 100644
--- a/mlir/test/Dialect/Linalg/detensorize_if.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -42,18 +42,15 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     arith.constant 0
-// CHECK-DAG:     arith.constant 10
-// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT:     return %{{.*}}
+// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG:     arith.constant true
+// CHECK:         cf.br
+// CHECK-NEXT:   ^[[bb1:.*]]:
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT:   ^[[bb2]]
+// CHECK-NEXT:     cf.br ^[[bb3:.*]]
+// CHECK-NEXT:   ^[[bb3]]
+// CHECK-NEXT:     return %[[cst]]
 // CHECK-NEXT:   }
 
 // -----
@@ -106,20 +103,17 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     arith.constant 0
-// CHECK-DAG:     arith.constant 10
-// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT:     cf.br ^[[bb4:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb4]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT:     return %{{.*}}
+// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG:     arith.constant true
+// CHECK:         cf.br ^[[bb1:.*]]
+// CHECK-NEXT:   ^[[bb1:.*]]:
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT:   ^[[bb2]]:
+// CHECK-NEXT:     cf.br ^[[bb3:.*]]
+// CHECK-NEXT:   ^[[bb3]]:
+// CHECK-NEXT:     cf.br ^[[bb4:.*]]
+// CHECK-NEXT:   ^[[bb4]]:
+// CHECK-NEXT:     return %[[cst]]
 // CHECK-NEXT:   }
 
 // -----
@@ -171,16 +165,13 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     arith.constant 0
-// CHECK-DAG:     arith.constant 10
-// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT:     return %{{.*}}
+// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<10>
+// CHECK-DAG:     arith.constant true
+// CHECK:         cf.br ^[[bb1:.*]]
+// CHECK-NEXT:   ^[[bb1]]:
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb2
+// CHECK-NEXT:   ^[[bb2]]
+// CHECK-NEXT:     cf.br ^[[bb3:.*]]
+// CHECK-NEXT:   ^[[bb3]]
+// CHECK-NEXT:     return %[[cst]]
 // CHECK-NEXT:   }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
index aa30900f76a334..580a97d3a851ba 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -46,11 +46,11 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
 // DET-ALL:         cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // DET-ALL:       ^[[bb1]](%{{.*}}: i32)
 // DET-ALL:         arith.cmpi slt, {{.*}}
-// DET-ALL:         cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-ALL:       ^[[bb2]](%{{.*}}: i32)
+// DET-ALL:         cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-ALL:       ^[[bb2]]
 // DET-ALL:         arith.addi {{.*}}
 // DET-ALL:         cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-ALL:       ^[[bb3]](%{{.*}}: i32)
+// DET-ALL:       ^[[bb3]]:
 // DET-ALL:         tensor.from_elements {{.*}}
 // DET-ALL:         return %{{.*}} : tensor<i32>
 
@@ -62,10 +62,10 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
 // DET-CF:         cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // DET-CF:       ^[[bb1]](%{{.*}}: i32)
 // DET-CF:         arith.cmpi slt, {{.*}}
-// DET-CF:         cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-CF:       ^[[bb2]](%{{.*}}: i32)
+// DET-CF:         cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-CF:       ^[[bb2]]:
 // DET-CF:         arith.addi {{.*}}
 // DET-CF:         cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-CF:       ^[[bb3]](%{{.*}}: i32)
+// DET-CF:       ^[[bb3]]:
 // DET-CF:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-CF:         return %{{.*}} : tensor<i32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
index 955c7be5ef4c89..414d9b94cbf530 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
@@ -74,8 +74,8 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
 // DET-ALL:         } -> tensor<i32>
 // DET-ALL:         tensor.extract %{{.*}}[] : tensor<i32>
 // DET-ALL:         cmpi slt, %{{.*}}, %{{.*}} : i32
-// DET-ALL:         cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-ALL:       ^[[bb2]](%{{.*}}: i32)
+// DET-ALL:         cf.cond_br %{{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-ALL:       ^[[bb2]]:
 // DET-ALL:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-ALL:         tensor.empty() : tensor<10xi32>
 // DET-ALL:         linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
@@ -83,7 +83,7 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
 // DET-ALL:           linalg.yield %{{.*}} : i32
 // DET-ALL:         } -> tensor<10xi32>
 // DET-ALL:         cf.br ^[[bb1]](%{{.*}} : tensor<10xi32>)
-// DET-ALL:       ^[[bb3]](%{{.*}}: i32)
+// DET-ALL:       ^[[bb3]]
 // DET-ALL:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-ALL:         return %{{.*}} : tensor<i32>
 // DET-ALL:       }
@@ -95,10 +95,10 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
 // DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
 // DET-CF:         tensor.extract %{{.*}}[] : tensor<i32>
 // DET-CF:         cmpi slt, %{{.*}}, %{{.*}} : i32
-// DET-CF:         cf.cond_br %{{.*}}, ^bb2(%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
-// DET-CF:       ^bb2(%{{.*}}: tensor<i32>)
+// DET-CF:         cf.cond_br %{{.*}}, ^bb2, ^bb3
+// DET-CF:       ^bb2:
 // DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
 // DET-CF:         cf.br ^bb1(%{{.*}} : tensor<10xi32>)
-// DET-CF:       ^bb3(%{{.*}}: tensor<i32>)
+// DET-CF:       ^bb3:
 // DET-CF:         return %{{.*}} : tensor<i32>
 // DET-CF:       }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
index 6d8d5fe71fca5c..913e78272db796 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
@@ -49,8 +49,8 @@ func.func @main() -> () attributes {} {
 // CHECK-NEXT:    cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // CHECK-NEXT:  ^[[bb1]](%{{.*}}: i32)
 // CHECK-NEXT:    %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:    cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]]
-// CHECK-NEXT:  ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT:    cf.cond_br %{{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// CHECK-NEXT:  ^[[bb2]]
 // CHECK-NEXT:    %{{.*}} = arith.addi %{{.*}}, %{{.*}}
 // CHECK-NEXT:    cf.br ^[[bb1]](%{{.*}} : i32)
 // CHECK-NEXT:  ^[[bb3]]:
diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir
index ac034d567a26a9..84631947970ded 100644
--- a/mlir/test/Transforms/canonicalize-dce.mlir
+++ b/mlir/test/Transforms/canonicalize-dce.mlir
@@ -137,10 +137,10 @@ func.func @f(%arg0: f32) {
 // Test case: Test the mechanics of deleting multiple block arguments.
 
 // CHECK:      func @f(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>, %arg2: tensor<3xf32>, %arg3: tensor<4xf32>, %arg4: tensor<5xf32>)
-// CHECK-NEXT:   "test.br"(%arg1, %arg3)[^bb1] : (tensor<2xf32>, tensor<4xf32>)
-// CHECK-NEXT: ^bb1([[VAL0:%.+]]: tensor<2xf32>, [[VAL1:%.+]]: tensor<4xf32>):
-// CHECK-NEXT:   "foo.print"([[VAL0]])
-// CHECK-NEXT:   "foo.print"([[VAL1]])
+// CHECK-NEXT:   "test.br"()[^bb1]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT:   "foo.print"(%arg1)
+// CHECK-NEXT:   "foo.print"(%arg3)
 // CHECK-NEXT:   return
 
 
diff --git a/mlir/test/Transforms/make-isolated-from-above.mlir b/mlir/test/Transforms/make-isolated-from-above.mlir
index 58f6cfbc5dd65f..a9d4325944fd9d 100644
--- a/mlir/test/Transforms/make-isolated-from-above.mlir
+++ b/mlir/test/Transforms/make-isolated-from-above.mlir
@@ -78,9 +78,9 @@ func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index
 //   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
 //       CHECK:   test.isolated_one_region_op %[[ARG2]], %[[C0]], %[[C1]], %[[D0]], %[[D1]]
 //  CHECK-NEXT:     ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index, %[[B3:[a-zA-Z0-9]+]]: index, %[[B4:[a-zA-Z0-9]+]]: index)
-//  CHECK-NEXT:       cf.br ^bb1(%[[B0]] : index)
-//       CHECK:     ^bb1(%[[B5:.+]]: index)
-//       CHECK:       "foo.yield"(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]])
+//  CHECK-NEXT:       cf.br ^bb1
+//       CHECK:     ^bb1:
+//       CHECK:       "foo.yield"(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B0]])
 
 // CLONE1-LABEL: func @make_isolated_from_above_multiple_blocks(
 //  CLONE1-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
@@ -95,9 +95,9 @@ func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index
 //  CLONE1-NEXT:     ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index)
 //   CLONE1-DAG:       %[[C0_0:.+]] = arith.constant 0 : index
 //   CLONE1-DAG:       %[[C1_0:.+]] = arith.constant 1 : index
-//  CLONE1-NEXT:       cf.br ^bb1(%[[B0]] : index)
-//       CLONE1:     ^bb1(%[[B3:.+]]: index)
-//       CLONE1:       "foo.yield"(%[[C0_0]], %[[C1_0]], %[[B1]], %[[B2]], %[[B3]])
+//  CLONE1-NEXT:       cf.br ^bb1
+//       CLONE1:     ^bb1:
+//       CLONE1:       "foo.yield"(%[[C0_0]], %[[C1_0]], %[[B1]], %[[B2]], %[[B0]])
 
 // CLONE2-LABEL: func @make_isolated_from_above_multiple_blocks(
 //  CLONE2-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
@@ -110,6 +110,6 @@ func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index
 //   CLONE2-DAG:       %[[EMPTY:.+]] = tensor.empty(%[[B1]], %[[B2]])
 //   CLONE2-DAG:       %[[D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]]
 //   CLONE2-DAG:       %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
-//  CLONE2-NEXT:       cf.br ^bb1(%[[B0]] : index)
-//       CLONE2:     ^bb1(%[[B3:.+]]: index)
-//       CLONE2:       "foo.yield"(%[[C0]], %[[C1]], %[[D0]], %[[D1]], %[[B3]])
+//  CLONE2-NEXT:       cf.br ^bb1
+//       CLONE2:     ^bb1:
+//       CLONE2:       "foo.yield"(%[[C0]], %[[C1]], %[[D0]], %[[D1]], %[[B0]])

>From 1f153c395385ba34376b90e742646b4161fcebc9 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Sun, 30 Jun 2024 06:49:11 +0100
Subject: [PATCH 4/4] Fix integration tests

---
 mlir/lib/Transforms/Utils/RegionUtils.cpp     | 95 +++++++++++++++----
 .../test-canonicalize-merge-large-blocks.mlir | 86 +++++++++++++++++
 2 files changed, 162 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 31c4bfe1f6b2f1..8b742bada67b97 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -679,6 +679,64 @@ static bool ableToUpdatePredOperands(Block *block) {
   return true;
 }
 
+/// Prunes the redundant list of arguments. E.g., if we are passing an argument
+/// list like [x, y, z, x] this would return [x, y, z] and it would update the
+/// `block` (to whom the argument are passed to) accordingly
+static void
+pruneRedundantArguments(SmallVector<SmallVector<Value, 8>, 2> &newArguments,
+                        RewriterBase &rewriter, Block *block) {
+  SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
+      newArguments.size(), SmallVector<Value, 8>());
+
+  if (!newArguments.empty()) {
+    llvm::DenseMap<unsigned, unsigned> toReplace;
+    // Go through the first list of arguments (list 0)
+    for (unsigned j = 0; j < newArguments[0].size(); j++) {
+      bool shouldReplaceJ = false;
+      unsigned replacement = 0;
+      // Look back to see if there are possible redundancies in
+      // list 0
+      for (unsigned k = 0; k < j; k++) {
+        if (newArguments[0][k] == newArguments[0][j]) {
+          shouldReplaceJ = true;
+          replacement = k;
+          // If a possible redundancy is found, then scan the other lists: we
+          // can prune the arguments if and only if they are redundant in every
+          // list
+          for (unsigned i = 1; i < newArguments.size(); i++)
+            shouldReplaceJ =
+                shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
+        }
+      }
+      // Save the replacement
+      if (shouldReplaceJ)
+        toReplace[j] = replacement;
+    }
+
+    // Populate the pruned argument list
+    for (unsigned i = 0; i < newArguments.size(); i++)
+      for (unsigned j = 0; j < newArguments[i].size(); j++)
+        if (!toReplace.contains(j))
+          newArgumentsPruned[i].push_back(newArguments[i][j]);
+
+    // Replace the block's redundant arguments
+    SmallVector<unsigned> toErase;
+    for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
+      if (toReplace.contains(idx)) {
+        Value oldArg = block->getArgument(idx);
+        Value newArg = block->getArgument(toReplace[idx]);
+        rewriter.replaceAllUsesWith(oldArg, newArg);
+        toErase.push_back(idx);
+      }
+    }
+
+    // Erase the block's redundant arguments
+    for (auto idxToErase : llvm::reverse(toErase))
+      block->eraseArgument(idxToErase);
+    newArguments = newArgumentsPruned;
+  }
+}
+
 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
   // Don't consider clusters that don't have blocks to merge.
   if (blocksToMerge.empty())
@@ -717,24 +775,18 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
         std::advance(blockIter, nextOpOffset);
         auto &operand = blockIter->getOpOperand(it.value().second);
         Value operandVal = operand.get();
-        Value *it = std::find(newArguments[i].begin(), newArguments[i].end(),
-                              operandVal);
-        if (it == newArguments[i].end()) {
-          newArguments[i].push_back(operandVal);
-          // Update the operand and insert an argument if this is the leader.
-          if (i == 0) {
-            operand.set(leaderBlock->addArgument(operandVal.getType(),
-                                                 operandVal.getLoc()));
-          }
-        } else if (i == 0) {
-          // If this is the leader, update the operand but do not insert a new
-          // argument. Instead, the opearand should point to one of the
-          // arguments we already passed (and that contained `operandVal`)
-          operand.set(leaderBlock->getArgument(
-              std::distance(newArguments[i].begin(), it)));
+        newArguments[i].push_back(operandVal);
+        // Update the operand and insert an argument if this is the leader.
+        if (i == 0) {
+          operand.set(leaderBlock->addArgument(operandVal.getType(),
+                                               operandVal.getLoc()));
         }
       }
     }
+
+    // Prune redundant arguments and update the leader block argument list
+    pruneRedundantArguments(newArguments, rewriter, leaderBlock);
+
     // Update the predecessors for each of the blocks.
     auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
       for (auto predIt = block->pred_begin(), predE = block->pred_end();
@@ -896,17 +948,22 @@ static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
 /// %cond = llvm.call @rand() : () -> i1
 /// %val0 = llvm.mlir.constant(1 : i64) : i64
 /// %val1 = llvm.mlir.constant(2 : i64) : i64
-/// %val2 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
 /// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
-/// : i64) ^bb1(%arg0 : i64, %arg1 : i64):
+/// : i64)
+///
+/// ^bb1(%arg0 : i64, %arg1 : i64):
 ///    llvm.call @foo(%arg0, %arg1)
 ///
 /// The previous IR can be rewritten as:
 /// %cond = llvm.call @rand() : () -> i1
-/// %val = llvm.mlir.constant(1 : i64) : i64
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
 /// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
+///
 /// ^bb1(%arg0 : i64):
-///    llvm.call @foo(%val0, %arg1)
+///    llvm.call @foo(%val0, %arg0)
 ///
 static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
                                             MutableArrayRef<Region> regions) {
diff --git a/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
index 570ff6905a04d2..e821dcd0c20645 100644
--- a/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
+++ b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
@@ -74,3 +74,89 @@ llvm.func @large_merge_block(%arg0: i64) {
 ^bb27:  // 2 preds: ^bb13, ^bb26
   llvm.return
 }
+
+llvm.func @redundant_args0(%cond : i1) {
+  %0 = llvm.mlir.constant(0 : i64) : i64
+  %2 = llvm.mlir.constant(1 : i64) : i64
+  %3 = llvm.mlir.constant(2 : i64) : i64
+  // CHECK  %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK  %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+  // CHECK  %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+
+  llvm.cond_br %cond, ^bb1, ^bb2
+
+  // CHECK: llvm.cond_br %{{.*}}, ^bb{{.*}}(%[[C0]], %[[C0]] : i64, i64), ^bb{{.*}}(%[[C1]], %[[C2]] : i64, i64)
+  // CHECK: ^bb{{.*}}(%{{.*}}: i64, %{{.*}}: i64)
+^bb1:
+  llvm.call @foo(%0) : (i64) -> ()
+  llvm.call @foo(%0) : (i64) -> ()
+  llvm.br ^bb3
+^bb2:
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.call @foo(%3) : (i64) -> ()
+  llvm.br ^bb3
+^bb3:
+  llvm.return
+}
+
+llvm.func @redundant_args1(%cond : i1) {
+  %0 = llvm.mlir.constant(0 : i64) : i64
+  %2 = llvm.mlir.constant(1 : i64) : i64
+  %3 = llvm.mlir.constant(2 : i64) : i64
+  // CHECK  %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK  %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+  // CHECK  %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+
+  llvm.cond_br %cond, ^bb1, ^bb2
+
+  // CHECK: llvm.cond_br %{{.*}}, ^bb{{.*}}(%[[C1]], %[[C2]] : i64, i64), ^bb{{.*}}(%[[C0]], %[[C0]] : i64, i64)
+  // CHECK: ^bb{{.*}}(%{{.*}}: i64, %{{.*}}: i64)
+^bb1:
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.call @foo(%3) : (i64) -> ()
+  llvm.br ^bb3
+^bb2:
+  llvm.call @foo(%0) : (i64) -> ()
+  llvm.call @foo(%0) : (i64) -> ()
+  llvm.br ^bb3
+^bb3:
+  llvm.return
+}
+
+llvm.func @redundant_args_complex(%cond : i1) {
+  %0 = llvm.mlir.constant(0 : i64) : i64
+  %1 = llvm.mlir.constant(1 : i64) : i64
+  %2 = llvm.mlir.constant(2 : i64) : i64
+  %3 = llvm.mlir.constant(3 : i64) : i64
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+  // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+  // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+
+  llvm.cond_br %cond, ^bb1, ^bb2
+
+  // CHECK: llvm.cond_br %{{.*}}, ^bb{{.*}}(%[[C2]], %[[C1]], %[[C3]] : i64, i64, i64), ^bb{{.*}}(%[[C0]], %[[C3]], %[[C2]] : i64, i64, i64)
+  // CHECK: ^bb{{.*}}(%[[arg0:.*]]: i64, %[[arg1:.*]]: i64, %[[arg2:.*]]: i64):
+  // CHECK: llvm.call @foo(%[[arg0]])
+  // CHECK: llvm.call @foo(%[[arg0]])
+  // CHECK: llvm.call @foo(%[[arg1]])
+  // CHECK: llvm.call @foo(%[[C2]])
+  // CHECK: llvm.call @foo(%[[arg2]])
+
+^bb1:
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.call @foo(%1) : (i64) -> ()
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.call @foo(%3) : (i64) -> ()
+  llvm.br ^bb3
+^bb2:
+  llvm.call @foo(%0) : (i64) -> ()
+  llvm.call @foo(%0) : (i64) -> ()
+  llvm.call @foo(%3) : (i64) -> ()
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.br ^bb3
+^bb3:
+  llvm.return
+}



More information about the Mlir-commits mailing list