[Mlir-commits] [mlir] Refactor LoopFuseSiblingOp and support parallel fusion (PR #94391)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri Jun 21 04:47:02 PDT 2024


================
@@ -1070,104 +1070,182 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
   return tileLoops;
 }
 
-scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
-                                                      scf::ForallOp source,
-                                                      RewriterBase &rewriter) {
-  unsigned numTargetOuts = target.getNumResults();
-  unsigned numSourceOuts = source.getNumResults();
+//===----------------------------------------------------------------------===//
+// Fusion related helpers
+//===----------------------------------------------------------------------===//
 
-  // Create fused shared_outs.
-  SmallVector<Value> fusedOuts;
-  llvm::append_range(fusedOuts, target.getOutputs());
-  llvm::append_range(fusedOuts, source.getOutputs());
+bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
+                                         LoopLikeOpInterface &source) {
+  auto iterSpaceEq =
+      target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
+      target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
+      target.getLoopSteps() == source.getLoopSteps();
+  auto forAllTarget = dyn_cast<scf::ForallOp>(*target);
+  auto forAllSource = dyn_cast<scf::ForallOp>(*source);
+  if (forAllTarget && forAllSource)
+    return iterSpaceEq &&
+           forAllTarget.getMapping() == forAllSource.getMapping();
+  return iterSpaceEq;
+}
 
-  // Create a new scf.forall op after the source loop.
-  rewriter.setInsertionPointAfter(source);
-  scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
-      source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
-      source.getMixedStep(), fusedOuts, source.getMapping());
+template <typename LoopTy>
+void fuseTerminator(RewriterBase &rewriter, LoopTy source, LoopTy &fused,
+                    IRMapping &mapping) {}
 
-  // Map control operands.
-  IRMapping mapping;
-  mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
-  mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
+template <>
+void fuseTerminator(RewriterBase &rewriter, scf::ForallOp source,
+                    scf::ForallOp &fused, IRMapping &mapping) {
+  // Fuse the old terminator in_parallel ops into the new one.
+  scf::InParallelOp fusedTerm = fused.getTerminator();
+  rewriter.setInsertionPointToEnd(fusedTerm.getBody());
+  for (Operation &op : source.getTerminator().getYieldingOps())
+    rewriter.clone(op, mapping);
+}
 
