[Mlir-commits] [mlir] Refactor LoopFuseSiblingOp and support parallel fusion (PR #94391)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri Jun 21 04:47:02 PDT 2024


================
@@ -1070,104 +1070,182 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
   return tileLoops;
 }
 
-scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
-                                                      scf::ForallOp source,
-                                                      RewriterBase &rewriter) {
-  unsigned numTargetOuts = target.getNumResults();
-  unsigned numSourceOuts = source.getNumResults();
+//===----------------------------------------------------------------------===//
+// Fusion related helpers
+//===----------------------------------------------------------------------===//
 
-  // Create fused shared_outs.
-  SmallVector<Value> fusedOuts;
-  llvm::append_range(fusedOuts, target.getOutputs());
-  llvm::append_range(fusedOuts, source.getOutputs());
+bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
+                                         LoopLikeOpInterface &source) {
+  auto iterSpaceEq =
+      target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
+      target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
+      target.getLoopSteps() == source.getLoopSteps();
+  auto forAllTarget = dyn_cast<scf::ForallOp>(*target);
+  auto forAllSource = dyn_cast<scf::ForallOp>(*source);
+  if (forAllTarget && forAllSource)
+    return iterSpaceEq &&
+           forAllTarget.getMapping() == forAllSource.getMapping();
+  return iterSpaceEq;
+}
 
-  // Create a new scf.forall op after the source loop.
-  rewriter.setInsertionPointAfter(source);
-  scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
-      source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
-      source.getMixedStep(), fusedOuts, source.getMapping());
+template <typename LoopTy>
+void fuseTerminator(RewriterBase &rewriter, LoopTy source, LoopTy &fused,
+                    IRMapping &mapping) {}
 
-  // Map control operands.
-  IRMapping mapping;
-  mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
-  mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
+template <>
+void fuseTerminator(RewriterBase &rewriter, scf::ForallOp source,
+                    scf::ForallOp &fused, IRMapping &mapping) {
+  // Fuse the old terminator in_parallel ops into the new one.
+  scf::InParallelOp fusedTerm = fused.getTerminator();
+  rewriter.setInsertionPointToEnd(fusedTerm.getBody());
+  for (Operation &op : source.getTerminator().getYieldingOps())
+    rewriter.clone(op, mapping);
+}
 
-  // Map shared outs.
-  mapping.map(target.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
-  mapping.map(source.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
+template <>
+void fuseTerminator(RewriterBase &rewriter, scf::ForOp source,
+                    scf::ForOp &fused, IRMapping &mapping) {
+  // Build fused yield results by appropriately mapping original yield operands.
+  auto newTerm = rewriter.clone(*fused.getBody()->getTerminator(), mapping);
+  rewriter.replaceOp(fused.getBody()->getTerminator(), newTerm);
+}
+
+// TODO: We should maybe add a method to LoopLikeOpInterface that will
+// facilitate this transformation. For now, this acts as a placeholder.
+template <>
+void fuseTerminator(RewriterBase &rewriter, LoopLikeOpInterface source,
+                    LoopLikeOpInterface &fused, IRMapping &mapping) {
+  if (isa<scf::ForOp>(source) && isa<scf::ForOp>(fused)) {
+    fuseTerminator(rewriter, cast<scf::ForOp>(source), cast<scf::ForOp>(fused),
+                   mapping);
+  } else if (isa<scf::ForallOp>(source) && isa<scf::ForallOp>(fused)) {
+    fuseTerminator(rewriter, cast<scf::ForallOp>(source),
+                   cast<scf::ForallOp>(fused), mapping);
+  } else if (isa<scf::ParallelOp>(source) && isa<scf::ParallelOp>(fused)) {
+    fuseTerminator(rewriter, cast<scf::ParallelOp>(source),
+                   cast<scf::ParallelOp>(fused), mapping);
+  } else {
+    return;
----------------
ftynse wrote:

llvm_unreachable("unsupported"), or return `LogicalResult` and propagate to the caller of fusion when impossible.

https://github.com/llvm/llvm-project/pull/94391


More information about the Mlir-commits mailing list