[Mlir-commits] [mlir] Revert "Refactor LoopFuseSiblingOp and support parallel fusion (#94391)" (PR #97523)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 2 23:25:53 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-scf

Author: None (srcarroll)

<details>
<summary>Changes</summary>

This reverts commit 6820b0871807abff07df118659e0de2ca741cb0b.

---

Patch is 47.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97523.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+1-2) 
- (modified) mlir/include/mlir/Dialect/SCF/Utils/Utils.h (-20) 
- (modified) mlir/include/mlir/Interfaces/LoopLikeInterface.h (-20) 
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (-38) 
- (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+119-21) 
- (modified) mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (+74-6) 
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+88-191) 
- (modified) mlir/lib/Interfaces/LoopLikeInterface.cpp (-55) 
- (modified) mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir (+1-233) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index bf95fbe6721cf..f35ea962bea16 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -303,8 +303,7 @@ def ForallOp : SCF_Op<"forall", [
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
           ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars", 
            "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
-           "replaceWithAdditionalYields", "promoteIfSingleIteration",
-           "yieldTiledValuesAndReplace"]>,
+           "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
        RecursiveMemoryEffects,
        SingleBlockImplicitTerminator<"scf::InParallelOp">,
        DeclareOpInterfaceMethods<RegionBranchOpInterface>,
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 6a40304e2eeba..de807c3e4e1f8 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -181,16 +181,6 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
 void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
                              scf::ForOp root);
 
-//===----------------------------------------------------------------------===//
-// Fusion related helpers
-//===----------------------------------------------------------------------===//
-
-/// Check structural compatibility between two loops such as iteration space
-/// and dominance.
-bool checkFusionStructuralLegality(LoopLikeOpInterface target,
-                                   LoopLikeOpInterface source,
-                                   Diagnostic &diag);
-
 /// Given two scf.forall loops, `target` and `source`, fuses `target` into
 /// `source`. Assumes that the given loops are siblings and are independent of
 /// each other.
@@ -212,16 +202,6 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
 scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
                                           RewriterBase &rewriter);
 
-/// Given two scf.parallel loops, `target` and `source`, fuses `target` into
-/// `source`. Assumes that the given loops are siblings and are independent of
-/// each other.
-///
-/// This function does not perform any legality checks and simply fuses the
-/// loops. The caller is responsible for ensuring that the loops are legal to
-/// fuse.
-scf::ParallelOp fuseIndependentSiblingParallelLoops(scf::ParallelOp target,
-                                                    scf::ParallelOp source,
-                                                    RewriterBase &rewriter);
 } // namespace mlir
 
 #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index d08e097a9b4af..9925fc6ce6ca9 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -90,24 +90,4 @@ struct JamBlockGatherer {
 /// Include the generated interface declarations.
 #include "mlir/Interfaces/LoopLikeInterface.h.inc"
 
-namespace mlir {
-/// A function that rewrites `target`'s terminator as a teminator obtained by
-/// fusing `source` into `target`.
-using FuseTerminatorFn =
-    function_ref<void(RewriterBase &rewriter, LoopLikeOpInterface source,
-                      LoopLikeOpInterface &target, IRMapping mapping)>;
-
-/// Returns a fused `LoopLikeOpInterface` created by fusing `source` to
-/// `target`.  The `NewYieldValuesFn` callback is used to pass to the
-/// `replaceWithAdditionalYields` interface method to replace the loop with a
-/// new loop with (possibly) additional yields, while the `FuseTerminatorFn`
-/// callback is repsonsible for updating the fused loop terminator.
-LoopLikeOpInterface createFused(LoopLikeOpInterface target,
-                                LoopLikeOpInterface source,
-                                RewriterBase &rewriter,
-                                NewYieldValuesFn newYieldValuesFn,
-                                FuseTerminatorFn fuseTerminatorFn);
-
-} // namespace mlir
-
 #endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index cb15e0ecebf05..907d7f794593d 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -618,44 +618,6 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
 
 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
 
-FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
-    RewriterBase &rewriter, ValueRange newInitOperands,
-    bool replaceInitOperandUsesInLoop,
-    const NewYieldValuesFn &newYieldValuesFn) {
-  // Create a new loop before the existing one, with the extra operands.
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(getOperation());
-  SmallVector<Value> inits(getOutputs());
-  llvm::append_range(inits, newInitOperands);
-  scf::ForallOp newLoop = rewriter.create<scf::ForallOp>(
-      getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
-      inits, getMapping(),
-      /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
-
-  // Move the loop body to the new op.
-  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
-                       newLoop.getBody()->getArguments().take_front(
-                           getBody()->getNumArguments()));
-
-  if (replaceInitOperandUsesInLoop) {
-    // Replace all uses of `newInitOperands` with the corresponding basic block
-    // arguments.
-    for (auto &&[newOperand, oldOperand] :
-         llvm::zip(newInitOperands, newLoop.getBody()->getArguments().take_back(
-                                        newInitOperands.size()))) {
-      rewriter.replaceUsesWithIf(newOperand, oldOperand, [&](OpOperand &use) {
-        Operation *user = use.getOwner();
-        return newLoop->isProperAncestor(user);
-      });
-    }
-  }
-
-  // Replace the old loop.
-  rewriter.replaceOp(getOperation(),
-                     newLoop->getResults().take_front(getNumResults()));
-  return cast<LoopLikeOpInterface>(newLoop.getOperation());
-}
-
 /// Promotes the loop body of a forallOp to its containing block if it can be
 /// determined that the loop has a single iteration.
 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 41834fea3bb84..56ff2709a589e 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -261,10 +261,8 @@ loopScheduling(scf::ForOp forOp,
     return 1;
   };
 
-  std::optional<int64_t> ubConstant =
-      getConstantIntValue(forOp.getUpperBound());
-  std::optional<int64_t> lbConstant =
-      getConstantIntValue(forOp.getLowerBound());
+  std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound());
+  std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound());
   DenseMap<Operation *, unsigned> opCycles;
   std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
   for (Operation &op : forOp.getBody()->getOperations()) {
@@ -449,6 +447,113 @@ void transform::TakeAssumedBranchOp::getEffects(
 // LoopFuseSiblingOp
 //===----------------------------------------------------------------------===//
 
+/// Check if `target` and `source` are siblings, in the context that `target`
+/// is being fused into `source`.
+///
+/// This is a simple check that just checks if both operations are in the same
+/// block and some checks to ensure that the fused IR does not violate
+/// dominance.
+static DiagnosedSilenceableFailure isOpSibling(Operation *target,
+                                               Operation *source) {
+  // Check if both operations are same.
+  if (target == source)
+    return emitSilenceableFailure(source)
+           << "target and source need to be different loops";
+
+  // Check if both operations are in the same block.
+  if (target->getBlock() != source->getBlock())
+    return emitSilenceableFailure(source)
+           << "target and source are not in the same block";
+
+  // Check if fusion will violate dominance.
+  DominanceInfo domInfo(source);
+  if (target->isBeforeInBlock(source)) {
+    // Since `target` is before `source`, all users of results of `target`
+    // need to be dominated by `source`.
+    for (Operation *user : target->getUsers()) {
+      if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
+        return emitSilenceableFailure(target)
+               << "user of results of target should be properly dominated by "
+                  "source";
+      }
+    }
+  } else {
+    // Since `target` is after `source`, all values used by `target` need
+    // to dominate `source`.
+
+    // Check if operands of `target` are dominated by `source`.
+    for (Value operand : target->getOperands()) {
+      Operation *operandOp = operand.getDefiningOp();
+      // Operands without defining operations are block arguments. When `target`
+      // and `source` occur in the same block, these operands dominate `source`.
+      if (!operandOp)
+        continue;
+
+      // Operand's defining operation should properly dominate `source`.
+      if (!domInfo.properlyDominates(operandOp, source,
+                                     /*enclosingOpOk=*/false))
+        return emitSilenceableFailure(target)
+               << "operands of target should be properly dominated by source";
+    }
+
+    // Check if values used by `target` are dominated by `source`.
+    bool failed = false;
+    OpOperand *failedValue = nullptr;
+    visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
+      Operation *operandOp = operand->get().getDefiningOp();
+      if (operandOp && !domInfo.properlyDominates(operandOp, source,
+                                                  /*enclosingOpOk=*/false)) {
+        // `operand` is not an argument of an enclosing block and the defining
+        // op of `operand` is outside `target` but does not dominate `source`.
+        failed = true;
+        failedValue = operand;
+      }
+    });
+
+    if (failed)
+      return emitSilenceableFailure(failedValue->getOwner())
+             << "values used inside regions of target should be properly "
+                "dominated by source";
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+/// Check if `target` scf.forall can be fused into `source` scf.forall.
+///
+/// This simply checks if both loops have the same bounds, steps and mapping.
+/// No attempt is made at checking that the side effects of `target` and
+/// `source` are independent of each other.
+static bool isForallWithIdenticalConfiguration(Operation *target,
+                                               Operation *source) {
+  auto targetOp = dyn_cast<scf::ForallOp>(target);
+  auto sourceOp = dyn_cast<scf::ForallOp>(source);
+  if (!targetOp || !sourceOp)
+    return false;
+
+  return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
+         targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
+         targetOp.getMixedStep() == sourceOp.getMixedStep() &&
+         targetOp.getMapping() == sourceOp.getMapping();
+}
+
+/// Check if `target` scf.for can be fused into `source` scf.for.
+///
+/// This simply checks if both loops have the same bounds and steps. No attempt
+/// is made at checking that the side effects of `target` and `source` are
+/// independent of each other.
+static bool isForWithIdenticalConfiguration(Operation *target,
+                                            Operation *source) {
+  auto targetOp = dyn_cast<scf::ForOp>(target);
+  auto sourceOp = dyn_cast<scf::ForOp>(source);
+  if (!targetOp || !sourceOp)
+    return false;
+
+  return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
+         targetOp.getUpperBound() == sourceOp.getUpperBound() &&
+         targetOp.getStep() == sourceOp.getStep();
+}
+
 DiagnosedSilenceableFailure
 transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
                                     transform::TransformResults &results,
@@ -464,32 +569,25 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
            << "source handle (got " << llvm::range_size(sourceOps) << ")";
   }
 
