[Mlir-commits] [mlir] [MLIR] Parallel loop fusion extended to interchanged loops. (PR #191245)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 9 09:56:05 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Dmitriy Smirnov (d-smirnov)
<details>
<summary>Changes</summary>
Patch extends fusion of two parallel loops to the case where the second parallel loop comprises of two interchanged loops of same iteration space.
---
Full diff: https://github.com/llvm/llvm-project/pull/191245.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (+52-2)
- (modified) mlir/test/Dialect/SCF/parallel-loop-fusion.mlir (+26)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 0b132e9109492..e0e6aad612aa8 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -733,6 +733,37 @@ static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
firstToSecondPloopIndices, mayAlias, b);
}
+// Interchange loops of the parallel loop, if there are just two loops
+std::optional<ParallelOp> interchangeLoops(OpBuilder &builder,
+ ParallelOp &loop) {
+
+ if (loop.getNumLoops() != 2)
+ return std::nullopt;
+
+ OpBuilder::InsertPoint ip = builder.saveInsertionPoint();
+
+ // Replace the parallel loop with the same parallel loop.
+ builder.setInsertionPoint(loop);
+ auto newOp = ParallelOp::create(builder, loop.getLoc(), loop.getLowerBound(),
+ loop.getUpperBound(), loop.getStep(),
+ loop.getInitVals(), nullptr);
+ IRMapping mapping;
+ for (auto [iv, riv] : llvm::zip(loop.getInductionVars(),
+ llvm::reverse(newOp.getInductionVars()))) {
+ mapping.map(iv, riv);
+ }
+ // Copy parallel loop body
+ if (newOp) {
+ builder.setInsertionPoint(&(newOp.getBody()->front()));
+ for (auto &o : loop.getRegion().getOps()) {
+ if (!isa<scf::ReduceOp>(o))
+ builder.clone(o, mapping);
+ }
+ }
+ builder.restoreInsertionPoint(ip);
+ return newOp;
+}
+
/// Prepend operations of firstPloop's body into secondPloop's body.
/// Update secondPloop with new loop.
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
@@ -744,8 +775,27 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
- mayAlias, builder))
- return;
+ mayAlias, builder)) {
+ // If second parallel loop consists of two loops of same iteration space
+ // then exchange these loops and re-asses the possibility of fusion.
+ if (secondPloop.getNumLoops() == 2 &&
+ secondPloop.getUpperBound()[0] == secondPloop.getUpperBound()[1] &&
+ secondPloop.getLowerBound()[0] == secondPloop.getLowerBound()[1] &&
+ secondPloop.getStep()[0] == secondPloop.getStep()[1]) {
+ firstToSecondPloopIndices.clear();
+ firstToSecondPloopIndices.map(block1->getArguments(),
+ llvm::reverse(block2->getArguments()));
+ if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
+ mayAlias, builder))
+ return;
+ auto newLoop = interchangeLoops(builder, secondPloop);
+ secondPloop->erase();
+ secondPloop = *newLoop;
+ block2 = secondPloop.getBody();
+ } else {
+ return;
+ }
+ }
DominanceInfo dom;
// We are fusing first loop into second, make sure there are no users of the
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index d876062b704f2..ac6ab11963bd5 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -1223,3 +1223,29 @@ func.func @reductions_use_res_between(%A: memref<2x2xf32>, %B: memref<2x2xf32>)
// CHECK-LABEL: func @reductions_use_res_between
// CHECK: scf.parallel
// CHECK: scf.parallel
+
+// -----
+
+func.func @test_fuse_interchanged_loops(%arg0: memref<1x64xf32>) {
+ %c8 = arith.constant 8 : index
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %alloc_0 = memref.alloc() : memref<1x8x8xf32>
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x8x1xf32>
+ scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) {
+ %0 = memref.load %alloc_0[%c0, %arg2, %arg3] : memref<1x8x8xf32>
+ memref.store %0, %alloc[%arg3, %arg2, %c0] : memref<8x8x1xf32>
+ scf.reduce
+ }
+ scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) {
+ %0 = memref.load %alloc[%arg2, %arg3, %c0] : memref<8x8x1xf32>
+ %1 = affine.apply affine_map<(d0, d1) -> (d0 * 8 + d1)>(%arg2, %arg3)
+ memref.store %0, %arg0[%c0, %1] : memref<1x64xf32>
+ scf.reduce
+ }
+ return
+}
+
+// CHECK-LABEL: func @test_fuse_interchanged_loops
+// CHECK: scf.parallel
+// CHECK-NOT: scf.parallel
``````````
</details>
https://github.com/llvm/llvm-project/pull/191245
More information about the Mlir-commits
mailing list