[Mlir-commits] [mlir] [mlir] Fix block merging (PR #102038)
Ivan Butygin
llvmlistbot at llvm.org
Sun Aug 11 15:03:29 PDT 2024
================
@@ -818,6 +917,111 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
return success(anyChanged);
}
+/// If a block's argument is always the same across different invocations, then
+/// drop the argument and use the value directly inside the block
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+ Block &block) {
+ SmallVector<size_t> argsToErase;
+
+ // Go through the arguments of the block.
+ for (auto [argIdx, blockOperand] : llvm::enumerate(block.getArguments())) {
+ bool sameArg = true;
+ Value commonValue;
+
+ // Go through the block predecessor and flag if they pass to the block
+ // different values for the same argument.
+ for (Block::pred_iterator predIt = block.pred_begin(),
+ predE = block.pred_end();
+ predIt != predE; ++predIt) {
+ auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
+ if (!branch) {
+ sameArg = false;
+ break;
+ }
+ unsigned succIndex = predIt.getSuccessorIndex();
+ SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+ auto branchOperands = succOperands.getForwardedOperands();
+ if (!commonValue) {
+ commonValue = branchOperands[argIdx];
+ continue;
+ }
+ if (branchOperands[argIdx] != commonValue) {
+ sameArg = false;
+ break;
+ }
+ }
+
+ // If they are passing the same value, drop the argument.
+ if (commonValue && sameArg) {
+ argsToErase.push_back(argIdx);
+
+ // Remove the argument from the block.
+ rewriter.replaceAllUsesWith(blockOperand, commonValue);
+ }
+ }
+
+ // Remove the arguments.
+ for (size_t argIdx : llvm::reverse(argsToErase)) {
+ block.eraseArgument(argIdx);
+
+ // Remove the argument from the branch ops.
+ for (auto predIt = block.pred_begin(), predE = block.pred_end();
+ predIt != predE; ++predIt) {
+ auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
+ unsigned succIndex = predIt.getSuccessorIndex();
+ SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+ succOperands.erase(argIdx);
+ }
+ }
+ return success(!argsToErase.empty());
+}
+
+/// This optimization drops redundant argument to blocks. I.e., if a given
+/// argument to a block receives the same value from each of the block
+/// predecessors, we can remove the argument from the block and use directly the
+/// original value. This is a simple example:
+///
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
+/// : i64)
+///
+/// ^bb1(%arg0 : i64, %arg1 : i64):
+/// llvm.call @foo(%arg0, %arg1)
+///
+/// The previous IR can be rewritten as:
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
+///
+/// ^bb1(%arg0 : i64):
+/// llvm.call @foo(%val0, %arg0)
+///
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+ MutableArrayRef<Region> regions) {
+ llvm::SmallSetVector<Region *, 1> worklist;
+ for (Region ®ion : regions)
+ worklist.insert(®ion);
+ bool anyChanged = false;
+ while (!worklist.empty()) {
+ Region *region = worklist.pop_back_val();
+
+ // Add any nested regions to the worklist.
+ for (Block &block : *region) {
+ anyChanged = succeeded(dropRedundantArguments(rewriter, block));
----------------
Hardcode84 wrote:
`anyChanged` is completely overwritten on each iteration.
It should be ` anyChanged = succeeded(dropRedundantArguments(rewriter, block)) || anyChanged `
https://github.com/llvm/llvm-project/pull/102038
More information about the Mlir-commits
mailing list