[Mlir-commits] [mlir] [MLIR] Parallel loop fusion extended to interchanged loops. (PR #191245)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 9 09:56:05 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Dmitriy Smirnov (d-smirnov)

<details>
<summary>Changes</summary>

Patch extends fusion of two parallel loops to the case where the second parallel loop comprises of two interchanged loops of same iteration space.


---
Full diff: https://github.com/llvm/llvm-project/pull/191245.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (+52-2) 
- (modified) mlir/test/Dialect/SCF/parallel-loop-fusion.mlir (+26) 


``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 0b132e9109492..e0e6aad612aa8 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -733,6 +733,37 @@ static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
                                         firstToSecondPloopIndices, mayAlias, b);
 }
 
+// Interchange loops of the parallel loop, if there are just two loops
+std::optional<ParallelOp> interchangeLoops(OpBuilder &builder,
+                                           ParallelOp &loop) {
+
+  if (loop.getNumLoops() != 2)
+    return std::nullopt;
+
+  OpBuilder::InsertPoint ip = builder.saveInsertionPoint();
+
+  // Replace the parallel loop with the same parallel loop.
+  builder.setInsertionPoint(loop);
+  auto newOp = ParallelOp::create(builder, loop.getLoc(), loop.getLowerBound(),
+                                  loop.getUpperBound(), loop.getStep(),
+                                  loop.getInitVals(), nullptr);
+  IRMapping mapping;
+  for (auto [iv, riv] : llvm::zip(loop.getInductionVars(),
+                                  llvm::reverse(newOp.getInductionVars()))) {
+    mapping.map(iv, riv);
+  }
+  // Copy parallel loop body
+  if (newOp) {
+    builder.setInsertionPoint(&(newOp.getBody()->front()));
+    for (auto &o : loop.getRegion().getOps()) {
+      if (!isa<scf::ReduceOp>(o))
+        builder.clone(o, mapping);
+    }
+  }
+  builder.restoreInsertionPoint(ip);
+  return newOp;
+}
+
 /// Prepend operations of firstPloop's body into secondPloop's body.
 /// Update secondPloop with new loop.
 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
@@ -744,8 +775,27 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
   firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
 
   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
-                     mayAlias, builder))
-    return;
+                     mayAlias, builder)) {
+    // If second parallel loop consists of two loops of same iteration space
+    // then exchange these loops and re-asses the possibility of fusion.
+    if (secondPloop.getNumLoops() == 2 &&
+        secondPloop.getUpperBound()[0] == secondPloop.getUpperBound()[1] &&
+        secondPloop.getLowerBound()[0] == secondPloop.getLowerBound()[1] &&
+        secondPloop.getStep()[0] == secondPloop.getStep()[1]) {
+      firstToSecondPloopIndices.clear();
+      firstToSecondPloopIndices.map(block1->getArguments(),
+                                    llvm::reverse(block2->getArguments()));
+      if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
+                         mayAlias, builder))
+        return;
+      auto newLoop = interchangeLoops(builder, secondPloop);
+      secondPloop->erase();
+      secondPloop = *newLoop;
+      block2 = secondPloop.getBody();
+    } else {
+      return;
+    }
+  }
 
   DominanceInfo dom;
   // We are fusing first loop into second, make sure there are no users of the
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index d876062b704f2..ac6ab11963bd5 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -1223,3 +1223,29 @@ func.func @reductions_use_res_between(%A: memref<2x2xf32>, %B: memref<2x2xf32>)
 // CHECK-LABEL: func @reductions_use_res_between
 // CHECK:      scf.parallel
 // CHECK:      scf.parallel
+
+// -----
+
+func.func @test_fuse_interchanged_loops(%arg0: memref<1x64xf32>) {
+  %c8 = arith.constant 8 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %alloc_0 = memref.alloc() : memref<1x8x8xf32>
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x8x1xf32>
+  scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) {
+    %0 = memref.load %alloc_0[%c0, %arg2, %arg3] : memref<1x8x8xf32>
+    memref.store %0, %alloc[%arg3, %arg2, %c0] : memref<8x8x1xf32>
+    scf.reduce
+  }
+  scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) {
+    %0 = memref.load %alloc[%arg2, %arg3, %c0] : memref<8x8x1xf32>
+    %1 = affine.apply affine_map<(d0, d1) -> (d0 * 8 + d1)>(%arg2, %arg3)
+    memref.store %0, %arg0[%c0, %1] : memref<1x64xf32>
+    scf.reduce
+  }
+  return
+}
+
+// CHECK-LABEL: func @test_fuse_interchanged_loops
+// CHECK:      scf.parallel
+// CHECK-NOT:      scf.parallel

``````````

</details>


https://github.com/llvm/llvm-project/pull/191245


More information about the Mlir-commits mailing list