[Mlir-commits] [mlir] Refactor LoopFuseSiblingOp and support parallel fusion (PR #94391)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 27 10:26:13 PDT 2024


================
@@ -1070,104 +1071,202 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
   return tileLoops;
 }
 
+//===----------------------------------------------------------------------===//
+// Fusion related helpers
+//===----------------------------------------------------------------------===//
+
+/// Check if `target` and `source` are siblings, in the context that `target`
+/// is being fused into `source`.
+///
+/// This is a simple check that just checks if both operations are in the same
+/// block and some checks to ensure that the fused IR does not violate
+/// dominance.
+static bool isOpSibling(Operation *target, Operation *source,
+                        Diagnostic &diag) {
+  // Check if both operations are same.
+  if (target == source) {
+    diag << "target and source need to be different loops";
+    return false;
+  }
+
+  // Check if both operations are in the same block.
+  if (target->getBlock() != source->getBlock()) {
+    diag << "target and source are not in the same block";
+    return false;
+  }
+
+  // Check if fusion will violate dominance.
+  DominanceInfo domInfo(source);
+  if (target->isBeforeInBlock(source)) {
+    // Since `target` is before `source`, all users of results of `target`
+    // need to be dominated by `source`.
+    for (Operation *user : target->getUsers()) {
+      if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
+        diag << "user of results of target should "
+                "be properly dominated by source";
+        return false;
+      }
+    }
+  } else {
+    // Since `target` is after `source`, all values used by `target` need
+    // to dominate `source`.
+
+    // Check if operands of `target` are dominated by `source`.
+    for (Value operand : target->getOperands()) {
+      Operation *operandOp = operand.getDefiningOp();
+      // Operands without defining operations are block arguments. When `target`
+      // and `source` occur in the same block, these operands dominate `source`.
+      if (!operandOp)
+        continue;
+
+      // Operand's defining operation should properly dominate `source`.
+      if (!domInfo.properlyDominates(operandOp, source,
+                                     /*enclosingOpOk=*/false)) {
+        diag << "operands of target should be properly dominated by source";
+        return false;
+      }
+    }
+
+    // Check if values used by `target` are dominated by `source`.
+    bool failed = false;
+    OpOperand *failedValue = nullptr;
+    visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
+      Operation *operandOp = operand->get().getDefiningOp();
+      if (operandOp && !domInfo.properlyDominates(operandOp, source,
+                                                  /*enclosingOpOk=*/false)) {
+        // `operand` is not an argument of an enclosing block and the defining
+        // op of `operand` is outside `target` but does not dominate `source`.
+        failed = true;
+        failedValue = operand;
+      }
+    });
+
+    if (failed) {
+      diag << "values used inside regions of target should be properly "
+              "dominated by source";
+      diag.attachNote(failedValue->getOwner()->getLoc()) << "see operation";
+      return false;
+    }
+  }
+
+  return true;
+}
+
+bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target,
+                                         LoopLikeOpInterface source,
+                                         Diagnostic &diag) {
+  bool iterSpaceEq =
+      target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
+      target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
+      target.getLoopSteps() == source.getLoopSteps();
+  auto forAllTarget = dyn_cast<scf::ForallOp>(*target);
+  auto forAllSource = dyn_cast<scf::ForallOp>(*source);
+  // TODO: Decouple checks on concrete loop types and move this function
+  // somewhere for general utility for `LoopLikeOpInterface`
+  if (forAllTarget && forAllSource)
+    iterSpaceEq =
+        iterSpaceEq && forAllTarget.getMapping() == forAllSource.getMapping();
+  if (!iterSpaceEq) {
+    diag << "target and source iteration spaces must be equal";
+    return false;
+  }
+  return isOpSibling(target, source, diag);
+}
+
 scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
                                                       scf::ForallOp source,
                                                       RewriterBase &rewriter) {
-  unsigned numTargetOuts = target.getNumResults();
-  unsigned numSourceOuts = source.getNumResults();
-
-  // Create fused shared_outs.
-  SmallVector<Value> fusedOuts;
-  llvm::append_range(fusedOuts, target.getOutputs());
-  llvm::append_range(fusedOuts, source.getOutputs());
-
-  // Create a new scf.forall op after the source loop.
-  rewriter.setInsertionPointAfter(source);
-  scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
-      source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
-      source.getMixedStep(), fusedOuts, source.getMapping());
-
-  // Map control operands.
-  IRMapping mapping;
-  mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
-  mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
-
-  // Map shared outs.
-  mapping.map(target.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
-  mapping.map(source.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
-
-  // Append everything except the terminator into the fused operation.
-  rewriter.setInsertionPointToStart(fusedLoop.getBody());
-  for (Operation &op : target.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
-  for (Operation &op : source.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
-
-  // Fuse the old terminator in_parallel ops into the new one.
-  scf::InParallelOp targetTerm = target.getTerminator();
-  scf::InParallelOp sourceTerm = source.getTerminator();
-  scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
-  rewriter.setInsertionPointToStart(fusedTerm.getBody());
-  for (Operation &op : targetTerm.getYieldingOps())
-    rewriter.clone(op, mapping);
-  for (Operation &op : sourceTerm.getYieldingOps())
-    rewriter.clone(op, mapping);
-
-  // Replace old loops by substituting their uses by results of the fused loop.
-  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
-  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+  scf::ForallOp fusedLoop = cast<scf::ForallOp>(createFused(
+      target, source, rewriter,
+      [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+        // `ForallOp` does not have yields, rather an `InParallelOp` terminator.
+        return ValueRange{};
+      },
+      [&](RewriterBase &b, LoopLikeOpInterface source,
+          LoopLikeOpInterface &target, IRMapping mapping) {
+        auto sourceForall = cast<scf::ForallOp>(source);
+        auto targetForall = cast<scf::ForallOp>(target);
+        scf::InParallelOp fusedTerm = targetForall.getTerminator();
+        b.setInsertionPointToEnd(fusedTerm.getBody());
+        for (Operation &op : sourceForall.getTerminator().getYieldingOps())
+          b.clone(op, mapping);
+      }));
+  rewriter.replaceOp(source,
+                     fusedLoop.getResults().take_back(source.getNumResults()));
 
   return fusedLoop;
 }
 
 scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
                                                 scf::ForOp source,
                                                 RewriterBase &rewriter) {
-  unsigned numTargetOuts = target.getNumResults();
-  unsigned numSourceOuts = source.getNumResults();
-
-  // Create fused init_args, with target's init_args before source's init_args.
-  SmallVector<Value> fusedInitArgs;
-  llvm::append_range(fusedInitArgs, target.getInitArgs());
-  llvm::append_range(fusedInitArgs, source.getInitArgs());
-
-  // Create a new scf.for op after the source loop (with scf.yield terminator
-  // (without arguments) only in case its init_args is empty).
-  rewriter.setInsertionPointAfter(source);
-  scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
-      source.getLoc(), source.getLowerBound(), source.getUpperBound(),
-      source.getStep(), fusedInitArgs);
-
-  // Map original induction variables and operands to those of the fused loop.
-  IRMapping mapping;
-  mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
-  mapping.map(target.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
-  mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
-  mapping.map(source.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
-
-  // Merge target's body into the new (fused) for loop and then source's body.
-  rewriter.setInsertionPointToStart(fusedLoop.getBody());
-  for (Operation &op : target.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
-  for (Operation &op : source.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
-
-  // Build fused yield results by appropriately mapping original yield operands.
-  SmallVector<Value> yieldResults;
-  for (Value operand : target.getBody()->getTerminator()->getOperands())
-    yieldResults.push_back(mapping.lookupOrDefault(operand));
-  for (Value operand : source.getBody()->getTerminator()->getOperands())
-    yieldResults.push_back(mapping.lookupOrDefault(operand));
-  if (!yieldResults.empty())
-    rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
-
-  // Replace old loops by substituting their uses by results of the fused loop.
-  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
-  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+  scf::ForOp fusedLoop = cast<scf::ForOp>(createFused(
+      target, source, rewriter,
+      [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+        return source.getYieldedValues();
+      },
+      [&](RewriterBase &b, LoopLikeOpInterface source,
+          LoopLikeOpInterface &target, IRMapping mapping) {
+        auto targetFor = cast<scf::ForOp>(target);
+        auto newTerm = b.clone(*targetFor.getBody()->getTerminator(), mapping);
+        b.replaceOp(targetFor.getBody()->getTerminator(), newTerm);
+      }));
+  rewriter.replaceOp(source,
+                     fusedLoop.getResults().take_back(source.getNumResults()));
+  return fusedLoop;
+}
+
+// TODO: Finish refactoring this a la the above, but likely requires additional
+// interface methods.
+scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
+    scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) {
+  Block *block1 = target.getBody();
+  Block *block2 = source.getBody();
+  auto term1 = cast<scf::ReduceOp>(block1->getTerminator());
+  auto term2 = cast<scf::ReduceOp>(block2->getTerminator());
+
+  ValueRange inits1 = target.getInitVals();
+  ValueRange inits2 = source.getInitVals();
+
+  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+  newInitVars.append(inits2.begin(), inits2.end());
+
+  rewriter.setInsertionPoint(source);
----------------
srcarroll wrote:

probably a good idea. i think it suffices to have the whole function as the scope for the guard

https://github.com/llvm/llvm-project/pull/94391


More information about the Mlir-commits mailing list