[Mlir-commits] [mlir] [mlir] Fix block merging (PR #97697)

Giuseppe Rossini llvmlistbot at llvm.org
Mon Jul 15 03:25:15 PDT 2024


https://github.com/giuseros updated https://github.com/llvm/llvm-project/pull/97697

>From a75c52d65a27819222ac77116abdb0c538647d2d Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Mon, 24 Jun 2024 09:36:32 +0100
Subject: [PATCH 1/7] Fix block merging

---
 .../BufferDeallocationSimplification.cpp      |   9 +-
 mlir/lib/Transforms/Utils/RegionUtils.cpp     | 139 ++++++++++++++++--
 .../Linalg/detensorize_entry_block.mlir       |   6 +-
 3 files changed, 137 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 954485cfede3d..5227b22653eef 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
                  SplitDeallocWhenNotAliasingAnyOther,
                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
                                                                 analysis);
+    // We don't want that the block structure changes invalidating the
+    // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
+    // region simplification
+    GreedyRewriteConfig config;
+    config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
     populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
 
-    if (failed(
-            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+    if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+                                            config)))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 4c0f15bafbaba..31c4bfe1f6b2f 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -16,11 +17,15 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
 
 #include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
 
 #include <deque>
+#include <iterator>
 
 using namespace mlir;
 
@@ -699,9 +704,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
       blockIterators.push_back(mergeBlock->begin());
 
     // Update each of the predecessor terminators with the new arguments.
-    SmallVector<SmallVector<Value, 8>, 2> newArguments(
-        1 + blocksToMerge.size(),
-        SmallVector<Value, 8>(operandsToMerge.size()));
+    SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
+                                                       SmallVector<Value, 8>());
     unsigned curOpIndex = 0;
     for (const auto &it : llvm::enumerate(operandsToMerge)) {
       unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -712,13 +716,22 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
         Block::iterator &blockIter = blockIterators[i];
         std::advance(blockIter, nextOpOffset);
         auto &operand = blockIter->getOpOperand(it.value().second);
-        newArguments[i][it.index()] = operand.get();
-
-        // Update the operand and insert an argument if this is the leader.
-        if (i == 0) {
-          Value operandVal = operand.get();
-          operand.set(leaderBlock->addArgument(operandVal.getType(),
-                                               operandVal.getLoc()));
+        Value operandVal = operand.get();
+        Value *it = std::find(newArguments[i].begin(), newArguments[i].end(),
+                              operandVal);
+        if (it == newArguments[i].end()) {
+          newArguments[i].push_back(operandVal);
+          // Update the operand and insert an argument if this is the leader.
+          if (i == 0) {
+            operand.set(leaderBlock->addArgument(operandVal.getType(),
+                                                 operandVal.getLoc()));
+          }
+        } else if (i == 0) {
+          // If this is the leader, update the operand but do not insert a new
+          // argument. Instead, the opearand should point to one of the
+          // arguments we already passed (and that contained `operandVal`)
+          operand.set(leaderBlock->getArgument(
+              std::distance(newArguments[i].begin(), it)));
         }
       }
     }
@@ -818,6 +831,104 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
   return success(anyChanged);
 }
 
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+                                            Block &block) {
+  SmallVector<size_t> argsToErase;
+
+  // Go through the arguments of the block
+  for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
+    bool sameArg = true;
+    Value commonValue;
+
+    // Go through the block predecessor and flag if they pass to the block
+    // different values for the same argument
+    for (auto predIt = block.pred_begin(), predE = block.pred_end();
+         predIt != predE; ++predIt) {
+      auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
+      if (!branch) {
+        sameArg = false;
+        break;
+      }
+      unsigned succIndex = predIt.getSuccessorIndex();
+      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+      auto operands = succOperands.getForwardedOperands();
+      if (!commonValue) {
+        commonValue = operands[argIdx];
+      } else {
+        if (operands[argIdx] != commonValue) {
+          sameArg = false;
+          break;
+        }
+      }
+    }
+
+    // If they are passing the same value, drop the argument
+    if (commonValue && sameArg) {
+      argsToErase.push_back(argIdx);
+
+      // Remove the argument from the block
+      Value argVal = block.getArgument(argIdx);
+      rewriter.replaceAllUsesWith(argVal, commonValue);
+    }
+  }
+
+  // Remove the arguments
+  for (auto argIdx : llvm::reverse(argsToErase)) {
+    block.eraseArgument(argIdx);
+
+    // Remove the argument from the branch ops
+    for (auto predIt = block.pred_begin(), predE = block.pred_end();
+         predIt != predE; ++predIt) {
+      auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
+      unsigned succIndex = predIt.getSuccessorIndex();
+      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+      succOperands.erase(argIdx);
+    }
+  }
+  return success(!argsToErase.empty());
+}
+
+/// This optimization drops redundant argument to blocks. I.e., if a given
+/// argument to a block receives the same value from each of the block
+/// predecessors, we can remove the argument from the block and use directly the
+/// original value. This is a simple example:
+///
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(2 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
+/// : i64) ^bb1(%arg0 : i64, %arg1 : i64):
+///    llvm.call @foo(%arg0, %arg1)
+///
+/// The previous IR can be rewritten as:
+/// %cond = llvm.call @rand() : () -> i1
+/// %val = llvm.mlir.constant(1 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
+/// ^bb1(%arg0 : i64):
+///    llvm.call @foo(%val0, %arg1)
+///
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+                                            MutableArrayRef<Region> regions) {
+  llvm::SmallSetVector<Region *, 1> worklist;
+  for (auto &region : regions)
+    worklist.insert(&region);
+  bool anyChanged = false;
+  while (!worklist.empty()) {
+    Region *region = worklist.pop_back_val();
+
+    // Add any nested regions to the worklist.
+    for (Block &block : *region) {
+      anyChanged = succeeded(dropRedundantArguments(rewriter, block));
+
+      for (auto &op : block)
+        for (auto &nestedRegion : op.getRegions())
+          worklist.insert(&nestedRegion);
+    }
+  }
+  return success(anyChanged);
+}
+
 //===----------------------------------------------------------------------===//
 // Region Simplification
 //===----------------------------------------------------------------------===//
