[Mlir-commits] [mlir] [mlir] Fix block merging (PR #102038)
Giuseppe Rossini
llvmlistbot at llvm.org
Mon Aug 5 11:59:19 PDT 2024
https://github.com/giuseros created https://github.com/llvm/llvm-project/pull/102038
With this PR I am trying to address: https://github.com/llvm/llvm-project/issues/63230.
What changed:
- While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same `Value`. The operations operands in the block will point to the argument we already inserted. This needs to happen to all the arguments we pass to the different successors of the parent block
- After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument.
- This last simplification clashed with `BufferDeallocationSimplification`. The reason, I think, is that the two simplifications are clashing. I.e., `BufferDeallocationSimplification` contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass.
**Note-1**: I ran all the integration tests (`-DMLIR_INCLUDE_INTEGRATION_TESTS=ON`) and they passed.
**Note-2**: I fixed a bug found by @Dinistro in #97697 . The issue was that, when looking for redundant arguments, I was not considering that the block might have already some arguments. So the index (in the block args list) of the i-th `newArgument` is `i+numOfOldArguments`.
>From 982f92820ad10702c4236c862da8f3b817f20409 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Mon, 24 Jun 2024 09:36:32 +0100
Subject: [PATCH 1/8] Fix block merging
---
.../BufferDeallocationSimplification.cpp | 9 +-
mlir/lib/Transforms/Utils/RegionUtils.cpp | 139 ++++++++++++++++--
.../Linalg/detensorize_entry_block.mlir | 6 +-
3 files changed, 137 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 954485cfede3d..5227b22653eef 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
SplitDeallocWhenNotAliasingAnyOther,
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
analysis);
+ // We don't want that the block structure changes invalidating the
+ // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
+ // region simplification
+ GreedyRewriteConfig config;
+ config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
- if (failed(
- applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+ config)))
signalPassFailure();
}
};
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 4c0f15bafbaba..31c4bfe1f6b2f 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -9,6 +9,7 @@
#include "mlir/Transforms/RegionUtils.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
@@ -16,11 +17,15 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
#include <deque>
+#include <iterator>
using namespace mlir;
@@ -699,9 +704,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
blockIterators.push_back(mergeBlock->begin());
// Update each of the predecessor terminators with the new arguments.
- SmallVector<SmallVector<Value, 8>, 2> newArguments(
- 1 + blocksToMerge.size(),
- SmallVector<Value, 8>(operandsToMerge.size()));
+ SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
+ SmallVector<Value, 8>());
unsigned curOpIndex = 0;
for (const auto &it : llvm::enumerate(operandsToMerge)) {
unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -712,13 +716,22 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
Block::iterator &blockIter = blockIterators[i];
std::advance(blockIter, nextOpOffset);
auto &operand = blockIter->getOpOperand(it.value().second);
- newArguments[i][it.index()] = operand.get();
-
- // Update the operand and insert an argument if this is the leader.
- if (i == 0) {
- Value operandVal = operand.get();
- operand.set(leaderBlock->addArgument(operandVal.getType(),
- operandVal.getLoc()));
+ Value operandVal = operand.get();
+ Value *it = std::find(newArguments[i].begin(), newArguments[i].end(),
+ operandVal);
+ if (it == newArguments[i].end()) {
+ newArguments[i].push_back(operandVal);
+ // Update the operand and insert an argument if this is the leader.
+ if (i == 0) {
+ operand.set(leaderBlock->addArgument(operandVal.getType(),
+ operandVal.getLoc()));
+ }
+ } else if (i == 0) {
+ // If this is the leader, update the operand but do not insert a new
+ // argument. Instead, the opearand should point to one of the
+ // arguments we already passed (and that contained `operandVal`)
+ operand.set(leaderBlock->getArgument(
+ std::distance(newArguments[i].begin(), it)));
}
}
}
@@ -818,6 +831,104 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
return success(anyChanged);
}
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+ Block &block) {
+ SmallVector<size_t> argsToErase;
+
+ // Go through the arguments of the block
+ for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
+ bool sameArg = true;
+ Value commonValue;
+
+ // Go through the block predecessor and flag if they pass to the block
+ // different values for the same argument
+ for (auto predIt = block.pred_begin(), predE = block.pred_end();
+ predIt != predE; ++predIt) {
+ auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
+ if (!branch) {
+ sameArg = false;
+ break;
+ }
+ unsigned succIndex = predIt.getSuccessorIndex();
+ SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+ auto operands = succOperands.getForwardedOperands();
+ if (!commonValue) {
+ commonValue = operands[argIdx];
+ } else {
+ if (operands[argIdx] != commonValue) {
+ sameArg = false;
+ break;
+ }
+ }
+ }
+
+ // If they are passing the same value, drop the argument
+ if (commonValue && sameArg) {
+ argsToErase.push_back(argIdx);
+
+ // Remove the argument from the block
+ Value argVal = block.getArgument(argIdx);
+ rewriter.replaceAllUsesWith(argVal, commonValue);
+ }
+ }
+
+ // Remove the arguments
+ for (auto argIdx : llvm::reverse(argsToErase)) {
+ block.eraseArgument(argIdx);
+
+ // Remove the argument from the branch ops
+ for (auto predIt = block.pred_begin(), predE = block.pred_end();
+ predIt != predE; ++predIt) {
+ auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
+ unsigned succIndex = predIt.getSuccessorIndex();
+ SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+ succOperands.erase(argIdx);
+ }
+ }
+ return success(!argsToErase.empty());
+}
+
+/// This optimization drops redundant argument to blocks. I.e., if a given
+/// argument to a block receives the same value from each of the block
+/// predecessors, we can remove the argument from the block and use directly the
+/// original value. This is a simple example:
+///
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(2 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
+/// : i64) ^bb1(%arg0 : i64, %arg1 : i64):
+/// llvm.call @foo(%arg0, %arg1)
+///
+/// The previous IR can be rewritten as:
+/// %cond = llvm.call @rand() : () -> i1
+/// %val = llvm.mlir.constant(1 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
+/// ^bb1(%arg0 : i64):
+/// llvm.call @foo(%val0, %arg1)
+///
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+ MutableArrayRef<Region> regions) {
+ llvm::SmallSetVector<Region *, 1> worklist;
+ for (auto ®ion : regions)
+ worklist.insert(®ion);
+ bool anyChanged = false;
+ while (!worklist.empty()) {
+ Region *region = worklist.pop_back_val();
+
+ // Add any nested regions to the worklist.
+ for (Block &block : *region) {
+ anyChanged = succeeded(dropRedundantArguments(rewriter, block));
+
+ for (auto &op : block)
+ for (auto &nestedRegion : op.getRegions())
+ worklist.insert(&nestedRegion);
+ }
+ }
+ return success(anyChanged);
+}
+
//===----------------------------------------------------------------------===//
// Region Simplification
//===----------------------------------------------------------------------===//
@@ -832,8 +943,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
bool mergedIdenticalBlocks = false;
- if (mergeBlocks)
+ bool droppedRedundantArguments = false;
+ if (mergeBlocks) {
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
+ droppedRedundantArguments =
+ succeeded(dropRedundantArguments(rewriter, regions));
+ }
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
- mergedIdenticalBlocks);
+ mergedIdenticalBlocks || droppedRedundantArguments);
}
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
index d1a89226fdb58..50a2d6bf532aa 100644
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: @main
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
-// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
-// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
-// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
+// CHECK: cf.br ^{{.*}}
+// CHECK: ^{{.*}}:
+// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
// CHECK: return %[[ELEMENTS]] : tensor<f32>
>From 040622bae47bf54e08d7c87efc30a7d00e6e6826 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Fri, 28 Jun 2024 09:25:13 +0100
Subject: [PATCH 2/8] Add test case
---
.../Transforms/canonicalize-block-merge.mlir | 6 +-
.../test-canonicalize-merge-large-blocks.mlir | 76 +++++++++++++++++++
2 files changed, 79 insertions(+), 3 deletions(-)
create mode 100644 mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
diff --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir
index 3b8b1fce0575a..92cfde817cf7f 100644
--- a/mlir/test/Transforms/canonicalize-block-merge.mlir
+++ b/mlir/test/Transforms/canonicalize-block-merge.mlir
@@ -87,7 +87,7 @@ func.func @mismatch_operands_matching_arguments(%cond : i1, %arg0 : i32, %arg1 :
// CHECK-LABEL: func @mismatch_argument_uses(
func.func @mismatch_argument_uses(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32, i32) {
- // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+ // CHECK: return {{.*}}, {{.*}}
cf.cond_br %cond, ^bb1(%arg1 : i32), ^bb2(%arg0 : i32)
@@ -101,7 +101,7 @@ func.func @mismatch_argument_uses(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32,
// CHECK-LABEL: func @mismatch_argument_types(
func.func @mismatch_argument_types(%cond : i1, %arg0 : i32, %arg1 : i16) {
- // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+ // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
cf.cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg1 : i16)
@@ -115,7 +115,7 @@ func.func @mismatch_argument_types(%cond : i1, %arg0 : i32, %arg1 : i16) {
// CHECK-LABEL: func @mismatch_argument_count(
func.func @mismatch_argument_count(%cond : i1, %arg0 : i32) {
- // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+ // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
cf.cond_br %cond, ^bb1(%arg0 : i32), ^bb2
diff --git a/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
new file mode 100644
index 0000000000000..570ff6905a04d
--- /dev/null
+++ b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
@@ -0,0 +1,76 @@
+ // RUN: mlir-opt -pass-pipeline='builtin.module(llvm.func(canonicalize{region-simplify=aggressive}))' %s | FileCheck %s
+
+llvm.func @foo(%arg0: i64)
+
+llvm.func @rand() -> i1
+
+// CHECK-LABEL: func @large_merge_block(
+llvm.func @large_merge_block(%arg0: i64) {
+ // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+ // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+ // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i64) : i64
+
+ // CHECK: llvm.cond_br %5, ^bb1(%[[C1]], %[[C3]], %[[C4]], %[[C2]] : i64, i64, i64, i64), ^bb1(%[[C4]], %[[C2]], %[[C1]], %[[C3]] : i64, i64, i64, i64)
+ // CHECK: ^bb{{.*}}(%[[arg0:.*]]: i64, %[[arg1:.*]]: i64, %[[arg2:.*]]: i64, %[[arg3:.*]]: i64):
+ // CHECK: llvm.cond_br %{{.*}}, ^bb2(%[[arg0]] : i64), ^bb2(%[[arg3]] : i64)
+ // CHECK: ^bb{{.*}}(%11: i64):
+ // CHECK: llvm.br ^bb{{.*}}
+ // CHECK: ^bb{{.*}}:
+ // CHECK: llvm.call
+ // CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}(%[[arg1]] : i64), ^bb{{.*}}(%[[arg2]] : i64)
+ // CHECK: ^bb{{.*}}:
+ // CHECK: llvm.call
+ // CHECK llvm.br ^bb{{.*}}
+
+ %0 = llvm.mlir.constant(0 : i64) : i64
+ %1 = llvm.mlir.constant(1 : i64) : i64
+ %2 = llvm.mlir.constant(2 : i64) : i64
+ %3 = llvm.mlir.constant(3 : i64) : i64
+ %4 = llvm.mlir.constant(4 : i64) : i64
+ %10 = llvm.icmp "eq" %arg0, %0 : i64
+ llvm.cond_br %10, ^bb1, ^bb14
+^bb1: // pred: ^bb0
+ %11 = llvm.call @rand() : () -> i1
+ llvm.cond_br %11, ^bb2, ^bb3
+^bb2: // pred: ^bb1
+ llvm.call @foo(%1) : (i64) -> ()
+ llvm.br ^bb4
+^bb3: // pred: ^bb1
+ llvm.call @foo(%2) : (i64) -> ()
+ llvm.br ^bb4
+^bb4: // 2 preds: ^bb2, ^bb3
+ %14 = llvm.call @rand() : () -> i1
+ llvm.cond_br %14, ^bb5, ^bb6
+^bb5: // pred: ^bb4
+ llvm.call @foo(%3) : (i64) -> ()
+ llvm.br ^bb13
+^bb6: // pred: ^bb4
+ llvm.call @foo(%4) : (i64) -> ()
+ llvm.br ^bb13
+^bb13: // 2 preds: ^bb11, ^bb12
+ llvm.br ^bb27
+^bb14: // pred: ^bb0
+ %23 = llvm.call @rand() : () -> i1
+ llvm.cond_br %23, ^bb15, ^bb16
+^bb15: // pred: ^bb14
+ llvm.call @foo(%4) : (i64) -> ()
+ llvm.br ^bb17
+^bb16: // pred: ^bb14
+ llvm.call @foo(%3) : (i64) -> ()
+ llvm.br ^bb17
+^bb17: // 2 preds: ^bb15, ^bb16
+ %26 = llvm.call @rand() : () -> i1
+ llvm.cond_br %26, ^bb18, ^bb19
+^bb18: // pred: ^bb17
+ llvm.call @foo(%2) : (i64) -> ()
+ llvm.br ^bb26
+^bb19: // pred: ^bb17
+ llvm.call @foo(%1) : (i64) -> ()
+ llvm.br ^bb26
+^bb26: // 2 preds: ^bb24, ^bb25
+ llvm.br ^bb27
+^bb27: // 2 preds: ^bb13, ^bb26
+ llvm.return
+}
>From 00cfd1d66a5b9059e375349977b374f81d51f7f2 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Sat, 29 Jun 2024 07:51:45 +0100
Subject: [PATCH 3/8] Correcting remaining tests
---
.../dealloc-branchop-interface.mlir | 20 +++---
mlir/test/Dialect/Linalg/detensorize_if.mlir | 67 ++++++++-----------
.../Dialect/Linalg/detensorize_while.mlir | 12 ++--
.../Linalg/detensorize_while_impure_cf.mlir | 12 ++--
.../Linalg/detensorize_while_pure_cf.mlir | 4 +-
mlir/test/Transforms/canonicalize-dce.mlir | 8 +--
.../Transforms/make-isolated-from-above.mlir | 18 ++---
7 files changed, 68 insertions(+), 73 deletions(-)
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
index 5e8104f83cc4d..8e14990502143 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
@@ -178,7 +178,7 @@ func.func @condBranchDynamicTypeNested(
// CHECK-NEXT: ^bb1
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
-// CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
+// CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
// CHECK: ^bb2([[IDX:%.*]]:{{.*}})
// CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
// CHECK-NEXT: test.buffer_based
@@ -186,20 +186,24 @@ func.func @condBranchDynamicTypeNested(
// CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
-// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
+// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
// CHECK-NEXT: ^bb3:
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
-// CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
-// CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+// CHECK-NEXT: ^bb4:
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
-// CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
-// CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
+// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+// CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+// CHECK-NOT: bufferization.dealloc
+// CHECK-NOT: bufferization.clone
+// CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
+// CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
// CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
-// CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
-// CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
+// CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
+// CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
// CHECK: test.copy
// CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])
diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
index 8d17763c04b6c..c728ad21d2209 100644
--- a/mlir/test/Dialect/Linalg/detensorize_if.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -42,18 +42,15 @@ func.func @main() -> (tensor<i32>) attributes {} {
}
// CHECK-LABEL: func @main()
-// CHECK-DAG: arith.constant 0
-// CHECK-DAG: arith.constant 10
-// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT: return %{{.*}}
+// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG: arith.constant true
+// CHECK: cf.br
+// CHECK-NEXT: ^[[bb1:.*]]:
+// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT: ^[[bb2]]
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb3]]
+// CHECK-NEXT: return %[[cst]]
// CHECK-NEXT: }
// -----
@@ -106,20 +103,17 @@ func.func @main() -> (tensor<i32>) attributes {} {
}
// CHECK-LABEL: func @main()
-// CHECK-DAG: arith.constant 0
-// CHECK-DAG: arith.constant 10
-// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT: cf.br ^[[bb4:.*]](%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb4]](%{{.*}}: i32)
-// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT: return %{{.*}}
+// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG: arith.constant true
+// CHECK: cf.br ^[[bb1:.*]]
+// CHECK-NEXT: ^[[bb1:.*]]:
+// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT: ^[[bb2]]:
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb3]]:
+// CHECK-NEXT: cf.br ^[[bb4:.*]]
+// CHECK-NEXT: ^[[bb4]]:
+// CHECK-NEXT: return %[[cst]]
// CHECK-NEXT: }
// -----
@@ -171,16 +165,13 @@ func.func @main() -> (tensor<i32>) attributes {} {
}
// CHECK-LABEL: func @main()
-// CHECK-DAG: arith.constant 0
-// CHECK-DAG: arith.constant 10
-// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT: return %{{.*}}
+// CHECK-DAG: %[[cst:.*]] = arith.constant dense<10>
+// CHECK-DAG: arith.constant true
+// CHECK: cf.br ^[[bb1:.*]]
+// CHECK-NEXT: ^[[bb1]]:
+// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb2
+// CHECK-NEXT: ^[[bb2]]
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb3]]
+// CHECK-NEXT: return %[[cst]]
// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
index aa30900f76a33..580a97d3a851b 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -46,11 +46,11 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
// DET-ALL: cf.br ^[[bb1:.*]](%{{.*}} : i32)
// DET-ALL: ^[[bb1]](%{{.*}}: i32)
// DET-ALL: arith.cmpi slt, {{.*}}
-// DET-ALL: cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-ALL: ^[[bb2]](%{{.*}}: i32)
+// DET-ALL: cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-ALL: ^[[bb2]]
// DET-ALL: arith.addi {{.*}}
// DET-ALL: cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-ALL: ^[[bb3]](%{{.*}}: i32)
+// DET-ALL: ^[[bb3]]:
// DET-ALL: tensor.from_elements {{.*}}
// DET-ALL: return %{{.*}} : tensor<i32>
@@ -62,10 +62,10 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
// DET-CF: cf.br ^[[bb1:.*]](%{{.*}} : i32)
// DET-CF: ^[[bb1]](%{{.*}}: i32)
// DET-CF: arith.cmpi slt, {{.*}}
-// DET-CF: cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-CF: ^[[bb2]](%{{.*}}: i32)
+// DET-CF: cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-CF: ^[[bb2]]:
// DET-CF: arith.addi {{.*}}
// DET-CF: cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-CF: ^[[bb3]](%{{.*}}: i32)
+// DET-CF: ^[[bb3]]:
// DET-CF: tensor.from_elements %{{.*}} : tensor<i32>
// DET-CF: return %{{.*}} : tensor<i32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
index 955c7be5ef4c8..414d9b94cbf53 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
@@ -74,8 +74,8 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
// DET-ALL: } -> tensor<i32>
// DET-ALL: tensor.extract %{{.*}}[] : tensor<i32>
// DET-ALL: cmpi slt, %{{.*}}, %{{.*}} : i32
-// DET-ALL: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-ALL: ^[[bb2]](%{{.*}}: i32)
+// DET-ALL: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-ALL: ^[[bb2]]:
// DET-ALL: tensor.from_elements %{{.*}} : tensor<i32>
// DET-ALL: tensor.empty() : tensor<10xi32>
// DET-ALL: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
@@ -83,7 +83,7 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
// DET-ALL: linalg.yield %{{.*}} : i32
// DET-ALL: } -> tensor<10xi32>
// DET-ALL: cf.br ^[[bb1]](%{{.*}} : tensor<10xi32>)
-// DET-ALL: ^[[bb3]](%{{.*}}: i32)
+// DET-ALL: ^[[bb3]]
// DET-ALL: tensor.from_elements %{{.*}} : tensor<i32>
// DET-ALL: return %{{.*}} : tensor<i32>
// DET-ALL: }
@@ -95,10 +95,10 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
// DET-CF: tensor.extract %{{.*}}[] : tensor<i32>
// DET-CF: cmpi slt, %{{.*}}, %{{.*}} : i32
-// DET-CF: cf.cond_br %{{.*}}, ^bb2(%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
-// DET-CF: ^bb2(%{{.*}}: tensor<i32>)
+// DET-CF: cf.cond_br %{{.*}}, ^bb2, ^bb3
+// DET-CF: ^bb2:
// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
// DET-CF: cf.br ^bb1(%{{.*}} : tensor<10xi32>)
-// DET-CF: ^bb3(%{{.*}}: tensor<i32>)
+// DET-CF: ^bb3:
// DET-CF: return %{{.*}} : tensor<i32>
// DET-CF: }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
index 6d8d5fe71fca5..913e78272db79 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
@@ -49,8 +49,8 @@ func.func @main() -> () attributes {} {
// CHECK-NEXT: cf.br ^[[bb1:.*]](%{{.*}} : i32)
// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32)
// CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]]
-// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb2]]
// CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}}
// CHECK-NEXT: cf.br ^[[bb1]](%{{.*}} : i32)
// CHECK-NEXT: ^[[bb3]]:
diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir
index ac034d567a26a..84631947970de 100644
--- a/mlir/test/Transforms/canonicalize-dce.mlir
+++ b/mlir/test/Transforms/canonicalize-dce.mlir
@@ -137,10 +137,10 @@ func.func @f(%arg0: f32) {
// Test case: Test the mechanics of deleting multiple block arguments.
// CHECK: func @f(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>, %arg2: tensor<3xf32>, %arg3: tensor<4xf32>, %arg4: tensor<5xf32>)
-// CHECK-NEXT: "test.br"(%arg1, %arg3)[^bb1] : (tensor<2xf32>, tensor<4xf32>)
-// CHECK-NEXT: ^bb1([[VAL0:%.+]]: tensor<2xf32>, [[VAL1:%.+]]: tensor<4xf32>):
-// CHECK-NEXT: "foo.print"([[VAL0]])
-// CHECK-NEXT: "foo.print"([[VAL1]])
+// CHECK-NEXT: "test.br"()[^bb1]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT: "foo.print"(%arg1)
+// CHECK-NEXT: "foo.print"(%arg3)
// CHECK-NEXT: return
diff --git a/mlir/test/Transforms/make-isolated-from-above.mlir b/mlir/test/Transforms/make-isolated-from-above.mlir
index 58f6cfbc5dd65..a9d4325944fd9 100644
--- a/mlir/test/Transforms/make-isolated-from-above.mlir
+++ b/mlir/test/Transforms/make-isolated-from-above.mlir
@@ -78,9 +78,9 @@ func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
// CHECK: test.isolated_one_region_op %[[ARG2]], %[[C0]], %[[C1]], %[[D0]], %[[D1]]
// CHECK-NEXT: ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index, %[[B3:[a-zA-Z0-9]+]]: index, %[[B4:[a-zA-Z0-9]+]]: index)
-// CHECK-NEXT: cf.br ^bb1(%[[B0]] : index)
-// CHECK: ^bb1(%[[B5:.+]]: index)
-// CHECK: "foo.yield"(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]])
+// CHECK-NEXT: cf.br ^bb1
+// CHECK: ^bb1:
+// CHECK: "foo.yield"(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B0]])
// CLONE1-LABEL: func @make_isolated_from_above_multiple_blocks(
// CLONE1-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
@@ -95,9 +95,9 @@ func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index
// CLONE1-NEXT: ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index)
// CLONE1-DAG: %[[C0_0:.+]] = arith.constant 0 : index
// CLONE1-DAG: %[[C1_0:.+]] = arith.constant 1 : index
-// CLONE1-NEXT: cf.br ^bb1(%[[B0]] : index)
-// CLONE1: ^bb1(%[[B3:.+]]: index)
-// CLONE1: "foo.yield"(%[[C0_0]], %[[C1_0]], %[[B1]], %[[B2]], %[[B3]])
+// CLONE1-NEXT: cf.br ^bb1
+// CLONE1: ^bb1:
+// CLONE1: "foo.yield"(%[[C0_0]], %[[C1_0]], %[[B1]], %[[B2]], %[[B0]])
// CLONE2-LABEL: func @make_isolated_from_above_multiple_blocks(
// CLONE2-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
@@ -110,6 +110,6 @@ func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index
// CLONE2-DAG: %[[EMPTY:.+]] = tensor.empty(%[[B1]], %[[B2]])
// CLONE2-DAG: %[[D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]]
// CLONE2-DAG: %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
-// CLONE2-NEXT: cf.br ^bb1(%[[B0]] : index)
-// CLONE2: ^bb1(%[[B3:.+]]: index)
-// CLONE2: "foo.yield"(%[[C0]], %[[C1]], %[[D0]], %[[D1]], %[[B3]])
+// CLONE2-NEXT: cf.br ^bb1
+// CLONE2: ^bb1:
+// CLONE2: "foo.yield"(%[[C0]], %[[C1]], %[[D0]], %[[D1]], %[[B0]])
>From 12755f7d18cffcee4375d4efc3603f691f4cfb40 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Sun, 30 Jun 2024 06:49:11 +0100
Subject: [PATCH 4/8] Fix integration tests
---
mlir/lib/Transforms/Utils/RegionUtils.cpp | 103 ++++++++++++++----
.../test-canonicalize-merge-large-blocks.mlir | 86 +++++++++++++++
2 files changed, 167 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 31c4bfe1f6b2f..c336bbc913634 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -679,6 +679,64 @@ static bool ableToUpdatePredOperands(Block *block) {
return true;
}
+/// Prunes the redundant list of arguments. E.g., if we are passing an argument
+/// list like [x, y, z, x] this would return [x, y, z] and it would update the
+/// `block` (to whom the argument are passed to) accordingly
+static void
+pruneRedundantArguments(SmallVector<SmallVector<Value, 8>, 2> &newArguments,
+ RewriterBase &rewriter, Block *block) {
+ SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
+ newArguments.size(), SmallVector<Value, 8>());
+
+ if (!newArguments.empty()) {
+ llvm::DenseMap<unsigned, unsigned> toReplace;
+ // Go through the first list of arguments (list 0)
+ for (unsigned j = 0; j < newArguments[0].size(); j++) {
+ bool shouldReplaceJ = false;
+ unsigned replacement = 0;
+ // Look back to see if there are possible redundancies in
+ // list 0
+ for (unsigned k = 0; k < j; k++) {
+ if (newArguments[0][k] == newArguments[0][j]) {
+ shouldReplaceJ = true;
+ replacement = k;
+ // If a possible redundancy is found, then scan the other lists: we
+ // can prune the arguments if and only if they are redundant in every
+ // list
+ for (unsigned i = 1; i < newArguments.size(); i++)
+ shouldReplaceJ =
+ shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
+ }
+ }
+ // Save the replacement
+ if (shouldReplaceJ)
+ toReplace[j] = replacement;
+ }
+
+ // Populate the pruned argument list
+ for (unsigned i = 0; i < newArguments.size(); i++)
+ for (unsigned j = 0; j < newArguments[i].size(); j++)
+ if (!toReplace.contains(j))
+ newArgumentsPruned[i].push_back(newArguments[i][j]);
+
+ // Replace the block's redundant arguments
+ SmallVector<unsigned> toErase;
+ for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
+ if (toReplace.contains(idx)) {
+ Value oldArg = block->getArgument(idx);
+ Value newArg = block->getArgument(toReplace[idx]);
+ rewriter.replaceAllUsesWith(oldArg, newArg);
+ toErase.push_back(idx);
+ }
+ }
+
+ // Erase the block's redundant arguments
+ for (auto idxToErase : llvm::reverse(toErase))
+ block->eraseArgument(idxToErase);
+ newArguments = newArgumentsPruned;
+ }
+}
+
LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
// Don't consider clusters that don't have blocks to merge.
if (blocksToMerge.empty())
@@ -704,8 +762,9 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
blockIterators.push_back(mergeBlock->begin());
// Update each of the predecessor terminators with the new arguments.
- SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
- SmallVector<Value, 8>());
+ SmallVector<SmallVector<Value, 8>, 2> newArguments(
+ 1 + blocksToMerge.size(),
+ SmallVector<Value, 8>(operandsToMerge.size()));
unsigned curOpIndex = 0;
for (const auto &it : llvm::enumerate(operandsToMerge)) {
unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -716,25 +775,20 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
Block::iterator &blockIter = blockIterators[i];
std::advance(blockIter, nextOpOffset);
auto &operand = blockIter->getOpOperand(it.value().second);
- Value operandVal = operand.get();
- Value *it = std::find(newArguments[i].begin(), newArguments[i].end(),
- operandVal);
- if (it == newArguments[i].end()) {
- newArguments[i].push_back(operandVal);
- // Update the operand and insert an argument if this is the leader.
- if (i == 0) {
- operand.set(leaderBlock->addArgument(operandVal.getType(),
- operandVal.getLoc()));
- }
- } else if (i == 0) {
- // If this is the leader, update the operand but do not insert a new
- // argument. Instead, the opearand should point to one of the
- // arguments we already passed (and that contained `operandVal`)
- operand.set(leaderBlock->getArgument(
- std::distance(newArguments[i].begin(), it)));
+ newArguments[i][it.index()] = operand.get();
+
+ // Update the operand and insert an argument if this is the leader.
+ if (i == 0) {
+ Value operandVal = operand.get();
+ operand.set(leaderBlock->addArgument(operandVal.getType(),
+ operandVal.getLoc()));
}
}
}
+
+ // Prune redundant arguments and update the leader block argument list
+ pruneRedundantArguments(newArguments, rewriter, leaderBlock);
+
// Update the predecessors for each of the blocks.
auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
for (auto predIt = block->pred_begin(), predE = block->pred_end();
@@ -896,17 +950,22 @@ static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
/// %cond = llvm.call @rand() : () -> i1
/// %val0 = llvm.mlir.constant(1 : i64) : i64
/// %val1 = llvm.mlir.constant(2 : i64) : i64
-/// %val2 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
-/// : i64) ^bb1(%arg0 : i64, %arg1 : i64):
+/// : i64)
+///
+/// ^bb1(%arg0 : i64, %arg1 : i64):
/// llvm.call @foo(%arg0, %arg1)
///
/// The previous IR can be rewritten as:
/// %cond = llvm.call @rand() : () -> i1
-/// %val = llvm.mlir.constant(1 : i64) : i64
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
+///
/// ^bb1(%arg0 : i64):
-/// llvm.call @foo(%val0, %arg1)
+/// llvm.call @foo(%val0, %arg0)
///
static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
MutableArrayRef<Region> regions) {
diff --git a/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
index 570ff6905a04d..e821dcd0c2064 100644
--- a/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
+++ b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
@@ -74,3 +74,89 @@ llvm.func @large_merge_block(%arg0: i64) {
^bb27: // 2 preds: ^bb13, ^bb26
llvm.return
}
+
+llvm.func @redundant_args0(%cond : i1) {
+ %0 = llvm.mlir.constant(0 : i64) : i64
+ %2 = llvm.mlir.constant(1 : i64) : i64
+ %3 = llvm.mlir.constant(2 : i64) : i64
+ // CHECK %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+
+ llvm.cond_br %cond, ^bb1, ^bb2
+
+ // CHECK: llvm.cond_br %{{.*}}, ^bb{{.*}}(%[[C0]], %[[C0]] : i64, i64), ^bb{{.*}}(%[[C1]], %[[C2]] : i64, i64)
+ // CHECK: ^bb{{.*}}(%{{.*}}: i64, %{{.*}}: i64)
+^bb1:
+ llvm.call @foo(%0) : (i64) -> ()
+ llvm.call @foo(%0) : (i64) -> ()
+ llvm.br ^bb3
+^bb2:
+ llvm.call @foo(%2) : (i64) -> ()
+ llvm.call @foo(%3) : (i64) -> ()
+ llvm.br ^bb3
+^bb3:
+ llvm.return
+}
+
+llvm.func @redundant_args1(%cond : i1) {
+ %0 = llvm.mlir.constant(0 : i64) : i64
+ %2 = llvm.mlir.constant(1 : i64) : i64
+ %3 = llvm.mlir.constant(2 : i64) : i64
+ // CHECK %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+
+ llvm.cond_br %cond, ^bb1, ^bb2
+
+ // CHECK: llvm.cond_br %{{.*}}, ^bb{{.*}}(%[[C1]], %[[C2]] : i64, i64), ^bb{{.*}}(%[[C0]], %[[C0]] : i64, i64)
+ // CHECK: ^bb{{.*}}(%{{.*}}: i64, %{{.*}}: i64)
+^bb1:
+ llvm.call @foo(%2) : (i64) -> ()
+ llvm.call @foo(%3) : (i64) -> ()
+ llvm.br ^bb3
+^bb2:
+ llvm.call @foo(%0) : (i64) -> ()
+ llvm.call @foo(%0) : (i64) -> ()
+ llvm.br ^bb3
+^bb3:
+ llvm.return
+}
+
+llvm.func @redundant_args_complex(%cond : i1) {
+ %0 = llvm.mlir.constant(0 : i64) : i64
+ %1 = llvm.mlir.constant(1 : i64) : i64
+ %2 = llvm.mlir.constant(2 : i64) : i64
+ %3 = llvm.mlir.constant(3 : i64) : i64
+ // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+ // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+
+ llvm.cond_br %cond, ^bb1, ^bb2
+
+ // CHECK: llvm.cond_br %{{.*}}, ^bb{{.*}}(%[[C2]], %[[C1]], %[[C3]] : i64, i64, i64), ^bb{{.*}}(%[[C0]], %[[C3]], %[[C2]] : i64, i64, i64)
+ // CHECK: ^bb{{.*}}(%[[arg0:.*]]: i64, %[[arg1:.*]]: i64, %[[arg2:.*]]: i64):
+ // CHECK: llvm.call @foo(%[[arg0]])
+ // CHECK: llvm.call @foo(%[[arg0]])
+ // CHECK: llvm.call @foo(%[[arg1]])
+ // CHECK: llvm.call @foo(%[[C2]])
+ // CHECK: llvm.call @foo(%[[arg2]])
+
+^bb1:
+ llvm.call @foo(%2) : (i64) -> ()
+ llvm.call @foo(%2) : (i64) -> ()
+ llvm.call @foo(%1) : (i64) -> ()
+ llvm.call @foo(%2) : (i64) -> ()
+ llvm.call @foo(%3) : (i64) -> ()
+ llvm.br ^bb3
+^bb2:
+ llvm.call @foo(%0) : (i64) -> ()
+ llvm.call @foo(%0) : (i64) -> ()
+ llvm.call @foo(%3) : (i64) -> ()
+ llvm.call @foo(%2) : (i64) -> ()
+ llvm.call @foo(%2) : (i64) -> ()
+ llvm.br ^bb3
+^bb3:
+ llvm.return
+}
>From ef733362770af4af8b386c1ebbc3fca030fec27a Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Tue, 9 Jul 2024 09:50:20 +0100
Subject: [PATCH 5/8] Address review feedbacks
---
mlir/lib/Transforms/Utils/RegionUtils.cpp | 127 ++++++++++++----------
1 file changed, 68 insertions(+), 59 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index c336bbc913634..6d672b59da0f8 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -681,60 +681,70 @@ static bool ableToUpdatePredOperands(Block *block) {
/// Prunes the redundant list of arguments. E.g., if we are passing an argument
/// list like [x, y, z, x] this would return [x, y, z] and it would update the
-/// `block` (to whom the argument are passed to) accordingly
-static void
-pruneRedundantArguments(SmallVector<SmallVector<Value, 8>, 2> &newArguments,
- RewriterBase &rewriter, Block *block) {
+/// `block` (to whom the argument are passed to) accordingly.
+static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
+ const SmallVector<SmallVector<Value, 8>, 2> &newArguments,
+ RewriterBase &rewriter, Block *block) {
+
SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
newArguments.size(), SmallVector<Value, 8>());
- if (!newArguments.empty()) {
- llvm::DenseMap<unsigned, unsigned> toReplace;
- // Go through the first list of arguments (list 0)
- for (unsigned j = 0; j < newArguments[0].size(); j++) {
- bool shouldReplaceJ = false;
- unsigned replacement = 0;
- // Look back to see if there are possible redundancies in
- // list 0
- for (unsigned k = 0; k < j; k++) {
- if (newArguments[0][k] == newArguments[0][j]) {
- shouldReplaceJ = true;
- replacement = k;
- // If a possible redundancy is found, then scan the other lists: we
- // can prune the arguments if and only if they are redundant in every
- // list
- for (unsigned i = 1; i < newArguments.size(); i++)
- shouldReplaceJ =
- shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
- }
+ if (newArguments.empty())
+ return newArguments;
+
+ // `newArguments` is a 2D array of size `numLists` x `numArgs`
+ unsigned numLists = newArguments.size();
+ unsigned numArgs = newArguments[0].size();
+
+ // Map that for each arg index contains the index that we can use in place of
+ // the original index. E.g., if we have newArgs = [x, y, z, x], we will have
+ // idxToReplacement[3] = 0
+ llvm::DenseMap<unsigned, unsigned> idxToReplacement;
+
+ // Go through the first list of arguments (list 0).
+ for (unsigned j = 0; j < numArgs; ++j) {
+ bool shouldReplaceJ = false;
+ unsigned replacement = 0;
+ // Look back to see if there are possible redundancies in
+ // list 0.
+ for (unsigned k = 0; k < j; k++) {
+ if (newArguments[0][k] == newArguments[0][j]) {
+ shouldReplaceJ = true;
+ replacement = k;
+ // If a possible redundancy is found, then scan the other lists: we
+ // can prune the arguments if and only if they are redundant in every
+ // list.
+ for (unsigned i = 1; i < numLists; ++i)
+ shouldReplaceJ =
+ shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
}
- // Save the replacement
- if (shouldReplaceJ)
- toReplace[j] = replacement;
}
+ // Save the replacement.
+ if (shouldReplaceJ)
+ idxToReplacement[j] = replacement;
+ }
- // Populate the pruned argument list
- for (unsigned i = 0; i < newArguments.size(); i++)
- for (unsigned j = 0; j < newArguments[i].size(); j++)
- if (!toReplace.contains(j))
- newArgumentsPruned[i].push_back(newArguments[i][j]);
-
- // Replace the block's redundant arguments
- SmallVector<unsigned> toErase;
- for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
- if (toReplace.contains(idx)) {
- Value oldArg = block->getArgument(idx);
- Value newArg = block->getArgument(toReplace[idx]);
- rewriter.replaceAllUsesWith(oldArg, newArg);
- toErase.push_back(idx);
- }
+ // Populate the pruned argument list.
+ for (unsigned i = 0; i < numLists; ++i)
+ for (unsigned j = 0; j < numArgs; ++j)
+ if (!idxToReplacement.contains(j))
+ newArgumentsPruned[i].push_back(newArguments[i][j]);
+
+ // Replace the block's redundant arguments.
+ SmallVector<unsigned> toErase;
+ for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
+ if (idxToReplacement.contains(idx)) {
+ Value oldArg = block->getArgument(idx);
+ Value newArg = block->getArgument(idxToReplacement[idx]);
+ rewriter.replaceAllUsesWith(oldArg, newArg);
+ toErase.push_back(idx);
}
-
- // Erase the block's redundant arguments
- for (auto idxToErase : llvm::reverse(toErase))
- block->eraseArgument(idxToErase);
- newArguments = newArgumentsPruned;
}
+
+ // Erase the block's redundant arguments.
+ for (unsigned idxToErase : llvm::reverse(toErase))
+ block->eraseArgument(idxToErase);
+ return newArgumentsPruned;
}
LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
@@ -787,7 +797,7 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
}
// Prune redundant arguments and update the leader block argument list
- pruneRedundantArguments(newArguments, rewriter, leaderBlock);
+ newArguments = pruneRedundantArguments(newArguments, rewriter, leaderBlock);
// Update the predecessors for each of the blocks.
auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
@@ -889,13 +899,13 @@ static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
Block &block) {
SmallVector<size_t> argsToErase;
- // Go through the arguments of the block
- for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
+ // Go through the arguments of the block.
+ for (auto [argIdx, blockOperand] : llvm::enumerate(block.getArguments())) {
bool sameArg = true;
Value commonValue;
// Go through the block predecessor and flag if they pass to the block
- // different values for the same argument
+ // different values for the same argument.
for (auto predIt = block.pred_begin(), predE = block.pred_end();
predIt != predE; ++predIt) {
auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
@@ -905,32 +915,31 @@ static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
}
unsigned succIndex = predIt.getSuccessorIndex();
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
- auto operands = succOperands.getForwardedOperands();
+ auto branchOperands = succOperands.getForwardedOperands();
if (!commonValue) {
- commonValue = operands[argIdx];
+ commonValue = branchOperands[argIdx];
} else {
- if (operands[argIdx] != commonValue) {
+ if (branchOperands[argIdx] != commonValue) {
sameArg = false;
break;
}
}
}
- // If they are passing the same value, drop the argument
+ // If they are passing the same value, drop the argument.
if (commonValue && sameArg) {
argsToErase.push_back(argIdx);
- // Remove the argument from the block
- Value argVal = block.getArgument(argIdx);
- rewriter.replaceAllUsesWith(argVal, commonValue);
+ // Remove the argument from the block.
+ rewriter.replaceAllUsesWith(blockOperand, commonValue);
}
}
- // Remove the arguments
+ // Remove the arguments.
for (auto argIdx : llvm::reverse(argsToErase)) {
block.eraseArgument(argIdx);
- // Remove the argument from the branch ops
+ // Remove the argument from the branch ops.
for (auto predIt = block.pred_begin(), predE = block.pred_end();
predIt != predE; ++predIt) {
auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
>From d6513e1b11627eb5ee1e50e2939aeee68fd8f684 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Wed, 10 Jul 2024 07:29:20 +0100
Subject: [PATCH 6/8] Address review feedbacks - 2
---
mlir/lib/Transforms/Utils/RegionUtils.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 6d672b59da0f8..16fcfa7807535 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -979,7 +979,7 @@ static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
MutableArrayRef<Region> regions) {
llvm::SmallSetVector<Region *, 1> worklist;
- for (auto ®ion : regions)
+ for (Region ®ion : regions)
worklist.insert(®ion);
bool anyChanged = false;
while (!worklist.empty()) {
@@ -989,8 +989,8 @@ static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
for (Block &block : *region) {
anyChanged = succeeded(dropRedundantArguments(rewriter, block));
- for (auto &op : block)
- for (auto &nestedRegion : op.getRegions())
+ for (Operation &op : block)
+ for (Region &nestedRegion : op.getRegions())
worklist.insert(&nestedRegion);
}
}
>From 746f507beba9f8ee0a92811706e42072ca98c0bf Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Mon, 15 Jul 2024 11:23:17 +0100
Subject: [PATCH 7/8] Use a O(N) algorithm with a conservative tradeoff
---
mlir/lib/Transforms/Utils/RegionUtils.cpp | 43 ++++++++++++++++-------
1 file changed, 30 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 16fcfa7807535..946d65cef4186 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -701,23 +701,40 @@ static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
// idxToReplacement[3] = 0
llvm::DenseMap<unsigned, unsigned> idxToReplacement;
+ // This is a useful data structure to track the first appearance of a Value
+ // on a given list of arguments
+ DenseMap<Value, unsigned> firstValueToIdx;
+ for (unsigned j = 0; j < numArgs; ++j) {
+ Value newArg = newArguments[0][j];
+ if (!firstValueToIdx.contains(newArg))
+ firstValueToIdx[newArg] = j;
+ }
+
// Go through the first list of arguments (list 0).
for (unsigned j = 0; j < numArgs; ++j) {
bool shouldReplaceJ = false;
unsigned replacement = 0;
- // Look back to see if there are possible redundancies in
- // list 0.
- for (unsigned k = 0; k < j; k++) {
- if (newArguments[0][k] == newArguments[0][j]) {
- shouldReplaceJ = true;
- replacement = k;
- // If a possible redundancy is found, then scan the other lists: we
- // can prune the arguments if and only if they are redundant in every
- // list.
- for (unsigned i = 1; i < numLists; ++i)
- shouldReplaceJ =
- shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
- }
+ // Look back to see if there are possible redundancies in list 0. Please
+ // note that we are using a map to annotate when an argument was seen first
+ // to avoid a O(N^2) algorithm. This has the drawback that if we have two
+ // lists like:
+ // list0: [%a, %a, %a]
+ // list1: [%c, %b, %b]
+ // We cannot simplify it, because firstVlaueToIdx[%a] = 0, but we cannot
+ // point list1[1](==%b) or list1[2](==%b) to list1[0](==%c). However, since
+ // the number of arguments can be potentially unbounded we cannot afford a
+ // O(N^2) algorithm (to search to all the possible pairs) and we need to
+ // accept the trade-off.
+ unsigned k = firstValueToIdx[newArguments[0][j]];
+ if (k != j) {
+ shouldReplaceJ = true;
+ replacement = k;
+ // If a possible redundancy is found, then scan the other lists: we
+ // can prune the arguments if and only if they are redundant in every
+ // list.
+ for (unsigned i = 1; i < numLists; ++i)
+ shouldReplaceJ =
+ shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
}
// Save the replacement.
if (shouldReplaceJ)
>From e018af01476e174c1f03b8676ab5be749a13d7af Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Mon, 5 Aug 2024 19:08:58 +0100
Subject: [PATCH 8/8] Offset the new arguments by the number of old arguments
---
mlir/lib/Transforms/Utils/RegionUtils.cpp | 22 +++++++++-----
.../test-canonicalize-merge-large-blocks.mlir | 30 +++++++++++++++++++
2 files changed, 44 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 946d65cef4186..3e15018bdb765 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -679,12 +679,15 @@ static bool ableToUpdatePredOperands(Block *block) {
return true;
}
-/// Prunes the redundant list of arguments. E.g., if we are passing an argument
-/// list like [x, y, z, x] this would return [x, y, z] and it would update the
-/// `block` (to whom the argument are passed to) accordingly.
+/// Prunes the redundant list of new arguments. E.g., if we are passing an
+/// argument list like [x, y, z, x] this would return [x, y, z] and it would
+/// update the `block` (to whom the argument are passed to) accordingly. The new
+/// arguments are passed as arguments at the back of the block, hence we need to
+/// know how many `numOldArguments` were before, in order to correctly replace
+/// the new arguments in the block
static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
const SmallVector<SmallVector<Value, 8>, 2> &newArguments,
- RewriterBase &rewriter, Block *block) {
+ RewriterBase &rewriter, unsigned numOldArguments, Block *block) {
SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
newArguments.size(), SmallVector<Value, 8>());
@@ -751,10 +754,11 @@ static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
SmallVector<unsigned> toErase;
for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
if (idxToReplacement.contains(idx)) {
- Value oldArg = block->getArgument(idx);
- Value newArg = block->getArgument(idxToReplacement[idx]);
+ Value oldArg = block->getArgument(numOldArguments + idx);
+ Value newArg =
+ block->getArgument(numOldArguments + idxToReplacement[idx]);
rewriter.replaceAllUsesWith(oldArg, newArg);
- toErase.push_back(idx);
+ toErase.push_back(numOldArguments + idx);
}
}
@@ -793,6 +797,7 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
1 + blocksToMerge.size(),
SmallVector<Value, 8>(operandsToMerge.size()));
unsigned curOpIndex = 0;
+ unsigned numOldArguments = leaderBlock->getNumArguments();
for (const auto &it : llvm::enumerate(operandsToMerge)) {
unsigned nextOpOffset = it.value().first - curOpIndex;
curOpIndex = it.value().first;
@@ -814,7 +819,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
}
// Prune redundant arguments and update the leader block argument list
- newArguments = pruneRedundantArguments(newArguments, rewriter, leaderBlock);
+ newArguments = pruneRedundantArguments(newArguments, rewriter,
+ numOldArguments, leaderBlock);
// Update the predecessors for each of the blocks.
auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
diff --git a/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
index e821dcd0c2064..84df83619d28a 100644
--- a/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
+++ b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
@@ -160,3 +160,33 @@ llvm.func @redundant_args_complex(%cond : i1) {
^bb3:
llvm.return
}
+
+llvm.func @blocks_with_args() {
+ %0 = llvm.mlir.zero : !llvm.ptr
+ %1 = llvm.call @rand() : () -> i1
+ // CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : i64)
+ // CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i64)
+ // CHECK: %[[cond:.*]] = llvm.call @rand
+ %3 = llvm.mlir.constant(0) : i64
+ %4 = llvm.mlir.constant(1) : i64
+ // CHECK: llvm.cond_br %[[cond]], ^bb1(%[[c1]] : i64), ^bb1(%[[c0]] : i64)
+ // CHECK: ^bb1(%{{.*}}: i64):
+ // CHECK ^bb2:
+ // CHECK ^bb3:
+ // CHECK llvm.return
+ llvm.cond_br %1, ^bb7(%0 : !llvm.ptr), ^bb1(%0 : !llvm.ptr)
+^bb1(%5: !llvm.ptr):
+ llvm.store %5, %0 : !llvm.ptr, !llvm.ptr
+ llvm.cond_br %1, ^bb2(%3 : i64), ^bb4(%3 : i64)
+^bb7(%6: !llvm.ptr):
+ llvm.store %6, %0 : !llvm.ptr, !llvm.ptr
+ llvm.cond_br %1, ^bb2(%4 : i64), ^bb4(%4 : i64)
+^bb2(%7: i64):
+ llvm.call @foo(%7) : (i64) -> ()
+ llvm.br ^bb8
+^bb4(%8: i64):
+ llvm.call @foo(%8) : (i64) -> ()
+ llvm.br ^bb8
+^bb8:
+ llvm.return
+}
More information about the Mlir-commits
mailing list