[Mlir-commits] [mlir] Refactor LoopFuseSiblingOp and support parallel fusion (PR #94391)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 21 14:01:04 PDT 2024
================
@@ -1070,104 +1070,182 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
return tileLoops;
}
-scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
- scf::ForallOp source,
- RewriterBase &rewriter) {
- unsigned numTargetOuts = target.getNumResults();
- unsigned numSourceOuts = source.getNumResults();
+//===----------------------------------------------------------------------===//
+// Fusion related helpers
+//===----------------------------------------------------------------------===//
- // Create fused shared_outs.
- SmallVector<Value> fusedOuts;
- llvm::append_range(fusedOuts, target.getOutputs());
- llvm::append_range(fusedOuts, source.getOutputs());
+bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
+ LoopLikeOpInterface &source) {
+ auto iterSpaceEq =
+ target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
+ target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
+ target.getLoopSteps() == source.getLoopSteps();
+ auto forAllTarget = dyn_cast<scf::ForallOp>(*target);
+ auto forAllSource = dyn_cast<scf::ForallOp>(*source);
+ if (forAllTarget && forAllSource)
+ return iterSpaceEq &&
+ forAllTarget.getMapping() == forAllSource.getMapping();
+ return iterSpaceEq;
+}
- // Create a new scf.forall op after the source loop.
- rewriter.setInsertionPointAfter(source);
- scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
- source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
- source.getMixedStep(), fusedOuts, source.getMapping());
+template <typename LoopTy>
+void fuseTerminator(RewriterBase &rewriter, LoopTy source, LoopTy &fused,
+ IRMapping &mapping) {}
- // Map control operands.
- IRMapping mapping;
- mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
- mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
+template <>
+void fuseTerminator(RewriterBase &rewriter, scf::ForallOp source,
+ scf::ForallOp &fused, IRMapping &mapping) {
+ // Fuse the old terminator in_parallel ops into the new one.
+ scf::InParallelOp fusedTerm = fused.getTerminator();
+ rewriter.setInsertionPointToEnd(fusedTerm.getBody());
+ for (Operation &op : source.getTerminator().getYieldingOps())
+ rewriter.clone(op, mapping);
+}
- // Map shared outs.
- mapping.map(target.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
- mapping.map(source.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
+template <>
+void fuseTerminator(RewriterBase &rewriter, scf::ForOp source,
+ scf::ForOp &fused, IRMapping &mapping) {
+ // Build fused yield results by appropriately mapping original yield operands.
+ auto newTerm = rewriter.clone(*fused.getBody()->getTerminator(), mapping);
+ rewriter.replaceOp(fused.getBody()->getTerminator(), newTerm);
+}
+
+// TODO: We should maybe add a method to LoopLikeOpInterface that will
+// facilitate this transformation. For now, this acts as a placeholder.
+template <>
+void fuseTerminator(RewriterBase &rewriter, LoopLikeOpInterface source,
+ LoopLikeOpInterface &fused, IRMapping &mapping) {
+ if (isa<scf::ForOp>(source) && isa<scf::ForOp>(fused)) {
+ fuseTerminator(rewriter, cast<scf::ForOp>(source), cast<scf::ForOp>(fused),
+ mapping);
+ } else if (isa<scf::ForallOp>(source) && isa<scf::ForallOp>(fused)) {
+ fuseTerminator(rewriter, cast<scf::ForallOp>(source),
+ cast<scf::ForallOp>(fused), mapping);
+ } else if (isa<scf::ParallelOp>(source) && isa<scf::ParallelOp>(fused)) {
+ fuseTerminator(rewriter, cast<scf::ParallelOp>(source),
+ cast<scf::ParallelOp>(fused), mapping);
+ } else {
+ return;
+ }
+}
+LoopLikeOpInterface createFused(LoopLikeOpInterface target,
+ LoopLikeOpInterface source,
+ RewriterBase &rewriter,
+ NewYieldValuesFn newYieldValuesFn) {
+ auto targetIterArgs = target.getRegionIterArgs();
+ auto targetInductionVar = *target.getLoopInductionVars();
+ SmallVector<Value> targetYieldOperands(target.getYieldedValues());
+ auto sourceIterArgs = source.getRegionIterArgs();
+ auto sourceInductionVar = *source.getLoopInductionVars();
+ SmallVector<Value> sourceYieldOperands(source.getYieldedValues());
+ auto sourceRegion = source.getLoopRegions().front();
+ LoopLikeOpInterface fusedLoop = *target.replaceWithAdditionalYields(
+ rewriter, source.getInits(), /*replaceInitOperandUsesInLoop=*/false,
+ newYieldValuesFn);
+
+ // Map control operands.
+ IRMapping mapping;
+ mapping.map(targetInductionVar, *fusedLoop.getLoopInductionVars());
+ mapping.map(targetIterArgs,
+ fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size()));
+ mapping.map(targetYieldOperands,
+ fusedLoop.getYieldedValues().take_front(targetIterArgs.size()));
+ mapping.map(sourceInductionVar, *fusedLoop.getLoopInductionVars());
+ mapping.map(sourceIterArgs,
+ fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size()));
+ mapping.map(sourceYieldOperands,
+ fusedLoop.getYieldedValues().take_back(sourceIterArgs.size()));
// Append everything except the terminator into the fused operation.
- rewriter.setInsertionPointToStart(fusedLoop.getBody());
- for (Operation &op : target.getBody()->without_terminator())
- rewriter.clone(op, mapping);
- for (Operation &op : source.getBody()->without_terminator())
+ rewriter.setInsertionPoint(
+ fusedLoop.getLoopRegions().front()->front().getTerminator());
+ for (Operation &op : sourceRegion->front().without_terminator())
rewriter.clone(op, mapping);
- // Fuse the old terminator in_parallel ops into the new one.
- scf::InParallelOp targetTerm = target.getTerminator();
- scf::InParallelOp sourceTerm = source.getTerminator();
- scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
- rewriter.setInsertionPointToStart(fusedTerm.getBody());
- for (Operation &op : targetTerm.getYieldingOps())
- rewriter.clone(op, mapping);
- for (Operation &op : sourceTerm.getYieldingOps())
- rewriter.clone(op, mapping);
+ // TODO: Replace with corresponding interface method if added
+ fuseTerminator(rewriter, source, fusedLoop, mapping);
- // Replace old loops by substituting their uses by results of the fused loop.
- rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
- rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+ return fusedLoop;
+}
+
+scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
+ scf::ForallOp source,
+ RewriterBase &rewriter) {
+ scf::ForallOp fusedLoop = cast<scf::ForallOp>(createFused(
+ target, source, rewriter,
+ [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+ // `ForallOp` does not have yields, rather an `InParallelOp` terminator.
+ return ValueRange{};
+ }));
+ rewriter.replaceOp(source,
+ fusedLoop.getResults().take_back(source.getNumResults()));
return fusedLoop;
}
scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
scf::ForOp source,
RewriterBase &rewriter) {
- unsigned numTargetOuts = target.getNumResults();
- unsigned numSourceOuts = source.getNumResults();
-
- // Create fused init_args, with target's init_args before source's init_args.
- SmallVector<Value> fusedInitArgs;
- llvm::append_range(fusedInitArgs, target.getInitArgs());
- llvm::append_range(fusedInitArgs, source.getInitArgs());
-
- // Create a new scf.for op after the source loop (with scf.yield terminator
- // (without arguments) only in case its init_args is empty).
- rewriter.setInsertionPointAfter(source);
- scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
- source.getLoc(), source.getLowerBound(), source.getUpperBound(),
- source.getStep(), fusedInitArgs);
+ scf::ForOp fusedLoop = cast<scf::ForOp>(createFused(
+ target, source, rewriter,
+ [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+ return source.getYieldedValues();
+ }));
+ rewriter.replaceOp(source,
+ fusedLoop.getResults().take_back(source.getNumResults()));
+ return fusedLoop;
+}
- // Map original induction variables and operands to those of the fused loop.
- IRMapping mapping;
- mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
- mapping.map(target.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
- mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
- mapping.map(source.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
-
- // Merge target's body into the new (fused) for loop and then source's body.
- rewriter.setInsertionPointToStart(fusedLoop.getBody());
- for (Operation &op : target.getBody()->without_terminator())
- rewriter.clone(op, mapping);
- for (Operation &op : source.getBody()->without_terminator())
- rewriter.clone(op, mapping);
+scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
----------------
srcarroll wrote:
> after all, there aren't that many different loop constructs so we might as well live with a switch.
although there may not be that many in scf, downstream users can have an arbitrary number of loop like ops that use the interface. so i think it's a better idea in general to push as much common logic as possible to loop like interface so that common patterns can be extended to new ops users come up with
https://github.com/llvm/llvm-project/pull/94391
More information about the Mlir-commits
mailing list