[Mlir-commits] [mlir] Refactor LoopFuseSiblingOp and support parallel fusion (PR #94391)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri Jun 21 04:47:01 PDT 2024


================
@@ -1070,104 +1070,182 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
   return tileLoops;
 }
 
-scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
-                                                      scf::ForallOp source,
-                                                      RewriterBase &rewriter) {
-  unsigned numTargetOuts = target.getNumResults();
-  unsigned numSourceOuts = source.getNumResults();
+//===----------------------------------------------------------------------===//
+// Fusion related helpers
+//===----------------------------------------------------------------------===//
 
-  // Create fused shared_outs.
-  SmallVector<Value> fusedOuts;
-  llvm::append_range(fusedOuts, target.getOutputs());
-  llvm::append_range(fusedOuts, source.getOutputs());
+bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
+                                         LoopLikeOpInterface &source) {
+  auto iterSpaceEq =
+      target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
+      target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
+      target.getLoopSteps() == source.getLoopSteps();
+  auto forAllTarget = dyn_cast<scf::ForallOp>(*target);
+  auto forAllSource = dyn_cast<scf::ForallOp>(*source);
+  if (forAllTarget && forAllSource)
+    return iterSpaceEq &&
+           forAllTarget.getMapping() == forAllSource.getMapping();
+  return iterSpaceEq;
+}
 
-  // Create a new scf.forall op after the source loop.
-  rewriter.setInsertionPointAfter(source);
-  scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
-      source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
-      source.getMixedStep(), fusedOuts, source.getMapping());
+template <typename LoopTy>
+void fuseTerminator(RewriterBase &rewriter, LoopTy source, LoopTy &fused,
+                    IRMapping &mapping) {}
 
-  // Map control operands.
-  IRMapping mapping;
-  mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
-  mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
+template <>
+void fuseTerminator(RewriterBase &rewriter, scf::ForallOp source,
+                    scf::ForallOp &fused, IRMapping &mapping) {
+  // Fuse the old terminator in_parallel ops into the new one.
+  scf::InParallelOp fusedTerm = fused.getTerminator();
+  rewriter.setInsertionPointToEnd(fusedTerm.getBody());
+  for (Operation &op : source.getTerminator().getYieldingOps())
+    rewriter.clone(op, mapping);
+}
 
-  // Map shared outs.
-  mapping.map(target.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
-  mapping.map(source.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
+template <>
+void fuseTerminator(RewriterBase &rewriter, scf::ForOp source,
+                    scf::ForOp &fused, IRMapping &mapping) {
+  // Build fused yield results by appropriately mapping original yield operands.
+  auto newTerm = rewriter.clone(*fused.getBody()->getTerminator(), mapping);
+  rewriter.replaceOp(fused.getBody()->getTerminator(), newTerm);
+}
+
+// TODO: We should maybe add a method to LoopLikeOpInterface that will
+// facilitate this transformation. For now, this acts as a placeholder.
+template <>
+void fuseTerminator(RewriterBase &rewriter, LoopLikeOpInterface source,
+                    LoopLikeOpInterface &fused, IRMapping &mapping) {
+  if (isa<scf::ForOp>(source) && isa<scf::ForOp>(fused)) {
+    fuseTerminator(rewriter, cast<scf::ForOp>(source), cast<scf::ForOp>(fused),
+                   mapping);
+  } else if (isa<scf::ForallOp>(source) && isa<scf::ForallOp>(fused)) {
+    fuseTerminator(rewriter, cast<scf::ForallOp>(source),
+                   cast<scf::ForallOp>(fused), mapping);
+  } else if (isa<scf::ParallelOp>(source) && isa<scf::ParallelOp>(fused)) {
+    fuseTerminator(rewriter, cast<scf::ParallelOp>(source),
+                   cast<scf::ParallelOp>(fused), mapping);
+  } else {
+    return;
+  }
+}
 
+LoopLikeOpInterface createFused(LoopLikeOpInterface target,
+                                LoopLikeOpInterface source,
+                                RewriterBase &rewriter,
+                                NewYieldValuesFn newYieldValuesFn) {
+  auto targetIterArgs = target.getRegionIterArgs();
+  auto targetInductionVar = *target.getLoopInductionVars();
+  SmallVector<Value> targetYieldOperands(target.getYieldedValues());
+  auto sourceIterArgs = source.getRegionIterArgs();
+  auto sourceInductionVar = *source.getLoopInductionVars();
+  SmallVector<Value> sourceYieldOperands(source.getYieldedValues());
+  auto sourceRegion = source.getLoopRegions().front();
+  LoopLikeOpInterface fusedLoop = *target.replaceWithAdditionalYields(
+      rewriter, source.getInits(), /*replaceInitOperandUsesInLoop=*/false,
+      newYieldValuesFn);
+
+  // Map control operands.
+  IRMapping mapping;
+  mapping.map(targetInductionVar, *fusedLoop.getLoopInductionVars());
+  mapping.map(targetIterArgs,
+              fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size()));
+  mapping.map(targetYieldOperands,
+              fusedLoop.getYieldedValues().take_front(targetIterArgs.size()));
+  mapping.map(sourceInductionVar, *fusedLoop.getLoopInductionVars());
+  mapping.map(sourceIterArgs,
+              fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size()));
+  mapping.map(sourceYieldOperands,
+              fusedLoop.getYieldedValues().take_back(sourceIterArgs.size()));
   // Append everything except the terminator into the fused operation.
-  rewriter.setInsertionPointToStart(fusedLoop.getBody());
-  for (Operation &op : target.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
-  for (Operation &op : source.getBody()->without_terminator())
+  rewriter.setInsertionPoint(
+      fusedLoop.getLoopRegions().front()->front().getTerminator());
+  for (Operation &op : sourceRegion->front().without_terminator())
     rewriter.clone(op, mapping);
 
-  // Fuse the old terminator in_parallel ops into the new one.
-  scf::InParallelOp targetTerm = target.getTerminator();
-  scf::InParallelOp sourceTerm = source.getTerminator();
-  scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
-  rewriter.setInsertionPointToStart(fusedTerm.getBody());
-  for (Operation &op : targetTerm.getYieldingOps())
-    rewriter.clone(op, mapping);
-  for (Operation &op : sourceTerm.getYieldingOps())
-    rewriter.clone(op, mapping);
+  // TODO: Replace with corresponding interface method if added
+  fuseTerminator(rewriter, source, fusedLoop, mapping);
 
-  // Replace old loops by substituting their uses by results of the fused loop.
-  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
-  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+  return fusedLoop;
+}
+
+scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
+                                                      scf::ForallOp source,
+                                                      RewriterBase &rewriter) {
+  scf::ForallOp fusedLoop = cast<scf::ForallOp>(createFused(
+      target, source, rewriter,
+      [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+        // `ForallOp` does not have yields, rather an `InParallelOp` terminator.
+        return ValueRange{};
+      }));
+  rewriter.replaceOp(source,
+                     fusedLoop.getResults().take_back(source.getNumResults()));
 
   return fusedLoop;
 }
 
 scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
                                                 scf::ForOp source,
                                                 RewriterBase &rewriter) {
-  unsigned numTargetOuts = target.getNumResults();
-  unsigned numSourceOuts = source.getNumResults();
-
-  // Create fused init_args, with target's init_args before source's init_args.
-  SmallVector<Value> fusedInitArgs;
-  llvm::append_range(fusedInitArgs, target.getInitArgs());
-  llvm::append_range(fusedInitArgs, source.getInitArgs());
-
-  // Create a new scf.for op after the source loop (with scf.yield terminator
-  // (without arguments) only in case its init_args is empty).
-  rewriter.setInsertionPointAfter(source);
-  scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
-      source.getLoc(), source.getLowerBound(), source.getUpperBound(),
-      source.getStep(), fusedInitArgs);
+  scf::ForOp fusedLoop = cast<scf::ForOp>(createFused(
+      target, source, rewriter,
+      [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+        return source.getYieldedValues();
+      }));
+  rewriter.replaceOp(source,
+                     fusedLoop.getResults().take_back(source.getNumResults()));
+  return fusedLoop;
+}
 
