[Mlir-commits] [mlir] Refactor LoopFuseSiblingOp and support parallel fusion (PR #94391)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 24 12:59:01 PDT 2024
================
@@ -1070,104 +1070,182 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
return tileLoops;
}
-scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
- scf::ForallOp source,
- RewriterBase &rewriter) {
- unsigned numTargetOuts = target.getNumResults();
- unsigned numSourceOuts = source.getNumResults();
+//===----------------------------------------------------------------------===//
+// Fusion related helpers
+//===----------------------------------------------------------------------===//
- // Create fused shared_outs.
- SmallVector<Value> fusedOuts;
- llvm::append_range(fusedOuts, target.getOutputs());
- llvm::append_range(fusedOuts, source.getOutputs());
+bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
+ LoopLikeOpInterface &source) {
+ auto iterSpaceEq =
+ target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
+ target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
+ target.getLoopSteps() == source.getLoopSteps();
+ auto forAllTarget = dyn_cast<scf::ForallOp>(*target);
+ auto forAllSource = dyn_cast<scf::ForallOp>(*source);
+ if (forAllTarget && forAllSource)
+ return iterSpaceEq &&
+ forAllTarget.getMapping() == forAllSource.getMapping();
+ return iterSpaceEq;
+}
- // Create a new scf.forall op after the source loop.
- rewriter.setInsertionPointAfter(source);
- scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
- source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
- source.getMixedStep(), fusedOuts, source.getMapping());
+template <typename LoopTy>
+void fuseTerminator(RewriterBase &rewriter, LoopTy source, LoopTy &fused,
+ IRMapping &mapping) {}
- // Map control operands.
- IRMapping mapping;
- mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
- mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
+template <>
+void fuseTerminator(RewriterBase &rewriter, scf::ForallOp source,
+ scf::ForallOp &fused, IRMapping &mapping) {
+ // Fuse the old terminator in_parallel ops into the new one.
+ scf::InParallelOp fusedTerm = fused.getTerminator();
+ rewriter.setInsertionPointToEnd(fusedTerm.getBody());
+ for (Operation &op : source.getTerminator().getYieldingOps())
+ rewriter.clone(op, mapping);
+}
- // Map shared outs.
- mapping.map(target.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
- mapping.map(source.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
+template <>
+void fuseTerminator(RewriterBase &rewriter, scf::ForOp source,
+ scf::ForOp &fused, IRMapping &mapping) {
+ // Build fused yield results by appropriately mapping original yield operands.
+ auto newTerm = rewriter.clone(*fused.getBody()->getTerminator(), mapping);
+ rewriter.replaceOp(fused.getBody()->getTerminator(), newTerm);
+}
+
+// TODO: We should maybe add a method to LoopLikeOpInterface that will
+// facilitate this transformation. For now, this acts as a placeholder.
+template <>
+void fuseTerminator(RewriterBase &rewriter, LoopLikeOpInterface source,
+ LoopLikeOpInterface &fused, IRMapping &mapping) {
+ if (isa<scf::ForOp>(source) && isa<scf::ForOp>(fused)) {
+ fuseTerminator(rewriter, cast<scf::ForOp>(source), cast<scf::ForOp>(fused),
+ mapping);
+ } else if (isa<scf::ForallOp>(source) && isa<scf::ForallOp>(fused)) {
+ fuseTerminator(rewriter, cast<scf::ForallOp>(source),
+ cast<scf::ForallOp>(fused), mapping);
+ } else if (isa<scf::ParallelOp>(source) && isa<scf::ParallelOp>(fused)) {
+ fuseTerminator(rewriter, cast<scf::ParallelOp>(source),
+ cast<scf::ParallelOp>(fused), mapping);
+ } else {
+ return;
+ }
+}
+LoopLikeOpInterface createFused(LoopLikeOpInterface target,
----------------
srcarroll wrote:
i moved this to `LoopLikeInterface.h/cpp` to make it have more general utility. could you take a look at the changes to make sure i'm not doing something dumb here? https://github.com/llvm/llvm-project/pull/94391/commits/cc95d75d2cc09f8a33850f3867c8313e374a0dfd
https://github.com/llvm/llvm-project/pull/94391
More information about the Mlir-commits
mailing list