-  auto target = dyn_cast<LoopLikeOpInterface>(*targetOps.begin());
-  auto source = dyn_cast<LoopLikeOpInterface>(*sourceOps.begin());
-  if (!target || !source)
-    return emitSilenceableFailure(target->getLoc())
-           << "target or source is not a loop op";
+  Operation *target = *targetOps.begin();
+  Operation *source = *sourceOps.begin();
 
-  // Check if loops can be fused
-  Diagnostic diag(target.getLoc(), DiagnosticSeverity::Error);
-  if (!mlir::checkFusionStructuralLegality(target, source, diag))
-    return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+  // Check if the target and source are siblings.
+  DiagnosedSilenceableFailure diag = isOpSibling(target, source);
+  if (!diag.succeeded())
+    return diag;
 
   Operation *fusedLoop;
-  // TODO: Support fusion for loop-like ops besides scf.for, scf.forall
-  // and scf.parallel.
-  if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
+  /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
+  if (isForWithIdenticalConfiguration(target, source)) {
     fusedLoop = fuseIndependentSiblingForLoops(
         cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
-  } else if (isa<scf::ForallOp>(target) && isa<scf::ForallOp>(source)) {
+  } else if (isForallWithIdenticalConfiguration(target, source)) {
     fusedLoop = fuseIndependentSiblingForallLoops(
         cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
-  } else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
-    fusedLoop = fuseIndependentSiblingParallelLoops(
-        cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
   } else
     return emitSilenceableFailure(target->getLoc())
-           << "unsupported loop type for fusion";
+           << "operations cannot be fused";
 
   assert(fusedLoop && "failed to fuse operations");
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index b775f988576e3..5934d85373b03 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -16,7 +16,6 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
-#include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/OpDefinition.h"
@@ -38,6 +37,24 @@ static bool hasNestedParallelOp(ParallelOp ploop) {
   return walkResult.wasInterrupted();
 }
 
+/// Verify equal iteration spaces.
+static bool equalIterationSpaces(ParallelOp firstPloop,
+                                 ParallelOp secondPloop) {
+  if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
+    return false;
+
+  auto matchOperands = [&](const OperandRange &lhs,
+                           const OperandRange &rhs) -> bool {
+    // TODO: Extend this to support aliases and equal constants.
+    return std::equal(lhs.begin(), lhs.end(), rhs.begin());
+  };
+  return matchOperands(firstPloop.getLowerBound(),
+                       secondPloop.getLowerBound()) &&
+         matchOperands(firstPloop.getUpperBound(),
+                       secondPloop.getUpperBound()) &&
+         matchOperands(firstPloop.getStep(), secondPloop.getStep());
+}
+
 /// Checks if the parallel loops have mixed access to the same buffers. Returns
 /// `true` if the first parallel loop writes to the same indices that the second
 /// loop reads.
@@ -136,10 +153,9 @@ verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
 static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
                           const IRMapping &firstToSecondPloopIndices,
                           llvm::function_ref<bool(Value, Value)> mayAlias) {
-  Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark);
   return !hasNestedParallelOp(firstPloop) &&
          !hasNestedParallelOp(secondPloop) &&
-         checkFusionStructuralLegality(firstPloop, secondPloop, diag) &&
+         equalIterationSpaces(firstPloop, secondPloop) &&
          succeeded(verifyDependencies(firstPloop, secondPloop,
                                       firstToSecondPloopIndices, mayAlias));
 }
