[Mlir-commits] [mlir] [SCF][Transform] Add support for scf.for in LoopFuseSibling op (PR #81495)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Wed Mar 27 09:52:58 PDT 2024
================
@@ -910,61 +910,98 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
unsigned numTargetOuts = target.getNumResults();
unsigned numSourceOuts = source.getNumResults();
- OperandRange targetOuts = target.getOutputs();
- OperandRange sourceOuts = source.getOutputs();
-
// Create fused shared_outs.
SmallVector<Value> fusedOuts;
- fusedOuts.reserve(numTargetOuts + numSourceOuts);
- fusedOuts.append(targetOuts.begin(), targetOuts.end());
- fusedOuts.append(sourceOuts.begin(), sourceOuts.end());
+ llvm::append_range(fusedOuts, target.getOutputs());
+ llvm::append_range(fusedOuts, source.getOutputs());
- // Create a new scf::forall op after the source loop.
+ // Create a new scf.forall op after the source loop.
rewriter.setInsertionPointAfter(source);
scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
source.getMixedStep(), fusedOuts, source.getMapping());
// Map control operands.
- IRMapping fusedMapping;
- fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
- fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
+ IRMapping mapping;
+ mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
+ mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
// Map shared outs.
- fusedMapping.map(target.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().slice(0, numTargetOuts));
- fusedMapping.map(
- source.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().slice(numTargetOuts, numSourceOuts));
+ mapping.map(target.getRegionIterArgs(),
+ fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+ mapping.map(source.getRegionIterArgs(),
+ fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
// Append everything except the terminator into the fused operation.
rewriter.setInsertionPointToStart(fusedLoop.getBody());
for (Operation &op : target.getBody()->without_terminator())
- rewriter.clone(op, fusedMapping);
+ rewriter.clone(op, mapping);
for (Operation &op : source.getBody()->without_terminator())
- rewriter.clone(op, fusedMapping);
+ rewriter.clone(op, mapping);
// Fuse the old terminator in_parallel ops into the new one.
scf::InParallelOp targetTerm = target.getTerminator();
scf::InParallelOp sourceTerm = source.getTerminator();
scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
-
rewriter.setInsertionPointToStart(fusedTerm.getBody());
for (Operation &op : targetTerm.getYieldingOps())
- rewriter.clone(op, fusedMapping);
+ rewriter.clone(op, mapping);
for (Operation &op : sourceTerm.getYieldingOps())
- rewriter.clone(op, fusedMapping);
-
- // Replace all uses of the old loops with the fused loop.
- rewriter.replaceAllUsesWith(target.getResults(),
- fusedLoop.getResults().slice(0, numTargetOuts));
- rewriter.replaceAllUsesWith(
- source.getResults(),
- fusedLoop.getResults().slice(numTargetOuts, numSourceOuts));
-
- // Erase the old loops.
- rewriter.eraseOp(target);
- rewriter.eraseOp(source);
+ rewriter.clone(op, mapping);
+
+ // Replace old loops by substituting their uses by results of the fused loop.
+ rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
+ rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+
+ return fusedLoop;
+}
+
+scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
+ scf::ForOp source,
+ RewriterBase &rewriter) {
+ unsigned numTargetOuts = target.getNumResults();
+ unsigned numSourceOuts = source.getNumResults();
+
+ // Create fused init_args, with target's init_args before source's init_args.
+ SmallVector<Value> fusedInitArgs;
+ llvm::append_range(fusedInitArgs, target.getInitArgs());
+ llvm::append_range(fusedInitArgs, source.getInitArgs());
+
+ // Create a new scf.for op after the source loop (with scf.yield terminator
+ // (without arguments) only in case its init_args is empty).
+ rewriter.setInsertionPointAfter(source);
+ scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
+ source.getLoc(), source.getLowerBound(), source.getUpperBound(),
+ source.getStep(), fusedInitArgs);
+
+ // Map original induction variables and operands to those of the fused loop.
+ IRMapping mapping;
+ mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
+ mapping.map(target.getRegionIterArgs(),
+ fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+ mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
+ mapping.map(source.getRegionIterArgs(),
+ fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
+
+ // Merge target's body into the new (fused) for loop and then source's body.
+ rewriter.setInsertionPointToStart(fusedLoop.getBody());
+ for (Operation &op : target.getBody()->without_terminator())
+ rewriter.clone(op, mapping);
+ for (Operation &op : source.getBody()->without_terminator())
+ rewriter.clone(op, mapping);
+
+ // Build fused yield results by appropriately mapping original yield operands.
+ SmallVector<Value> yieldResults;
+ for (Value operand : target.getBody()->getTerminator()->getOperands())
+ yieldResults.push_back(mapping.lookupOrDefault(operand));
+ for (Value operand : source.getBody()->getTerminator()->getOperands())
+ yieldResults.push_back(mapping.lookupOrDefault(operand));
+ if (yieldResults.size())
----------------
ftynse wrote:
clang-tidy: prefer `.empty()` to check for emptiness.
https://github.com/llvm/llvm-project/pull/81495
More information about the Mlir-commits
mailing list