[Mlir-commits] [mlir] Refactor LoopFuseSiblingOp and support parallel fusion (PR #94391)
Ivan Butygin
llvmlistbot at llvm.org
Thu Jun 27 10:07:49 PDT 2024
================
@@ -1070,104 +1071,202 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
return tileLoops;
}
+//===----------------------------------------------------------------------===//
+// Fusion related helpers
+//===----------------------------------------------------------------------===//
+
+/// Check if `target` and `source` are siblings, in the context that `target`
+/// is being fused into `source`.
+///
+/// This is a simple check that just checks if both operations are in the same
+/// block and some checks to ensure that the fused IR does not violate
+/// dominance.
+static bool isOpSibling(Operation *target, Operation *source,
+ Diagnostic &diag) {
+ // Check if both operations are same.
+ if (target == source) {
+ diag << "target and source need to be different loops";
+ return false;
+ }
+
+ // Check if both operations are in the same block.
+ if (target->getBlock() != source->getBlock()) {
+ diag << "target and source are not in the same block";
+ return false;
+ }
+
+ // Check if fusion will violate dominance.
+ DominanceInfo domInfo(source);
+ if (target->isBeforeInBlock(source)) {
+ // Since `target` is before `source`, all users of results of `target`
+ // need to be dominated by `source`.
+ for (Operation *user : target->getUsers()) {
+ if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
+ diag << "user of results of target should "
+ "be properly dominated by source";
+ return false;
+ }
+ }
+ } else {
+ // Since `target` is after `source`, all values used by `target` need
+ // to dominate `source`.
+
+ // Check if operands of `target` are dominated by `source`.
+ for (Value operand : target->getOperands()) {
+ Operation *operandOp = operand.getDefiningOp();
+ // Operands without defining operations are block arguments. When `target`
+ // and `source` occur in the same block, these operands dominate `source`.
+ if (!operandOp)
+ continue;
+
+ // Operand's defining operation should properly dominate `source`.
+ if (!domInfo.properlyDominates(operandOp, source,
+ /*enclosingOpOk=*/false)) {
+ diag << "operands of target should be properly dominated by source";
+ return false;
+ }
+ }
+
+ // Check if values used by `target` are dominated by `source`.
+ bool failed = false;
+ OpOperand *failedValue = nullptr;
+ visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
+ Operation *operandOp = operand->get().getDefiningOp();
+ if (operandOp && !domInfo.properlyDominates(operandOp, source,
+ /*enclosingOpOk=*/false)) {
+ // `operand` is not an argument of an enclosing block and the defining
+ // op of `operand` is outside `target` but does not dominate `source`.
+ failed = true;
+ failedValue = operand;
+ }
+ });
+
+ if (failed) {
+ diag << "values used inside regions of target should be properly "
+ "dominated by source";
+ diag.attachNote(failedValue->getOwner()->getLoc()) << "see operation";
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target,
+ LoopLikeOpInterface source,
+ Diagnostic &diag) {
+ bool iterSpaceEq =
+ target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
+ target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
+ target.getLoopSteps() == source.getLoopSteps();
+ auto forAllTarget = dyn_cast<scf::ForallOp>(*target);
+ auto forAllSource = dyn_cast<scf::ForallOp>(*source);
+ // TODO: Decouple checks on concrete loop types and move this function
+ // somewhere for general utility for `LoopLikeOpInterface`
+ if (forAllTarget && forAllSource)
+ iterSpaceEq =
+ iterSpaceEq && forAllTarget.getMapping() == forAllSource.getMapping();
+ if (!iterSpaceEq) {
+ diag << "target and source iteration spaces must be equal";
+ return false;
+ }
+ return isOpSibling(target, source, diag);
+}
+
scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForallOp source,
RewriterBase &rewriter) {
- unsigned numTargetOuts = target.getNumResults();
- unsigned numSourceOuts = source.getNumResults();
-
- // Create fused shared_outs.
- SmallVector<Value> fusedOuts;
- llvm::append_range(fusedOuts, target.getOutputs());
- llvm::append_range(fusedOuts, source.getOutputs());
-
- // Create a new scf.forall op after the source loop.
- rewriter.setInsertionPointAfter(source);
- scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
- source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
- source.getMixedStep(), fusedOuts, source.getMapping());
-
- // Map control operands.
- IRMapping mapping;
- mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
- mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
-
- // Map shared outs.
- mapping.map(target.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
- mapping.map(source.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
-
- // Append everything except the terminator into the fused operation.
- rewriter.setInsertionPointToStart(fusedLoop.getBody());
- for (Operation &op : target.getBody()->without_terminator())
- rewriter.clone(op, mapping);
- for (Operation &op : source.getBody()->without_terminator())
- rewriter.clone(op, mapping);
-
- // Fuse the old terminator in_parallel ops into the new one.
- scf::InParallelOp targetTerm = target.getTerminator();
- scf::InParallelOp sourceTerm = source.getTerminator();
- scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
- rewriter.setInsertionPointToStart(fusedTerm.getBody());
- for (Operation &op : targetTerm.getYieldingOps())
- rewriter.clone(op, mapping);
- for (Operation &op : sourceTerm.getYieldingOps())
- rewriter.clone(op, mapping);
-
- // Replace old loops by substituting their uses by results of the fused loop.
- rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
- rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+ scf::ForallOp fusedLoop = cast<scf::ForallOp>(createFused(
+ target, source, rewriter,
+ [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+ // `ForallOp` does not have yields, rather an `InParallelOp` terminator.
+ return ValueRange{};
+ },
+ [&](RewriterBase &b, LoopLikeOpInterface source,
+ LoopLikeOpInterface &target, IRMapping mapping) {
+ auto sourceForall = cast<scf::ForallOp>(source);
+ auto targetForall = cast<scf::ForallOp>(target);
+ scf::InParallelOp fusedTerm = targetForall.getTerminator();
+ b.setInsertionPointToEnd(fusedTerm.getBody());
+ for (Operation &op : sourceForall.getTerminator().getYieldingOps())
+ b.clone(op, mapping);
+ }));
+ rewriter.replaceOp(source,
+ fusedLoop.getResults().take_back(source.getNumResults()));
return fusedLoop;
}
scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
scf::ForOp source,
RewriterBase &rewriter) {
- unsigned numTargetOuts = target.getNumResults();
- unsigned numSourceOuts = source.getNumResults();
-
- // Create fused init_args, with target's init_args before source's init_args.
- SmallVector<Value> fusedInitArgs;
- llvm::append_range(fusedInitArgs, target.getInitArgs());
- llvm::append_range(fusedInitArgs, source.getInitArgs());
-
- // Create a new scf.for op after the source loop (with scf.yield terminator
- // (without arguments) only in case its init_args is empty).
- rewriter.setInsertionPointAfter(source);
- scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
- source.getLoc(), source.getLowerBound(), source.getUpperBound(),
- source.getStep(), fusedInitArgs);
-
- // Map original induction variables and operands to those of the fused loop.
- IRMapping mapping;
- mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
- mapping.map(target.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
- mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
- mapping.map(source.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
-
- // Merge target's body into the new (fused) for loop and then source's body.
- rewriter.setInsertionPointToStart(fusedLoop.getBody());
- for (Operation &op : target.getBody()->without_terminator())
- rewriter.clone(op, mapping);
- for (Operation &op : source.getBody()->without_terminator())
- rewriter.clone(op, mapping);
-
- // Build fused yield results by appropriately mapping original yield operands.
- SmallVector<Value> yieldResults;
- for (Value operand : target.getBody()->getTerminator()->getOperands())
- yieldResults.push_back(mapping.lookupOrDefault(operand));
- for (Value operand : source.getBody()->getTerminator()->getOperands())
- yieldResults.push_back(mapping.lookupOrDefault(operand));
- if (!yieldResults.empty())
- rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
-
- // Replace old loops by substituting their uses by results of the fused loop.
- rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
- rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+ scf::ForOp fusedLoop = cast<scf::ForOp>(createFused(
+ target, source, rewriter,
+ [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+ return source.getYieldedValues();
+ },
+ [&](RewriterBase &b, LoopLikeOpInterface source,
+ LoopLikeOpInterface &target, IRMapping mapping) {
+ auto targetFor = cast<scf::ForOp>(target);
+ auto newTerm = b.clone(*targetFor.getBody()->getTerminator(), mapping);
+ b.replaceOp(targetFor.getBody()->getTerminator(), newTerm);
+ }));
+ rewriter.replaceOp(source,
+ fusedLoop.getResults().take_back(source.getNumResults()));
+ return fusedLoop;
+}
+
+// TODO: Finish refactoring this a la the above, but likely requires additional
+// interface methods.
+scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
+ scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) {
+ Block *block1 = target.getBody();
+ Block *block2 = source.getBody();
+ auto term1 = cast<scf::ReduceOp>(block1->getTerminator());
+ auto term2 = cast<scf::ReduceOp>(block2->getTerminator());
+
+ ValueRange inits1 = target.getInitVals();
+ ValueRange inits2 = source.getInitVals();
+
+ SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+ newInitVars.append(inits2.begin(), inits2.end());
+
+ rewriter.setInsertionPoint(source);
----------------
Hardcode84 wrote:
Should we add insertion guard, as we are changing insertion point?
https://github.com/llvm/llvm-project/pull/94391
More information about the Mlir-commits
mailing list