-  // Map original induction variables and operands to those of the fused loop.
-  IRMapping mapping;
-  mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
-  mapping.map(target.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
-  mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
-  mapping.map(source.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
-
-  // Merge target's body into the new (fused) for loop and then source's body.
-  rewriter.setInsertionPointToStart(fusedLoop.getBody());
-  for (Operation &op : target.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
-  for (Operation &op : source.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
+scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
+    scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) {
+  Block *block1 = target.getBody();
+  Block *block2 = source.getBody();
+  auto term1 = cast<scf::ReduceOp>(block1->getTerminator());
+  auto term2 = cast<scf::ReduceOp>(block2->getTerminator());
 
-  // Build fused yield results by appropriately mapping original yield operands.
-  SmallVector<Value> yieldResults;
-  for (Value operand : target.getBody()->getTerminator()->getOperands())
-    yieldResults.push_back(mapping.lookupOrDefault(operand));
-  for (Value operand : source.getBody()->getTerminator()->getOperands())
-    yieldResults.push_back(mapping.lookupOrDefault(operand));
-  if (!yieldResults.empty())
-    rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
-
-  // Replace old loops by substituting their uses by results of the fused loop.
-  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
-  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+  ValueRange inits1 = target.getInitVals();
+  ValueRange inits2 = source.getInitVals();
+
+  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+  newInitVars.append(inits2.begin(), inits2.end());
+
+  rewriter.setInsertionPoint(source);
+  auto fusedLoop = rewriter.create<scf::ParallelOp>(
+      source.getLoc(), source.getLowerBound(), source.getUpperBound(),
+      source.getStep(), newInitVars);
+  Block *newBlock = fusedLoop.getBody();
+  rewriter.inlineBlockBefore(block2, newBlock, newBlock->begin(),
+                             newBlock->getArguments());
+  rewriter.inlineBlockBefore(block1, newBlock, newBlock->begin(),
+                             newBlock->getArguments());
+
+  ValueRange results = fusedLoop.getResults();
+  if (!results.empty()) {
+    rewriter.setInsertionPointToEnd(newBlock);
+
+    ValueRange reduceArgs1 = term1.getOperands();
+    ValueRange reduceArgs2 = term2.getOperands();
+    SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
+    newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
+
+    auto newReduceOp =
+        rewriter.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
+
+    for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
+             term1.getReductions(), term2.getReductions()))) {
+      Block &oldRedBlock = reg.front();
+      Block &newRedBlock = newReduceOp.getReductions()[i].front();
+      rewriter.inlineBlockBefore(&oldRedBlock, &newRedBlock,
+                                 newRedBlock.begin(),
+                                 newRedBlock.getArguments());
+    }
+    target.replaceAllUsesWith(results.take_front(inits1.size()));
+    source.replaceAllUsesWith(results.take_back(inits2.size()));
+  }
+  term1->erase();
+  term2->erase();
+  target.erase();
+  source.erase();
----------------
ftynse wrote:

Use the rewriter for this.

https://github.com/llvm/llvm-project/pull/94391


More information about the Mlir-commits mailing list