@@ -158,9 +174,61 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
                      mayAlias))
     return;
 
-  IRRewriter rewriter(builder);
-  secondPloop = mlir::fuseIndependentSiblingParallelLoops(
-      firstPloop, secondPloop, rewriter);
+  DominanceInfo dom;
+  // We are fusing first loop into second, make sure there are no users of the
+  // first loop results between loops.
+  for (Operation *user : firstPloop->getUsers())
+    if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
+      return;
+
+  ValueRange inits1 = firstPloop.getInitVals();
+  ValueRange inits2 = secondPloop.getInitVals();
+
+  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+  newInitVars.append(inits2.begin(), inits2.end());
+
+  IRRewriter b(builder);
+  b.setInsertionPoint(secondPloop);
+  auto newSecondPloop = b.create<ParallelOp>(
+      secondPloop.getLoc(), secondPloop.getLowerBound(),
+      secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
+
+  Block *newBlock = newSecondPloop.getBody();
+  auto term1 = cast<ReduceOp>(block1->getTerminator());
+  auto term2 = cast<ReduceOp>(block2->getTerminator());
+
+  b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
+                      newBlock->getArguments());
+  b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
+                      newBlock->getArguments());
+
+  ValueRange results = newSecondPloop.getResults();
+  if (!results.empty()) {
+    b.setInsertionPointToEnd(newBlock);
+
+    ValueRange reduceArgs1 = term1.getOperands();
+    ValueRange reduceArgs2 = term2.getOperands();
+    SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
+    newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
+
+    auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
+
+    for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
+             term1.getReductions(), term2.getReductions()))) {
+      Block &oldRedBlock = reg.front();
+      Block &newRedBlock = newReduceOp.getReductions()[i].front();
+      b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
+                          newRedBlock.getArguments());
+    }
+
+    firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
+    secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
+  }
+  term1->erase();
+  term2->erase();
+  firstPloop.erase();
+  secondPloop.erase();
+  secondPloop = newSecondPloop;
 }
 
 void mlir::scf::naivelyFuseParallelOps(
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index abfc9a1b4d444..c0ee9d2afe91c 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -17,7 +17,6 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Dominance.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
@@ -1263,131 +1262,54 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
   return tileLoops;
 }
 
-//===----------------------------------------------------------------------===//
-// Fusion related helpers
-//===----------------------------------------------------------------------===//
-
-/// Check if `target` and `source` are siblings, in the context that `target`
-/// is being fused into `source`.
-///
-/// This is a simple check that just checks if both operatio...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/97523


More information about the Mlir-commits mailing list