@@ -832,8 +943,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
   bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
   bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
   bool mergedIdenticalBlocks = false;
-  if (mergeBlocks)
+  bool droppedRedundantArguments = false;
+  if (mergeBlocks) {
     mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
+    droppedRedundantArguments =
+        succeeded(dropRedundantArguments(rewriter, regions));
+  }
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
-                 mergedIdenticalBlocks);
+                 mergedIdenticalBlocks || droppedRedundantArguments);
 }
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
index d1a89226fdb58..50a2d6bf532aa 100644
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
 // CHECK-LABEL: @main
 // CHECK-SAME:       (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
 // CHECK:   %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
-// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
-// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
-// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
+// CHECK: cf.br ^{{.*}}
+// CHECK: ^{{.*}}:
+// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
 // CHECK:   return %[[ELEMENTS]] : tensor<f32>

>From 4b1f9fcd517e9de4171ead6941de1de39c4776c0 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Fri, 28 Jun 2024 09:25:13 +0100
Subject: [PATCH 2/7] Add test case

---
 .../Transforms/canonicalize-block-merge.mlir  |  6 +-
 .../test-canonicalize-merge-large-blocks.mlir | 76 +++++++++++++++++++
 2 files changed, 79 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir

diff --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir
index 3b8b1fce0575a..92cfde817cf7f 100644
--- a/mlir/test/Transforms/canonicalize-block-merge.mlir
+++ b/mlir/test/Transforms/canonicalize-block-merge.mlir
@@ -87,7 +87,7 @@ func.func @mismatch_operands_matching_arguments(%cond : i1, %arg0 : i32, %arg1 :
 
 // CHECK-LABEL: func @mismatch_argument_uses(
 func.func @mismatch_argument_uses(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32, i32) {
-  // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+  // CHECK: return {{.*}}, {{.*}}
 
   cf.cond_br %cond, ^bb1(%arg1 : i32), ^bb2(%arg0 : i32)
 
@@ -101,7 +101,7 @@ func.func @mismatch_argument_uses(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32,
 
 // CHECK-LABEL: func @mismatch_argument_types(
 func.func @mismatch_argument_types(%cond : i1, %arg0 : i32, %arg1 : i16) {
-  // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+  // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
 
   cf.cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg1 : i16)
 
@@ -115,7 +115,7 @@ func.func @mismatch_argument_types(%cond : i1, %arg0 : i32, %arg1 : i16) {
 
 // CHECK-LABEL: func @mismatch_argument_count(
 func.func @mismatch_argument_count(%cond : i1, %arg0 : i32) {
-  // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+  // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
 
   cf.cond_br %cond, ^bb1(%arg0 : i32), ^bb2
 
diff --git a/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
new file mode 100644
index 0000000000000..570ff6905a04d
--- /dev/null
+++ b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
@@ -0,0 +1,76 @@
+ // RUN: mlir-opt -pass-pipeline='builtin.module(llvm.func(canonicalize{region-simplify=aggressive}))' %s | FileCheck %s
+
+llvm.func @foo(%arg0: i64)
+
+llvm.func @rand() -> i1
+
+// CHECK-LABEL: func @large_merge_block(
+llvm.func @large_merge_block(%arg0: i64) {
+  // CHECK:  %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK:  %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+  // CHECK:  %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+  // CHECK:  %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+  // CHECK:  %[[C4:.*]] = llvm.mlir.constant(4 : i64) : i64
+
+  // CHECK:  llvm.cond_br %5, ^bb1(%[[C1]], %[[C3]], %[[C4]], %[[C2]] : i64, i64, i64, i64), ^bb1(%[[C4]], %[[C2]], %[[C1]], %[[C3]] : i64, i64, i64, i64)
+  // CHECK: ^bb{{.*}}(%[[arg0:.*]]: i64, %[[arg1:.*]]: i64, %[[arg2:.*]]: i64, %[[arg3:.*]]: i64):
+  // CHECK:    llvm.cond_br %{{.*}}, ^bb2(%[[arg0]] : i64), ^bb2(%[[arg3]] : i64)
+  // CHECK: ^bb{{.*}}(%11: i64):
+  // CHECK:    llvm.br ^bb{{.*}}
+  // CHECK: ^bb{{.*}}:
+  // CHECK:   llvm.call
+  // CHECK:   llvm.cond_br {{.*}}, ^bb{{.*}}(%[[arg1]] : i64), ^bb{{.*}}(%[[arg2]] : i64)
+  // CHECK: ^bb{{.*}}:
+  // CHECK:   llvm.call
+  // CHECK    llvm.br ^bb{{.*}}
+
+  %0 = llvm.mlir.constant(0 : i64) : i64
+  %1 = llvm.mlir.constant(1 : i64) : i64
+  %2 = llvm.mlir.constant(2 : i64) : i64
+  %3 = llvm.mlir.constant(3 : i64) : i64
+  %4 = llvm.mlir.constant(4 : i64) : i64
+  %10 = llvm.icmp "eq" %arg0, %0 : i64
+  llvm.cond_br %10, ^bb1, ^bb14
+^bb1:  // pred: ^bb0
+  %11 = llvm.call @rand() : () -> i1
+  llvm.cond_br %11, ^bb2, ^bb3
+^bb2:  // pred: ^bb1
+  llvm.call @foo(%1) : (i64) -> ()
+  llvm.br ^bb4
+^bb3:  // pred: ^bb1
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.br ^bb4
+^bb4:  // 2 preds: ^bb2, ^bb3
+  %14 = llvm.call @rand() : () -> i1
+  llvm.cond_br %14, ^bb5, ^bb6
+^bb5:  // pred: ^bb4
+  llvm.call @foo(%3) : (i64) -> ()
+  llvm.br ^bb13
+^bb6:  // pred: ^bb4
+  llvm.call @foo(%4) : (i64) -> ()
+  llvm.br ^bb13
+^bb13:  // 2 preds: ^bb11, ^bb12
+  llvm.br ^bb27
+^bb14:  // pred: ^bb0
+  %23 = llvm.call @rand() : () -> i1
+  llvm.cond_br %23, ^bb15, ^bb16
+^bb15:  // pred: ^bb14
+  llvm.call @foo(%4) : (i64) -> ()
+  llvm.br ^bb17
+^bb16:  // pred: ^bb14
+  llvm.call @foo(%3) : (i64) -> ()
+  llvm.br ^bb17
+^bb17:  // 2 preds: ^bb15, ^bb16
+  %26 = llvm.call @rand() : () -> i1
+  llvm.cond_br %26, ^bb18, ^bb19
+^bb18:  // pred: ^bb17
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.br ^bb26
+^bb19:  // pred: ^bb17
+  llvm.call @foo(%1) : (i64) -> ()
+  llvm.br ^bb26
+^bb26:  // 2 preds: ^bb24, ^bb25
+  llvm.br ^bb27
+^bb27:  // 2 preds: ^bb13, ^bb26
+  llvm.return
+}

>From 8d9fe794698cdffd3fb06beef65064091b0a9825 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Sat, 29 Jun 2024 07:51:45 +0100
Subject: [PATCH 3/7] Correcting remaining tests

---
 .../dealloc-branchop-interface.mlir           | 20 +++---
 mlir/test/Dialect/Linalg/detensorize_if.mlir  | 67 ++++++++-----------
 .../Dialect/Linalg/detensorize_while.mlir     | 12 ++--
 .../Linalg/detensorize_while_impure_cf.mlir   | 12 ++--
 .../Linalg/detensorize_while_pure_cf.mlir     |  4 +-
 mlir/test/Transforms/canonicalize-dce.mlir    |  8 +--
 .../Transforms/make-isolated-from-above.mlir  | 18 ++---
 7 files changed, 68 insertions(+), 73 deletions(-)

diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
index 5e8104f83cc4d..8e14990502143 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
@@ -178,7 +178,7 @@ func.func @condBranchDynamicTypeNested(
 //  CHECK-NEXT: ^bb1
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
+//       CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
 //       CHECK: ^bb2([[IDX:%.*]]:{{.*}})
 //       CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
 //  CHECK-NEXT: test.buffer_based
@@ -186,20 +186,24 @@ func.func @condBranchDynamicTypeNested(
 //  CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
+//       CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
 //  CHECK-NEXT: ^bb3:
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
-//  CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+//  CHECK-NEXT: ^bb4:
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
-//  CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+//  CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+//   CHECK-NOT: bufferization.dealloc
+//   CHECK-NOT: bufferization.clone
+//       CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
+//  CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
 //  CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
 //  CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
-//       CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
-//  CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
+//  CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
 //       CHECK: test.copy
 //       CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
 //  CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])
diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
index 8d17763c04b6c..c728ad21d2209 100644
--- a/mlir/test/Dialect/Linalg/detensorize_if.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -42,18 +42,15 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     arith.constant 0
-// CHECK-DAG:     arith.constant 10
-// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT:     return %{{.*}}
+// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG:     arith.constant true
+// CHECK:         cf.br
+// CHECK-NEXT:   ^[[bb1:.*]]:
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT:   ^[[bb2]]
+// CHECK-NEXT:     cf.br ^[[bb3:.*]]
+// CHECK-NEXT:   ^[[bb3]]
+// CHECK-NEXT:     return %[[cst]]
 // CHECK-NEXT:   }
 
 // -----
@@ -106,20 +103,17 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     arith.constant 0
-// CHECK-DAG:     arith.constant 10
-// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT:     cf.br ^[[bb4:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb4]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT:     return %{{.*}}
+// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG:     arith.constant true
+// CHECK:         cf.br ^[[bb1:.*]]
+// CHECK-NEXT:   ^[[bb1:.*]]:
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT:   ^[[bb2]]:
+// CHECK-NEXT:     cf.br ^[[bb3:.*]]
+// CHECK-NEXT:   ^[[bb3]]:
+// CHECK-NEXT:     cf.br ^[[bb4:.*]]
+// CHECK-NEXT:   ^[[bb4]]:
+// CHECK-NEXT:     return %[[cst]]
 // CHECK-NEXT:   }
 
 // -----
@@ -171,16 +165,13 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     arith.constant 0
-// CHECK-DAG:     arith.constant 10
-// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT:     return %{{.*}}
+// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<10>
+// CHECK-DAG:     arith.constant true
+// CHECK:         cf.br ^[[bb1:.*]]
+// CHECK-NEXT:   ^[[bb1]]:
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb2
+// CHECK-NEXT:   ^[[bb2]]
+// CHECK-NEXT:     cf.br ^[[bb3:.*]]
+// CHECK-NEXT:   ^[[bb3]]
+// CHECK-NEXT:     return %[[cst]]
 // CHECK-NEXT:   }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
index aa30900f76a33..580a97d3a851b 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -46,11 +46,11 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
 // DET-ALL:         cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // DET-ALL:       ^[[bb1]](%{{.*}}: i32)
 // DET-ALL:         arith.cmpi slt, {{.*}}
-// DET-ALL:         cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-ALL:       ^[[bb2]](%{{.*}}: i32)
+// DET-ALL:         cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-ALL:       ^[[bb2]]
 // DET-ALL:         arith.addi {{.*}}
 // DET-ALL:         cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-ALL:       ^[[bb3]](%{{.*}}: i32)
+// DET-ALL:       ^[[bb3]]:
 // DET-ALL:         tensor.from_elements {{.*}}
 // DET-ALL:         return %{{.*}} : tensor<i32>
 
@@ -62,10 +62,10 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
 // DET-CF:         cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // DET-CF:       ^[[bb1]](%{{.*}}: i32)
 // DET-CF:         arith.cmpi slt, {{.*}}
-// DET-CF:         cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-CF:       ^[[bb2]](%{{.*}}: i32)
+// DET-CF:         cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-CF:       ^[[bb2]]:
 // DET-CF:         arith.addi {{.*}}
 // DET-CF:         cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-CF:       ^[[bb3]](%{{.*}}: i32)
+// DET-CF:       ^[[bb3]]:
 // DET-CF:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-CF:         return %{{.*}} : tensor<i32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
index 955c7be5ef4c8..414d9b94cbf53 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
@@ -74,8 +74,8 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
 // DET-ALL:         } -> tensor<i32>
 // DET-ALL:         tensor.extract %{{.*}}[] : tensor<i32>
 // DET-ALL:         cmpi slt, %{{.*}}, %{{.*}} : i32
-// DET-ALL:         cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-ALL:       ^[[bb2]](%{{.*}}: i32)
+// DET-ALL:         cf.cond_br %{{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-ALL:       ^[[bb2]]:
 // DET-ALL:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-ALL:         tensor.empty() : tensor<10xi32>
 // DET-ALL:         linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
@@ -83,7 +83,7 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
 // DET-ALL:           linalg.yield %{{.*}} : i32
 // DET-ALL:         } -> tensor<10xi32>
 // DET-ALL:         cf.br ^[[bb1]](%{{.*}} : tensor<10xi32>)
-// DET-ALL:       ^[[bb3]](%{{.*}}: i32)
+// DET-ALL:       ^[[bb3]]
 // DET-ALL:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-ALL:         return %{{.*}} : tensor<i32>
 // DET-ALL:       }
@@ -95,10 +95,10 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
 // DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
 // DET-CF:         tensor.extract %{{.*}}[] : tensor<i32>
 // DET-CF:         cmpi slt, %{{.*}}, %{{.*}} : i32
-// DET-CF:         cf.cond_br %{{.*}}, ^bb2(%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
-// DET-CF:       ^bb2(%{{.*}}: tensor<i32>)
+// DET-CF:         cf.cond_br %{{.*}}, ^bb2, ^bb3
+// DET-CF:       ^bb2:
 // DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
 // DET-CF:         cf.br ^bb1(%{{.*}} : tensor<10xi32>)
-// DET-CF:       ^bb3(%{{.*}}: tensor<i32>)
+// DET-CF:       ^bb3:
 // DET-CF:         return %{{.*}} : tensor<i32>
 // DET-CF:       }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
index 6d8d5fe71fca5..913e78272db79 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
@@ -49,8 +49,8 @@ func.func @main() -> () attributes {} {
 // CHECK-NEXT:    cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // CHECK-NEXT:  ^[[bb1]](%{{.*}}: i32)
 // CHECK-NEXT:    %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:    cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]]
-// CHECK-NEXT:  ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT:    cf.cond_br %{{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// CHECK-NEXT:  ^[[bb2]]
 // CHECK-NEXT:    %{{.*}} = arith.addi %{{.*}}, %{{.*}}
 // CHECK-NEXT:    cf.br ^[[bb1]](%{{.*}} : i32)
 // CHECK-NEXT:  ^[[bb3]]:
diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir
index ac034d567a26a..84631947970de 100644
--- a/mlir/test/Transforms/canonicalize-dce.mlir
+++ b/mlir/test/Transforms/canonicalize-dce.mlir
@@ -137,10 +137,10 @@ func.func @f(%arg0: f32) {
 // Test case: Test the mechanics of deleting multiple block arguments.
 
 // CHECK:      func @f(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>, %arg2: tensor<3xf32>, %arg3: tensor<4xf32>, %arg4: tensor<5xf32>)
-// CHECK-NEXT:   "test.br"(%arg1, %arg3)[^bb1] : (tensor<2xf32>, tensor<4xf32>)
-// CHECK-NEXT: ^bb1([[VAL0:%.+]]: tensor<2xf32>, [[VAL1:%.+]]: tensor<4xf32>):
-// CHECK-NEXT:   "foo.print"([[VAL0]])
-// CHECK-NEXT:   "foo.print"([[VAL1]])
+// CHECK-NEXT:   "test.br"()[^bb1]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT:   "foo.print"(%arg1)
+// CHECK-NEXT:   "foo.print"(%arg3)
 // CHECK-NEXT:   return
 
 
diff --git a/mlir/test/Transforms/make-isolated-from-above.mlir b/mlir/test/Transforms/make-isolated-from-above.mlir
index 58f6cfbc5dd65..a9d4325944fd9 100644
--- a/mlir/test/Transforms/make-isolated-from-above.mlir
+++ b/mlir/test/Transforms/make-isolated-from-above.mlir
@@ -78,9 +78,9 @@ func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index
 //   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
 //       CHECK:   test.isolated_one_region_op %[[ARG2]], %[[C0]], %[[C1]], %[[D0]], %[[D1]]
 //  CHECK-NEXT:     ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index, %[[B3:[a-zA-Z0-9]+]]: index, %[[B4:[a-zA-Z0-9]+]]: index)
-//  CHECK-NEXT:       cf.br ^bb1(%[[B0]] : index)
-//       CHECK:     ^bb1(%[[B5:.+]]: index)
-//       CHECK:       "foo.yield"(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]])
+//  CHECK-NEXT:       cf.br ^bb1
+//       CHECK:     ^bb1:
+//       CHECK:       "foo.yield"(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B0]])
 
 // CLONE1-LABEL: func @make_isolated_from_above_multiple_blocks(
 //  CLONE1-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
@@ -95,9 +95,9 @@ func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index
 //  CLONE1-NEXT:     ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index)
 //   CLONE1-DAG:       %[[C0_0:.+]] = arith.constant 0 : index
 //   CLONE1-DAG:       %[[C1_0:.+]] = arith.constant 1 : index
-//  CLONE1-NEXT:       cf.br ^bb1(%[[B0]] : index)
-//       CLONE1:     ^bb1(%[[B3:.+]]: index)
-//       CLONE1:       "foo.yield"(%[[C0_0]], %[[C1_0]], %[[B1]], %[[B2]], %[[B3]])
+//  CLONE1-NEXT:       cf.br ^bb1
+//       CLONE1:     ^bb1:
+//       CLONE1:       "foo.yield"(%[[C0_0]], %[[C1_0]], %[[B1]], %[[B2]], %[[B0]])
 
 // CLONE2-LABEL: func @make_isolated_from_above_multiple_blocks(
 //  CLONE2-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
@@ -110,6 +110,6 @@ func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index
 //   CLONE2-DAG:       %[[EMPTY:.+]] = tensor.empty(%[[B1]], %[[B2]])
 //   CLONE2-DAG:       %[[D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]]
 //   CLONE2-DAG:       %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
-//  CLONE2-NEXT:       cf.br ^bb1(%[[B0]] : index)
-//       CLONE2:     ^bb1(%[[B3:.+]]: index)
-//       CLONE2:       "foo.yield"(%[[C0]], %[[C1]], %[[D0]], %[[D1]], %[[B3]])
+//  CLONE2-NEXT:       cf.br ^bb1
+//       CLONE2:     ^bb1:
+//       CLONE2:       "foo.yield"(%[[C0]], %[[C1]], %[[D0]], %[[D1]], %[[B0]])

>From 249a3877880e1aeed3903c6f38f115734bf64110 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Sun, 30 Jun 2024 06:49:11 +0100
Subject: [PATCH 4/7] Fix integration tests

---
 mlir/lib/Transforms/Utils/RegionUtils.cpp     | 103 ++++++++++++++----
 .../test-canonicalize-merge-large-blocks.mlir |  86 +++++++++++++++
 2 files changed, 167 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 31c4bfe1f6b2f..c336bbc913634 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -679,6 +679,64 @@ static bool ableToUpdatePredOperands(Block *block) {
   return true;
 }
 
+/// Prunes the redundant list of arguments. E.g., if we are passing an argument
+/// list like [x, y, z, x] this would return [x, y, z] and it would update the
+/// `block` (to whom the argument are passed to) accordingly
+static void
+pruneRedundantArguments(SmallVector<SmallVector<Value, 8>, 2> &newArguments,
+                        RewriterBase &rewriter, Block *block) {
+  SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
+      newArguments.size(), SmallVector<Value, 8>());
+
+  if (!newArguments.empty()) {
+    llvm::DenseMap<unsigned, unsigned> toReplace;
+    // Go through the first list of arguments (list 0)
+    for (unsigned j = 0; j < newArguments[0].size(); j++) {
+      bool shouldReplaceJ = false;
+      unsigned replacement = 0;
+      // Look back to see if there are possible redundancies in
+      // list 0
+      for (unsigned k = 0; k < j; k++) {
+        if (newArguments[0][k] == newArguments[0][j]) {
+          shouldReplaceJ = true;
+          replacement = k;
+          // If a possible redundancy is found, then scan the other lists: we
+          // can prune the arguments if and only if they are redundant in every
+          // list
+          for (unsigned i = 1; i < newArguments.size(); i++)
+            shouldReplaceJ =
+                shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
+        }
+      }
+      // Save the replacement
+      if (shouldReplaceJ)
+        toReplace[j] = replacement;
+    }
+
+    // Populate the pruned argument list
+    for (unsigned i = 0; i < newArguments.size(); i++)
+      for (unsigned j = 0; j < newArguments[i].size(); j++)
+        if (!toReplace.contains(j))
+          newArgumentsPruned[i].push_back(newArguments[i][j]);
+
+    // Replace the block's redundant arguments
+    SmallVector<unsigned> toErase;
+    for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
+      if (toReplace.contains(idx)) {
+        Value oldArg = block->getArgument(idx);
+        Value newArg = block->getArgument(toReplace[idx]);
+        rewriter.replaceAllUsesWith(oldArg, newArg);
+        toErase.push_back(idx);
+      }
+    }
+
+    // Erase the block's redundant arguments
+    for (auto idxToErase : llvm::reverse(toErase))
+      block->eraseArgument(idxToErase);
+    newArguments = newArgumentsPruned;
+  }
+}
+
 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
   // Don't consider clusters that don't have blocks to merge.
   if (blocksToMerge.empty())
@@ -704,8 +762,9 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
       blockIterators.push_back(mergeBlock->begin());
 
     // Update each of the predecessor terminators with the new arguments.
-    SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
-                                                       SmallVector<Value, 8>());
+    SmallVector<SmallVector<Value, 8>, 2> newArguments(
+        1 + blocksToMerge.size(),
+        SmallVector<Value, 8>(operandsToMerge.size()));
     unsigned curOpIndex = 0;
     for (const auto &it : llvm::enumerate(operandsToMerge)) {
       unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -716,25 +775,20 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
         Block::iterator &blockIter = blockIterators[i];
         std::advance(blockIter, nextOpOffset);
         auto &operand = blockIter->getOpOperand(it.value().second);
-        Value operandVal = operand.get();
-        Value *it = std::find(newArguments[i].begin(), newArguments[i].end(),
-                              operandVal);
-        if (it == newArguments[i].end()) {
-          newArguments[i].push_back(operandVal);
-          // Update the operand and insert an argument if this is the leader.
-          if (i == 0) {
-            operand.set(leaderBlock->addArgument(operandVal.getType(),
-                                                 operandVal.getLoc()));
-          }
-        } else if (i == 0) {
-          // If this is the leader, update the operand but do not insert a new
-          // argument. Instead, the opearand should point to one of the
-          // arguments we already passed (and that contained `operandVal`)
-          operand.set(leaderBlock->getArgument(
-              std::distance(newArguments[i].begin(), it)));
+        newArguments[i][it.index()] = operand.get();
+
+        // Update the operand and insert an argument if this is the leader.
+        if (i == 0) {
+          Value operandVal = operand.get();
+          operand.set(leaderBlock->addArgument(operandVal.getType(),
+                                               operandVal.getLoc()));
         }
       }
     }
+
+    // Prune redundant arguments and update the leader block argument list
+    pruneRedundantArguments(newArguments, rewriter, leaderBlock);
+
     // Update the predecessors for each of the blocks.
     auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
       for (auto predIt = block->pred_begin(), predE = block->pred_end();
@@ -896,17 +950,22 @@ static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
 /// %cond = llvm.call @rand() : () -> i1
 /// %val0 = llvm.mlir.constant(1 : i64) : i64
 /// %val1 = llvm.mlir.constant(2 : i64) : i64
-/// %val2 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
 /// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
-/// : i64) ^bb1(%arg0 : i64, %arg1 : i64):
+/// : i64)
+///
+/// ^bb1(%arg0 : i64, %arg1 : i64):
 ///    llvm.call @foo(%arg0, %arg1)
 ///
 /// The previous IR can be rewritten as:
 /// %cond = llvm.call @rand() : () -> i1
-/// %val = llvm.mlir.constant(1 : i64) : i64
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
 /// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
+///
 /// ^bb1(%arg0 : i64):
-///    llvm.call @foo(%val0, %arg1)
+///    llvm.call @foo(%val0, %arg0)
 ///
 static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
                                             MutableArrayRef<Region> regions) {
diff --git a/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
index 570ff6905a04d..e821dcd0c2064 100644
--- a/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
+++ b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
@@ -74,3 +74,89 @@ llvm.func @large_merge_block(%arg0: i64) {
 ^bb27:  // 2 preds: ^bb13, ^bb26
   llvm.return
 }
+
+llvm.func @redundant_args0(%cond : i1) {
+  %0 = llvm.mlir.constant(0 : i64) : i64
+  %2 = llvm.mlir.constant(1 : i64) : i64
+  %3 = llvm.mlir.constant(2 : i64) : i64
+  // CHECK  %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK  %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+  // CHECK  %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+
+  llvm.cond_br %cond, ^bb1, ^bb2
+
+  // CHECK: llvm.cond_br %{{.*}}, ^bb{{.*}}(%[[C0]], %[[C0]] : i64, i64), ^bb{{.*}}(%[[C1]], %[[C2]] : i64, i64)
+  // CHECK: ^bb{{.*}}(%{{.*}}: i64, %{{.*}}: i64)
+^bb1:
+  llvm.call @foo(%0) : (i64) -> ()
+  llvm.call @foo(%0) : (i64) -> ()
+  llvm.br ^bb3
+^bb2:
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.call @foo(%3) : (i64) -> ()
+  llvm.br ^bb3
+^bb3:
+  llvm.return
+}
+
+llvm.func @redundant_args1(%cond : i1) {
+  %0 = llvm.mlir.constant(0 : i64) : i64
+  %2 = llvm.mlir.constant(1 : i64) : i64
+  %3 = llvm.mlir.constant(2 : i64) : i64
+  // CHECK  %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK  %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+  // CHECK  %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+
+  llvm.cond_br %cond, ^bb1, ^bb2
+
+  // CHECK: llvm.cond_br %{{.*}}, ^bb{{.*}}(%[[C1]], %[[C2]] : i64, i64), ^bb{{.*}}(%[[C0]], %[[C0]] : i64, i64)
+  // CHECK: ^bb{{.*}}(%{{.*}}: i64, %{{.*}}: i64)
+^bb1:
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.call @foo(%3) : (i64) -> ()
+  llvm.br ^bb3
+^bb2:
+  llvm.call @foo(%0) : (i64) -> ()
+  llvm.call @foo(%0) : (i64) -> ()
+  llvm.br ^bb3
+^bb3:
+  llvm.return
+}
+
+llvm.func @redundant_args_complex(%cond : i1) {
+  %0 = llvm.mlir.constant(0 : i64) : i64
+  %1 = llvm.mlir.constant(1 : i64) : i64
+  %2 = llvm.mlir.constant(2 : i64) : i64
+  %3 = llvm.mlir.constant(3 : i64) : i64
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+  // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+  // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+
+  llvm.cond_br %cond, ^bb1, ^bb2
+
+  // CHECK: llvm.cond_br %{{.*}}, ^bb{{.*}}(%[[C2]], %[[C1]], %[[C3]] : i64, i64, i64), ^bb{{.*}}(%[[C0]], %[[C3]], %[[C2]] : i64, i64, i64)
+  // CHECK: ^bb{{.*}}(%[[arg0:.*]]: i64, %[[arg1:.*]]: i64, %[[arg2:.*]]: i64):
+  // CHECK: llvm.call @foo(%[[arg0]])
+  // CHECK: llvm.call @foo(%[[arg0]])
+  // CHECK: llvm.call @foo(%[[arg1]])
+  // CHECK: llvm.call @foo(%[[C2]])
+  // CHECK: llvm.call @foo(%[[arg2]])
+
+^bb1:
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.call @foo(%1) : (i64) -> ()
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.call @foo(%3) : (i64) -> ()
+  llvm.br ^bb3
+^bb2:
+  llvm.call @foo(%0) : (i64) -> ()
+  llvm.call @foo(%0) : (i64) -> ()
+  llvm.call @foo(%3) : (i64) -> ()
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.call @foo(%2) : (i64) -> ()
+  llvm.br ^bb3
+^bb3:
+  llvm.return
+}

>From 92d385d03bfdaa00dfe434804aedaa23b86f6872 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Tue, 9 Jul 2024 09:50:20 +0100
Subject: [PATCH 5/7] Address review feedbacks

---
 mlir/lib/Transforms/Utils/RegionUtils.cpp | 127 ++++++++++++----------
 1 file changed, 68 insertions(+), 59 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index c336bbc913634..6d672b59da0f8 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -681,60 +681,70 @@ static bool ableToUpdatePredOperands(Block *block) {
 
 /// Prunes the redundant list of arguments. E.g., if we are passing an argument
 /// list like [x, y, z, x] this would return [x, y, z] and it would update the
-/// `block` (to whom the argument are passed to) accordingly
-static void
-pruneRedundantArguments(SmallVector<SmallVector<Value, 8>, 2> &newArguments,
-                        RewriterBase &rewriter, Block *block) {
+/// `block` (to whom the argument are passed to) accordingly.
+static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
+    const SmallVector<SmallVector<Value, 8>, 2> &newArguments,
+    RewriterBase &rewriter, Block *block) {
+
   SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
       newArguments.size(), SmallVector<Value, 8>());
 
-  if (!newArguments.empty()) {
-    llvm::DenseMap<unsigned, unsigned> toReplace;
-    // Go through the first list of arguments (list 0)
-    for (unsigned j = 0; j < newArguments[0].size(); j++) {
-      bool shouldReplaceJ = false;
-      unsigned replacement = 0;
-      // Look back to see if there are possible redundancies in
-      // list 0
-      for (unsigned k = 0; k < j; k++) {
-        if (newArguments[0][k] == newArguments[0][j]) {
-          shouldReplaceJ = true;
-          replacement = k;
-          // If a possible redundancy is found, then scan the other lists: we
-          // can prune the arguments if and only if they are redundant in every
-          // list
-          for (unsigned i = 1; i < newArguments.size(); i++)
-            shouldReplaceJ =
-                shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
-        }
+  if (newArguments.empty())
+    return newArguments;
+
+  // `newArguments` is a 2D array of size `numLists` x `numArgs`
+  unsigned numLists = newArguments.size();
+  unsigned numArgs = newArguments[0].size();
+
+  // Map that for each arg index contains the index that we can use in place of
+  // the original index. E.g., if we have newArgs = [x, y, z, x], we will have
+  // idxToReplacement[3] = 0
+  llvm::DenseMap<unsigned, unsigned> idxToReplacement;
+
+  // Go through the first list of arguments (list 0).
+  for (unsigned j = 0; j < numArgs; ++j) {
+    bool shouldReplaceJ = false;
+    unsigned replacement = 0;
+    // Look back to see if there are possible redundancies in
+    // list 0.
+    for (unsigned k = 0; k < j; k++) {
+      if (newArguments[0][k] == newArguments[0][j]) {
+        shouldReplaceJ = true;
+        replacement = k;
+        // If a possible redundancy is found, then scan the other lists: we
+        // can prune the arguments if and only if they are redundant in every
+        // list.
+        for (unsigned i = 1; i < numLists; ++i)
+          shouldReplaceJ =
+              shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
       }
-      // Save the replacement
-      if (shouldReplaceJ)
-        toReplace[j] = replacement;
     }
+    // Save the replacement.
+    if (shouldReplaceJ)
+      idxToReplacement[j] = replacement;
+  }
 
-    // Populate the pruned argument list
-    for (unsigned i = 0; i < newArguments.size(); i++)
-      for (unsigned j = 0; j < newArguments[i].size(); j++)
-        if (!toReplace.contains(j))
-          newArgumentsPruned[i].push_back(newArguments[i][j]);
-
-    // Replace the block's redundant arguments
-    SmallVector<unsigned> toErase;
-    for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
-      if (toReplace.contains(idx)) {
-        Value oldArg = block->getArgument(idx);
-        Value newArg = block->getArgument(toReplace[idx]);
-        rewriter.replaceAllUsesWith(oldArg, newArg);
-        toErase.push_back(idx);
-      }
+  // Populate the pruned argument list.
+  for (unsigned i = 0; i < numLists; ++i)
+    for (unsigned j = 0; j < numArgs; ++j)
+      if (!idxToReplacement.contains(j))
+        newArgumentsPruned[i].push_back(newArguments[i][j]);
+
+  // Replace the block's redundant arguments.
+  SmallVector<unsigned> toErase;
+  for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
+    if (idxToReplacement.contains(idx)) {
+      Value oldArg = block->getArgument(idx);
+      Value newArg = block->getArgument(idxToReplacement[idx]);
+      rewriter.replaceAllUsesWith(oldArg, newArg);
+      toErase.push_back(idx);
     }
-
-    // Erase the block's redundant arguments
-    for (auto idxToErase : llvm::reverse(toErase))
-      block->eraseArgument(idxToErase);
-    newArguments = newArgumentsPruned;
   }
+
+  // Erase the block's redundant arguments.
+  for (unsigned idxToErase : llvm::reverse(toErase))
+    block->eraseArgument(idxToErase);
+  return newArgumentsPruned;
 }
 
 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
@@ -787,7 +797,7 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
     }
 
     // Prune redundant arguments and update the leader block argument list
-    pruneRedundantArguments(newArguments, rewriter, leaderBlock);
+    newArguments = pruneRedundantArguments(newArguments, rewriter, leaderBlock);
 
     // Update the predecessors for each of the blocks.
     auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
@@ -889,13 +899,13 @@ static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
                                             Block &block) {
   SmallVector<size_t> argsToErase;
 
-  // Go through the arguments of the block
-  for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
+  // Go through the arguments of the block.
+  for (auto [argIdx, blockOperand] : llvm::enumerate(block.getArguments())) {
     bool sameArg = true;
     Value commonValue;
 
     // Go through the block predecessor and flag if they pass to the block
-    // different values for the same argument
+    // different values for the same argument.
     for (auto predIt = block.pred_begin(), predE = block.pred_end();
          predIt != predE; ++predIt) {
       auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
@@ -905,32 +915,31 @@ static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
       }
       unsigned succIndex = predIt.getSuccessorIndex();
       SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
-      auto operands = succOperands.getForwardedOperands();
+      auto branchOperands = succOperands.getForwardedOperands();
       if (!commonValue) {
-        commonValue = operands[argIdx];
+        commonValue = branchOperands[argIdx];
       } else {
-        if (operands[argIdx] != commonValue) {
+        if (branchOperands[argIdx] != commonValue) {
           sameArg = false;
           break;
         }
       }
     }
 
-    // If they are passing the same value, drop the argument
+    // If they are passing the same value, drop the argument.
     if (commonValue && sameArg) {
       argsToErase.push_back(argIdx);
 
-      // Remove the argument from the block
-      Value argVal = block.getArgument(argIdx);
-      rewriter.replaceAllUsesWith(argVal, commonValue);
+      // Remove the argument from the block.
+      rewriter.replaceAllUsesWith(blockOperand, commonValue);
     }
   }
 
-  // Remove the arguments
+  // Remove the arguments.
   for (auto argIdx : llvm::reverse(argsToErase)) {
     block.eraseArgument(argIdx);
 
-    // Remove the argument from the branch ops
+    // Remove the argument from the branch ops.
     for (auto predIt = block.pred_begin(), predE = block.pred_end();
          predIt != predE; ++predIt) {
       auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());

>From da6c33289736a5b1b8b33c3b35f9066a109e0265 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Wed, 10 Jul 2024 07:29:20 +0100
Subject: [PATCH 6/7] Address review feedbacks - 2

---
 mlir/lib/Transforms/Utils/RegionUtils.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 6d672b59da0f8..16fcfa7807535 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -979,7 +979,7 @@ static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
 static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
                                             MutableArrayRef<Region> regions) {
   llvm::SmallSetVector<Region *, 1> worklist;
-  for (auto &region : regions)
+  for (Region &region : regions)
     worklist.insert(&region);
   bool anyChanged = false;
   while (!worklist.empty()) {
@@ -989,8 +989,8 @@ static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
     for (Block &block : *region) {
       anyChanged = succeeded(dropRedundantArguments(rewriter, block));
 
-      for (auto &op : block)
-        for (auto &nestedRegion : op.getRegions())
+      for (Operation &op : block)
+        for (Region &nestedRegion : op.getRegions())
           worklist.insert(&nestedRegion);
     }
   }

>From f3f20a473e874fa7eae71748cecd91f17d4d6875 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Mon, 15 Jul 2024 11:23:17 +0100
Subject: [PATCH 7/7] Use a O(N) algorithm with a conservative tradeoff

---
 mlir/lib/Transforms/Utils/RegionUtils.cpp | 43 ++++++++++++++++-------
 1 file changed, 30 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 16fcfa7807535..946d65cef4186 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -701,23 +701,40 @@ static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
   // idxToReplacement[3] = 0
   llvm::DenseMap<unsigned, unsigned> idxToReplacement;
 
+  // This is a useful data structure to track the first appearance of a Value
+  // on a given list of arguments
+  DenseMap<Value, unsigned> firstValueToIdx;
+  for (unsigned j = 0; j < numArgs; ++j) {
+    Value newArg = newArguments[0][j];
+    if (!firstValueToIdx.contains(newArg))
+      firstValueToIdx[newArg] = j;
+  }
+
   // Go through the first list of arguments (list 0).
   for (unsigned j = 0; j < numArgs; ++j) {
     bool shouldReplaceJ = false;
     unsigned replacement = 0;
-    // Look back to see if there are possible redundancies in
-    // list 0.
-    for (unsigned k = 0; k < j; k++) {
-      if (newArguments[0][k] == newArguments[0][j]) {
-        shouldReplaceJ = true;
-        replacement = k;
-        // If a possible redundancy is found, then scan the other lists: we
-        // can prune the arguments if and only if they are redundant in every
-        // list.
-        for (unsigned i = 1; i < numLists; ++i)
-          shouldReplaceJ =
-              shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
-      }
+    // Look back to see if there are possible redundancies in list 0. Please
+    // note that we are using a map to annotate when an argument was seen first
+    // to avoid a O(N^2) algorithm. This has the drawback that if we have two
+    // lists like:
+    // list0: [%a, %a, %a]
+    // list1: [%c, %b, %b]
+    // We cannot simplify it, because firstVlaueToIdx[%a] = 0, but we cannot
+    // point list1[1](==%b) or list1[2](==%b) to list1[0](==%c).  However, since
+    // the number of arguments can be potentially unbounded we cannot afford a
+    // O(N^2) algorithm (to search to all the possible pairs) and we need to
+    // accept the trade-off.
+    unsigned k = firstValueToIdx[newArguments[0][j]];
+    if (k != j) {
+      shouldReplaceJ = true;
+      replacement = k;
+      // If a possible redundancy is found, then scan the other lists: we
+      // can prune the arguments if and only if they are redundant in every
+      // list.
+      for (unsigned i = 1; i < numLists; ++i)
+        shouldReplaceJ =
+            shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
     }
     // Save the replacement.
     if (shouldReplaceJ)



More information about the Mlir-commits mailing list