-  // Map shared outs.
-  mapping.map(target.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
-  mapping.map(source.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
+template <>
+void fuseTerminator(RewriterBase &rewriter, scf::ForOp source,
+                    scf::ForOp &fused, IRMapping &mapping) {
+  // Build fused yield results by appropriately mapping original yield operands.
+  auto newTerm = rewriter.clone(*fused.getBody()->getTerminator(), mapping);
+  rewriter.replaceOp(fused.getBody()->getTerminator(), newTerm);
+}
+
+// TODO: We should maybe add a method to LoopLikeOpInterface that will
+// facilitate this transformation. For now, this acts as a placeholder.
+template <>
+void fuseTerminator(RewriterBase &rewriter, LoopLikeOpInterface source,
+                    LoopLikeOpInterface &fused, IRMapping &mapping) {
+  if (isa<scf::ForOp>(source) && isa<scf::ForOp>(fused)) {
+    fuseTerminator(rewriter, cast<scf::ForOp>(source), cast<scf::ForOp>(fused),
+                   mapping);
+  } else if (isa<scf::ForallOp>(source) && isa<scf::ForallOp>(fused)) {
+    fuseTerminator(rewriter, cast<scf::ForallOp>(source),
+                   cast<scf::ForallOp>(fused), mapping);
+  } else if (isa<scf::ParallelOp>(source) && isa<scf::ParallelOp>(fused)) {
+    fuseTerminator(rewriter, cast<scf::ParallelOp>(source),
+                   cast<scf::ParallelOp>(fused), mapping);
+  } else {
+    return;
+  }
+}
 
+LoopLikeOpInterface createFused(LoopLikeOpInterface target,
+                                LoopLikeOpInterface source,
+                                RewriterBase &rewriter,
+                                NewYieldValuesFn newYieldValuesFn) {
+  auto targetIterArgs = target.getRegionIterArgs();
+  auto targetInductionVar = *target.getLoopInductionVars();
+  SmallVector<Value> targetYieldOperands(target.getYieldedValues());
+  auto sourceIterArgs = source.getRegionIterArgs();
+  auto sourceInductionVar = *source.getLoopInductionVars();
+  SmallVector<Value> sourceYieldOperands(source.getYieldedValues());
+  auto sourceRegion = source.getLoopRegions().front();
+  LoopLikeOpInterface fusedLoop = *target.replaceWithAdditionalYields(
+      rewriter, source.getInits(), /*replaceInitOperandUsesInLoop=*/false,
+      newYieldValuesFn);
+
+  // Map control operands.
+  IRMapping mapping;
+  mapping.map(targetInductionVar, *fusedLoop.getLoopInductionVars());
+  mapping.map(targetIterArgs,
+              fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size()));
+  mapping.map(targetYieldOperands,
+              fusedLoop.getYieldedValues().take_front(targetIterArgs.size()));
+  mapping.map(sourceInductionVar, *fusedLoop.getLoopInductionVars());
+  mapping.map(sourceIterArgs,
+              fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size()));
+  mapping.map(sourceYieldOperands,
+              fusedLoop.getYieldedValues().take_back(sourceIterArgs.size()));
   // Append everything except the terminator into the fused operation.
-  rewriter.setInsertionPointToStart(fusedLoop.getBody());
-  for (Operation &op : target.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
-  for (Operation &op : source.getBody()->without_terminator())
+  rewriter.setInsertionPoint(
+      fusedLoop.getLoopRegions().front()->front().getTerminator());
+  for (Operation &op : sourceRegion->front().without_terminator())
     rewriter.clone(op, mapping);
 
-  // Fuse the old terminator in_parallel ops into the new one.
-  scf::InParallelOp targetTerm = target.getTerminator();
-  scf::InParallelOp sourceTerm = source.getTerminator();
-  scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
-  rewriter.setInsertionPointToStart(fusedTerm.getBody());
-  for (Operation &op : targetTerm.getYieldingOps())
-    rewriter.clone(op, mapping);
-  for (Operation &op : sourceTerm.getYieldingOps())
-    rewriter.clone(op, mapping);
+  // TODO: Replace with corresponding interface method if added
+  fuseTerminator(rewriter, source, fusedLoop, mapping);
 
-  // Replace old loops by substituting their uses by results of the fused loop.
-  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
-  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+  return fusedLoop;
+}
+
+scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
+                                                      scf::ForallOp source,
+                                                      RewriterBase &rewriter) {
+  scf::ForallOp fusedLoop = cast<scf::ForallOp>(createFused(
+      target, source, rewriter,
+      [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+        // `ForallOp` does not have yields, rather an `InParallelOp` terminator.
+        return ValueRange{};
+      }));
+  rewriter.replaceOp(source,
+                     fusedLoop.getResults().take_back(source.getNumResults()));
 
   return fusedLoop;
 }
 
 scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
                                                 scf::ForOp source,
                                                 RewriterBase &rewriter) {
-  unsigned numTargetOuts = target.getNumResults();
-  unsigned numSourceOuts = source.getNumResults();
-
-  // Create fused init_args, with target's init_args before source's init_args.
-  SmallVector<Value> fusedInitArgs;
-  llvm::append_range(fusedInitArgs, target.getInitArgs());
-  llvm::append_range(fusedInitArgs, source.getInitArgs());
-
-  // Create a new scf.for op after the source loop (with scf.yield terminator
-  // (without arguments) only in case its init_args is empty).
-  rewriter.setInsertionPointAfter(source);
-  scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
-      source.getLoc(), source.getLowerBound(), source.getUpperBound(),
-      source.getStep(), fusedInitArgs);
+  scf::ForOp fusedLoop = cast<scf::ForOp>(createFused(
+      target, source, rewriter,
+      [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+        return source.getYieldedValues();
+      }));
+  rewriter.replaceOp(source,
+                     fusedLoop.getResults().take_back(source.getNumResults()));
+  return fusedLoop;
+}
 
-  // Map original induction variables and operands to those of the fused loop.
-  IRMapping mapping;
-  mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
-  mapping.map(target.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
-  mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
-  mapping.map(source.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
-
-  // Merge target's body into the new (fused) for loop and then source's body.
-  rewriter.setInsertionPointToStart(fusedLoop.getBody());
-  for (Operation &op : target.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
-  for (Operation &op : source.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
+scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
----------------
ftynse wrote:

I don't remember offhand the structure of `scf.parallel`, it a design for reduction handling that appeared suboptimal in hindsight. Having an interface method for terminator handling is okay as long as it has a default implementation (return failure() and reject to fuse presumably). Otherwise, we don't have to use interfaces for this, after all, there aren't that many different loop constructs so we might as well live with a switch.

https://github.com/llvm/llvm-project/pull/94391


More information about the Mlir-commits mailing list