[Mlir-commits] [mlir] edbc0e3 - [mlir][loops] Reland Refactor LoopFuseSiblingOp and support parallel fusion #94391 (#97607)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 3 12:03:58 PDT 2024


Author: srcarroll
Date: 2024-07-03T14:03:54-05:00
New Revision: edbc0e30a9e587cee1189be023b9385adc2f239a

URL: https://github.com/llvm/llvm-project/commit/edbc0e30a9e587cee1189be023b9385adc2f239a
DIFF: https://github.com/llvm/llvm-project/commit/edbc0e30a9e587cee1189be023b9385adc2f239a.diff

LOG: [mlir][loops] Reland Refactor LoopFuseSiblingOp and support parallel fusion #94391 (#97607)

The refactor had a bug where the fused loop was inserted in an incorrect
location. This patch fixes the bug and relands the original PR
https://github.com/llvm/llvm-project/pull/94391.

This patch refactors code related to LoopFuseSiblingOp transform in
attempt to reduce duplicate common code. The aim is to refactor as much
as possible to a functions on LoopLikeOpInterfaces, but this is still a
work in progress. A full refactor will require more additions to the
LoopLikeOpInterface.

In addition, scf.parallel fusion support has been added.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/include/mlir/Dialect/SCF/Utils/Utils.h
    mlir/include/mlir/Interfaces/LoopLikeInterface.h
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
    mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
    mlir/lib/Dialect/SCF/Utils/Utils.cpp
    mlir/lib/Interfaces/LoopLikeInterface.cpp
    mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index f35ea962bea16..bf95fbe6721cf 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -303,7 +303,8 @@ def ForallOp : SCF_Op<"forall", [
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
           ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars", 
            "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
-           "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
+           "replaceWithAdditionalYields", "promoteIfSingleIteration",
+           "yieldTiledValuesAndReplace"]>,
        RecursiveMemoryEffects,
        SingleBlockImplicitTerminator<"scf::InParallelOp">,
        DeclareOpInterfaceMethods<RegionBranchOpInterface>,

diff  --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index de807c3e4e1f8..6a40304e2eeba 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -181,6 +181,16 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
 void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
                              scf::ForOp root);
 
+//===----------------------------------------------------------------------===//
+// Fusion related helpers
+//===----------------------------------------------------------------------===//
+
+/// Check structural compatibility between two loops such as iteration space
+/// and dominance.
+bool checkFusionStructuralLegality(LoopLikeOpInterface target,
+                                   LoopLikeOpInterface source,
+                                   Diagnostic &diag);
+
 /// Given two scf.forall loops, `target` and `source`, fuses `target` into
 /// `source`. Assumes that the given loops are siblings and are independent of
 /// each other.
@@ -202,6 +212,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
 scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
                                           RewriterBase &rewriter);
 
+/// Given two scf.parallel loops, `target` and `source`, fuses `target` into
+/// `source`. Assumes that the given loops are siblings and are independent of
+/// each other.
+///
+/// This function does not perform any legality checks and simply fuses the
+/// loops. The caller is responsible for ensuring that the loops are legal to
+/// fuse.
+scf::ParallelOp fuseIndependentSiblingParallelLoops(scf::ParallelOp target,
+                                                    scf::ParallelOp source,
+                                                    RewriterBase &rewriter);
 } // namespace mlir
 
 #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_

diff  --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index 9925fc6ce6ca9..d08e097a9b4af 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -90,4 +90,24 @@ struct JamBlockGatherer {
 /// Include the generated interface declarations.
 #include "mlir/Interfaces/LoopLikeInterface.h.inc"
 
+namespace mlir {
+/// A function that rewrites `target`'s terminator as a teminator obtained by
+/// fusing `source` into `target`.
+using FuseTerminatorFn =
+    function_ref<void(RewriterBase &rewriter, LoopLikeOpInterface source,
+                      LoopLikeOpInterface &target, IRMapping mapping)>;
+
+/// Returns a fused `LoopLikeOpInterface` created by fusing `source` to
+/// `target`.  The `NewYieldValuesFn` callback is used to pass to the
+/// `replaceWithAdditionalYields` interface method to replace the loop with a
+/// new loop with (possibly) additional yields, while the `FuseTerminatorFn`
+/// callback is repsonsible for updating the fused loop terminator.
+LoopLikeOpInterface createFused(LoopLikeOpInterface target,
+                                LoopLikeOpInterface source,
+                                RewriterBase &rewriter,
+                                NewYieldValuesFn newYieldValuesFn,
+                                FuseTerminatorFn fuseTerminatorFn);
+
+} // namespace mlir
+
 #endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 907d7f794593d..cb15e0ecebf05 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -618,6 +618,44 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
 
 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
 
+FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
+    RewriterBase &rewriter, ValueRange newInitOperands,
+    bool replaceInitOperandUsesInLoop,
+    const NewYieldValuesFn &newYieldValuesFn) {
+  // Create a new loop before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(getOperation());
+  SmallVector<Value> inits(getOutputs());
+  llvm::append_range(inits, newInitOperands);
+  scf::ForallOp newLoop = rewriter.create<scf::ForallOp>(
+      getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
+      inits, getMapping(),
+      /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
+
+  // Move the loop body to the new op.
+  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
+                       newLoop.getBody()->getArguments().take_front(
+                           getBody()->getNumArguments()));
+
+  if (replaceInitOperandUsesInLoop) {
+    // Replace all uses of `newInitOperands` with the corresponding basic block
+    // arguments.
+    for (auto &&[newOperand, oldOperand] :
+         llvm::zip(newInitOperands, newLoop.getBody()->getArguments().take_back(
+                                        newInitOperands.size()))) {
+      rewriter.replaceUsesWithIf(newOperand, oldOperand, [&](OpOperand &use) {
+        Operation *user = use.getOwner();
+        return newLoop->isProperAncestor(user);
+      });
+    }
+  }
+
+  // Replace the old loop.
+  rewriter.replaceOp(getOperation(),
+                     newLoop->getResults().take_front(getNumResults()));
+  return cast<LoopLikeOpInterface>(newLoop.getOperation());
+}
+
 /// Promotes the loop body of a forallOp to its containing block if it can be
 /// determined that the loop has a single iteration.
 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {

diff  --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 56ff2709a589e..41834fea3bb84 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -261,8 +261,10 @@ loopScheduling(scf::ForOp forOp,
     return 1;
   };
 
-  std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound());
-  std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound());
+  std::optional<int64_t> ubConstant =
+      getConstantIntValue(forOp.getUpperBound());
+  std::optional<int64_t> lbConstant =
+      getConstantIntValue(forOp.getLowerBound());
   DenseMap<Operation *, unsigned> opCycles;
   std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
   for (Operation &op : forOp.getBody()->getOperations()) {
@@ -447,113 +449,6 @@ void transform::TakeAssumedBranchOp::getEffects(
 // LoopFuseSiblingOp
 //===----------------------------------------------------------------------===//
 
-/// Check if `target` and `source` are siblings, in the context that `target`
-/// is being fused into `source`.
-///
-/// This is a simple check that just checks if both operations are in the same
-/// block and some checks to ensure that the fused IR does not violate
-/// dominance.
-static DiagnosedSilenceableFailure isOpSibling(Operation *target,
-                                               Operation *source) {
-  // Check if both operations are same.
-  if (target == source)
-    return emitSilenceableFailure(source)
-           << "target and source need to be 
diff erent loops";
-
-  // Check if both operations are in the same block.
-  if (target->getBlock() != source->getBlock())
-    return emitSilenceableFailure(source)
-           << "target and source are not in the same block";
-
-  // Check if fusion will violate dominance.
-  DominanceInfo domInfo(source);
-  if (target->isBeforeInBlock(source)) {
-    // Since `target` is before `source`, all users of results of `target`
-    // need to be dominated by `source`.
-    for (Operation *user : target->getUsers()) {
-      if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
-        return emitSilenceableFailure(target)
-               << "user of results of target should be properly dominated by "
-                  "source";
-      }
-    }
-  } else {
-    // Since `target` is after `source`, all values used by `target` need
-    // to dominate `source`.
-
-    // Check if operands of `target` are dominated by `source`.
-    for (Value operand : target->getOperands()) {
-      Operation *operandOp = operand.getDefiningOp();
-      // Operands without defining operations are block arguments. When `target`
-      // and `source` occur in the same block, these operands dominate `source`.
-      if (!operandOp)
-        continue;
-
-      // Operand's defining operation should properly dominate `source`.
-      if (!domInfo.properlyDominates(operandOp, source,
-                                     /*enclosingOpOk=*/false))
-        return emitSilenceableFailure(target)
-               << "operands of target should be properly dominated by source";
-    }
-
-    // Check if values used by `target` are dominated by `source`.
-    bool failed = false;
-    OpOperand *failedValue = nullptr;
-    visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
-      Operation *operandOp = operand->get().getDefiningOp();
-      if (operandOp && !domInfo.properlyDominates(operandOp, source,
-                                                  /*enclosingOpOk=*/false)) {
-        // `operand` is not an argument of an enclosing block and the defining
-        // op of `operand` is outside `target` but does not dominate `source`.
-        failed = true;
-        failedValue = operand;
-      }
-    });
-
-    if (failed)
-      return emitSilenceableFailure(failedValue->getOwner())
-             << "values used inside regions of target should be properly "
-                "dominated by source";
-  }
-
-  return DiagnosedSilenceableFailure::success();
-}
-
-/// Check if `target` scf.forall can be fused into `source` scf.forall.
-///
-/// This simply checks if both loops have the same bounds, steps and mapping.
-/// No attempt is made at checking that the side effects of `target` and
-/// `source` are independent of each other.
-static bool isForallWithIdenticalConfiguration(Operation *target,
-                                               Operation *source) {
-  auto targetOp = dyn_cast<scf::ForallOp>(target);
-  auto sourceOp = dyn_cast<scf::ForallOp>(source);
-  if (!targetOp || !sourceOp)
-    return false;
-
-  return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
-         targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
-         targetOp.getMixedStep() == sourceOp.getMixedStep() &&
-         targetOp.getMapping() == sourceOp.getMapping();
-}
-
-/// Check if `target` scf.for can be fused into `source` scf.for.
-///
-/// This simply checks if both loops have the same bounds and steps. No attempt
-/// is made at checking that the side effects of `target` and `source` are
-/// independent of each other.
-static bool isForWithIdenticalConfiguration(Operation *target,
-                                            Operation *source) {
-  auto targetOp = dyn_cast<scf::ForOp>(target);
-  auto sourceOp = dyn_cast<scf::ForOp>(source);
-  if (!targetOp || !sourceOp)
-    return false;
-
-  return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
-         targetOp.getUpperBound() == sourceOp.getUpperBound() &&
-         targetOp.getStep() == sourceOp.getStep();
-}
-
 DiagnosedSilenceableFailure
 transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
                                     transform::TransformResults &results,
@@ -569,25 +464,32 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
            << "source handle (got " << llvm::range_size(sourceOps) << ")";
   }
 
-  Operation *target = *targetOps.begin();
-  Operation *source = *sourceOps.begin();
+  auto target = dyn_cast<LoopLikeOpInterface>(*targetOps.begin());
+  auto source = dyn_cast<LoopLikeOpInterface>(*sourceOps.begin());
+  if (!target || !source)
+    return emitSilenceableFailure(target->getLoc())
+           << "target or source is not a loop op";
 
-  // Check if the target and source are siblings.
-  DiagnosedSilenceableFailure diag = isOpSibling(target, source);
-  if (!diag.succeeded())
-    return diag;
+  // Check if loops can be fused
+  Diagnostic diag(target.getLoc(), DiagnosticSeverity::Error);
+  if (!mlir::checkFusionStructuralLegality(target, source, diag))
+    return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
 
   Operation *fusedLoop;
-  /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
-  if (isForWithIdenticalConfiguration(target, source)) {
+  // TODO: Support fusion for loop-like ops besides scf.for, scf.forall
+  // and scf.parallel.
+  if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
     fusedLoop = fuseIndependentSiblingForLoops(
         cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
-  } else if (isForallWithIdenticalConfiguration(target, source)) {
+  } else if (isa<scf::ForallOp>(target) && isa<scf::ForallOp>(source)) {
     fusedLoop = fuseIndependentSiblingForallLoops(
         cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
+  } else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
+    fusedLoop = fuseIndependentSiblingParallelLoops(
+        cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
   } else
     return emitSilenceableFailure(target->getLoc())
-           << "operations cannot be fused";
+           << "unsupported loop type for fusion";
 
   assert(fusedLoop && "failed to fuse operations");
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 5934d85373b03..b775f988576e3 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/OpDefinition.h"
@@ -37,24 +38,6 @@ static bool hasNestedParallelOp(ParallelOp ploop) {
   return walkResult.wasInterrupted();
 }
 
-/// Verify equal iteration spaces.
-static bool equalIterationSpaces(ParallelOp firstPloop,
-                                 ParallelOp secondPloop) {
-  if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
-    return false;
-
-  auto matchOperands = [&](const OperandRange &lhs,
-                           const OperandRange &rhs) -> bool {
-    // TODO: Extend this to support aliases and equal constants.
-    return std::equal(lhs.begin(), lhs.end(), rhs.begin());
-  };
-  return matchOperands(firstPloop.getLowerBound(),
-                       secondPloop.getLowerBound()) &&
-         matchOperands(firstPloop.getUpperBound(),
-                       secondPloop.getUpperBound()) &&
-         matchOperands(firstPloop.getStep(), secondPloop.getStep());
-}
-
 /// Checks if the parallel loops have mixed access to the same buffers. Returns
 /// `true` if the first parallel loop writes to the same indices that the second
 /// loop reads.
@@ -153,9 +136,10 @@ verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
 static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
                           const IRMapping &firstToSecondPloopIndices,
                           llvm::function_ref<bool(Value, Value)> mayAlias) {
+  Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark);
   return !hasNestedParallelOp(firstPloop) &&
          !hasNestedParallelOp(secondPloop) &&
-         equalIterationSpaces(firstPloop, secondPloop) &&
+         checkFusionStructuralLegality(firstPloop, secondPloop, diag) &&
          succeeded(verifyDependencies(firstPloop, secondPloop,
                                       firstToSecondPloopIndices, mayAlias));
 }
@@ -174,61 +158,9 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
                      mayAlias))
     return;
 
-  DominanceInfo dom;
-  // We are fusing first loop into second, make sure there are no users of the
-  // first loop results between loops.
-  for (Operation *user : firstPloop->getUsers())
-    if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
-      return;
-
-  ValueRange inits1 = firstPloop.getInitVals();
-  ValueRange inits2 = secondPloop.getInitVals();
-
-  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
-  newInitVars.append(inits2.begin(), inits2.end());
-
-  IRRewriter b(builder);
-  b.setInsertionPoint(secondPloop);
-  auto newSecondPloop = b.create<ParallelOp>(
-      secondPloop.getLoc(), secondPloop.getLowerBound(),
-      secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
-
-  Block *newBlock = newSecondPloop.getBody();
-  auto term1 = cast<ReduceOp>(block1->getTerminator());
-  auto term2 = cast<ReduceOp>(block2->getTerminator());
-
-  b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
-                      newBlock->getArguments());
-  b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
-                      newBlock->getArguments());
-
-  ValueRange results = newSecondPloop.getResults();
-  if (!results.empty()) {
-    b.setInsertionPointToEnd(newBlock);
-
-    ValueRange reduceArgs1 = term1.getOperands();
-    ValueRange reduceArgs2 = term2.getOperands();
-    SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
-    newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
-
-    auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
-
-    for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
-             term1.getReductions(), term2.getReductions()))) {
-      Block &oldRedBlock = reg.front();
-      Block &newRedBlock = newReduceOp.getReductions()[i].front();
-      b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
-                          newRedBlock.getArguments());
-    }
-
-    firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
-    secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
-  }
-  term1->erase();
-  term2->erase();
-  firstPloop.erase();
-  secondPloop.erase();
-  secondPloop = newSecondPloop;
+  IRRewriter rewriter(builder);
+  secondPloop = mlir::fuseIndependentSiblingParallelLoops(
+      firstPloop, secondPloop, rewriter);
 }
 
 void mlir::scf::naivelyFuseParallelOps(

diff  --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index c0ee9d2afe91c..abfc9a1b4d444 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
@@ -1262,54 +1263,131 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
   return tileLoops;
 }
 
+//===----------------------------------------------------------------------===//
+// Fusion related helpers
+//===----------------------------------------------------------------------===//
+
+/// Check if `target` and `source` are siblings, in the context that `target`
+/// is being fused into `source`.
+///
+/// This is a simple check that just checks if both operations are in the same
+/// block and some checks to ensure that the fused IR does not violate
+/// dominance.
+static bool isOpSibling(Operation *target, Operation *source,
+                        Diagnostic &diag) {
+  // Check if both operations are same.
+  if (target == source) {
+    diag << "target and source need to be 
diff erent loops";
+    return false;
+  }
+
+  // Check if both operations are in the same block.
+  if (target->getBlock() != source->getBlock()) {
+    diag << "target and source are not in the same block";
+    return false;
+  }
+
+  // Check if fusion will violate dominance.
+  DominanceInfo domInfo(source);
+  if (target->isBeforeInBlock(source)) {
+    // Since `target` is before `source`, all users of results of `target`
+    // need to be dominated by `source`.
+    for (Operation *user : target->getUsers()) {
+      if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
+        diag << "user of results of target should "
+                "be properly dominated by source";
+        return false;
+      }
+    }
+  } else {
+    // Since `target` is after `source`, all values used by `target` need
+    // to dominate `source`.
+
+    // Check if operands of `target` are dominated by `source`.
+    for (Value operand : target->getOperands()) {
+      Operation *operandOp = operand.getDefiningOp();
+      // Operands without defining operations are block arguments. When `target`
+      // and `source` occur in the same block, these operands dominate `source`.
+      if (!operandOp)
+        continue;
+
+      // Operand's defining operation should properly dominate `source`.
+      if (!domInfo.properlyDominates(operandOp, source,
+                                     /*enclosingOpOk=*/false)) {
+        diag << "operands of target should be properly dominated by source";
+        return false;
+      }
+    }
+
+    // Check if values used by `target` are dominated by `source`.
+    bool failed = false;
+    OpOperand *failedValue = nullptr;
+    visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
+      Operation *operandOp = operand->get().getDefiningOp();
+      if (operandOp && !domInfo.properlyDominates(operandOp, source,
+                                                  /*enclosingOpOk=*/false)) {
+        // `operand` is not an argument of an enclosing block and the defining
+        // op of `operand` is outside `target` but does not dominate `source`.
+        failed = true;
+        failedValue = operand;
+      }
+    });
+
+    if (failed) {
+      diag << "values used inside regions of target should be properly "
+              "dominated by source";
+      diag.attachNote(failedValue->getOwner()->getLoc()) << "see operation";
+      return false;
+    }
+  }
+
+  return true;
+}
+
+bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target,
+                                         LoopLikeOpInterface source,
+                                         Diagnostic &diag) {
+  if (target->getName() != source->getName()) {
+    diag << "target and source must be same loop type";
+    return false;
+  }
+
+  bool iterSpaceEq =
+      target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
+      target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
+      target.getLoopSteps() == source.getLoopSteps();
+  // TODO: Decouple checks on concrete loop types and move this function
+  // somewhere for general utility for `LoopLikeOpInterface`
+  if (auto forAllTarget = dyn_cast<scf::ForallOp>(*target))
+    iterSpaceEq = iterSpaceEq && forAllTarget.getMapping() ==
+                                     cast<scf::ForallOp>(*source).getMapping();
+  if (!iterSpaceEq) {
+    diag << "target and source iteration spaces must be equal";
+    return false;
+  }
+  return isOpSibling(target, source, diag);
+}
+
 scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
                                                       scf::ForallOp source,
                                                       RewriterBase &rewriter) {
-  unsigned numTargetOuts = target.getNumResults();
-  unsigned numSourceOuts = source.getNumResults();
-
-  // Create fused shared_outs.
-  SmallVector<Value> fusedOuts;
-  llvm::append_range(fusedOuts, target.getOutputs());
-  llvm::append_range(fusedOuts, source.getOutputs());
-
-  // Create a new scf.forall op after the source loop.
-  rewriter.setInsertionPointAfter(source);
-  scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
-      source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
-      source.getMixedStep(), fusedOuts, source.getMapping());
-
-  // Map control operands.
-  IRMapping mapping;
-  mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
-  mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
-
-  // Map shared outs.
-  mapping.map(target.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
-  mapping.map(source.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
-
-  // Append everything except the terminator into the fused operation.
-  rewriter.setInsertionPointToStart(fusedLoop.getBody());
-  for (Operation &op : target.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
-  for (Operation &op : source.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
-
-  // Fuse the old terminator in_parallel ops into the new one.
-  scf::InParallelOp targetTerm = target.getTerminator();
-  scf::InParallelOp sourceTerm = source.getTerminator();
-  scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
-  rewriter.setInsertionPointToStart(fusedTerm.getBody());
-  for (Operation &op : targetTerm.getYieldingOps())
-    rewriter.clone(op, mapping);
-  for (Operation &op : sourceTerm.getYieldingOps())
-    rewriter.clone(op, mapping);
-
-  // Replace old loops by substituting their uses by results of the fused loop.
-  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
-  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+  scf::ForallOp fusedLoop = cast<scf::ForallOp>(createFused(
+      target, source, rewriter,
+      [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+        // `ForallOp` does not have yields, rather an `InParallelOp` terminator.
+        return ValueRange{};
+      },
+      [&](RewriterBase &b, LoopLikeOpInterface source,
+          LoopLikeOpInterface &target, IRMapping mapping) {
+        auto sourceForall = cast<scf::ForallOp>(source);
+        auto targetForall = cast<scf::ForallOp>(target);
+        scf::InParallelOp fusedTerm = targetForall.getTerminator();
+        b.setInsertionPointToEnd(fusedTerm.getBody());
+        for (Operation &op : sourceForall.getTerminator().getYieldingOps())
+          b.clone(op, mapping);
+      }));
+  rewriter.replaceOp(source,
+                     fusedLoop.getResults().take_back(source.getNumResults()));
 
   return fusedLoop;
 }
@@ -1317,49 +1395,74 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
 scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
                                                 scf::ForOp source,
                                                 RewriterBase &rewriter) {
-  unsigned numTargetOuts = target.getNumResults();
-  unsigned numSourceOuts = source.getNumResults();
-
-  // Create fused init_args, with target's init_args before source's init_args.
-  SmallVector<Value> fusedInitArgs;
-  llvm::append_range(fusedInitArgs, target.getInitArgs());
-  llvm::append_range(fusedInitArgs, source.getInitArgs());
-
-  // Create a new scf.for op after the source loop (with scf.yield terminator
-  // (without arguments) only in case its init_args is empty).
-  rewriter.setInsertionPointAfter(source);
-  scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
-      source.getLoc(), source.getLowerBound(), source.getUpperBound(),
-      source.getStep(), fusedInitArgs);
-
-  // Map original induction variables and operands to those of the fused loop.
-  IRMapping mapping;
-  mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
-  mapping.map(target.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
-  mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
-  mapping.map(source.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
-
-  // Merge target's body into the new (fused) for loop and then source's body.
-  rewriter.setInsertionPointToStart(fusedLoop.getBody());
-  for (Operation &op : target.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
-  for (Operation &op : source.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
-
-  // Build fused yield results by appropriately mapping original yield operands.
-  SmallVector<Value> yieldResults;
-  for (Value operand : target.getBody()->getTerminator()->getOperands())
-    yieldResults.push_back(mapping.lookupOrDefault(operand));
-  for (Value operand : source.getBody()->getTerminator()->getOperands())
-    yieldResults.push_back(mapping.lookupOrDefault(operand));
-  if (!yieldResults.empty())
-    rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
-
-  // Replace old loops by substituting their uses by results of the fused loop.
-  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
-  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+  scf::ForOp fusedLoop = cast<scf::ForOp>(createFused(
+      target, source, rewriter,
+      [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+        return source.getYieldedValues();
+      },
+      [&](RewriterBase &b, LoopLikeOpInterface source,
+          LoopLikeOpInterface &target, IRMapping mapping) {
+        auto targetFor = cast<scf::ForOp>(target);
+        auto newTerm = b.clone(*targetFor.getBody()->getTerminator(), mapping);
+        b.replaceOp(targetFor.getBody()->getTerminator(), newTerm);
+      }));
+  rewriter.replaceOp(source,
+                     fusedLoop.getResults().take_back(source.getNumResults()));
+  return fusedLoop;
+}
+
+// TODO: Finish refactoring this a la the above, but likely requires additional
+// interface methods.
+scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
+    scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) {
+  OpBuilder::InsertionGuard guard(rewriter);
+  Block *block1 = target.getBody();
+  Block *block2 = source.getBody();
+  auto term1 = cast<scf::ReduceOp>(block1->getTerminator());
+  auto term2 = cast<scf::ReduceOp>(block2->getTerminator());
+
+  ValueRange inits1 = target.getInitVals();
+  ValueRange inits2 = source.getInitVals();
+
+  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+  newInitVars.append(inits2.begin(), inits2.end());
+
+  rewriter.setInsertionPoint(source);
+  auto fusedLoop = rewriter.create<scf::ParallelOp>(
+      rewriter.getFusedLoc(target.getLoc(), source.getLoc()),
+      source.getLowerBound(), source.getUpperBound(), source.getStep(),
+      newInitVars);
+  Block *newBlock = fusedLoop.getBody();
+  rewriter.inlineBlockBefore(block2, newBlock, newBlock->begin(),
+                             newBlock->getArguments());
+  rewriter.inlineBlockBefore(block1, newBlock, newBlock->begin(),
+                             newBlock->getArguments());
+
+  ValueRange results = fusedLoop.getResults();
+  if (!results.empty()) {
+    rewriter.setInsertionPointToEnd(newBlock);
+
+    ValueRange reduceArgs1 = term1.getOperands();
+    ValueRange reduceArgs2 = term2.getOperands();
+    SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
+    newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
+
+    auto newReduceOp = rewriter.create<scf::ReduceOp>(
+        rewriter.getFusedLoc(term1.getLoc(), term2.getLoc()), newReduceArgs);
+
+    for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
+             term1.getReductions(), term2.getReductions()))) {
+      Block &oldRedBlock = reg.front();
+      Block &newRedBlock = newReduceOp.getReductions()[i].front();
+      rewriter.inlineBlockBefore(&oldRedBlock, &newRedBlock,
+                                 newRedBlock.begin(),
+                                 newRedBlock.getArguments());
+    }
+  }
+  rewriter.replaceOp(target, results.take_front(inits1.size()));
+  rewriter.replaceOp(source, results.take_back(inits2.size()));
+  rewriter.eraseOp(term1);
+  rewriter.eraseOp(term2);
 
   return fusedLoop;
 }

diff  --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index 1e0e87b64e811..5a119a7cf2659 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -8,6 +8,8 @@
 
 #include "mlir/Interfaces/LoopLikeInterface.h"
 
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "llvm/ADT/DenseSet.h"
 
@@ -113,3 +115,60 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
 
   return success();
 }
+
+LoopLikeOpInterface mlir::createFused(LoopLikeOpInterface target,
+                                      LoopLikeOpInterface source,
+                                      RewriterBase &rewriter,
+                                      NewYieldValuesFn newYieldValuesFn,
+                                      FuseTerminatorFn fuseTerminatorFn) {
+  auto targetIterArgs = target.getRegionIterArgs();
+  std::optional<SmallVector<Value>> targetInductionVar =
+      target.getLoopInductionVars();
+  SmallVector<Value> targetYieldOperands(target.getYieldedValues());
+  auto sourceIterArgs = source.getRegionIterArgs();
+  std::optional<SmallVector<Value>> sourceInductionVar =
+      *source.getLoopInductionVars();
+  SmallVector<Value> sourceYieldOperands(source.getYieldedValues());
+  auto sourceRegion = source.getLoopRegions().front();
+
+  FailureOr<LoopLikeOpInterface> maybeFusedLoop =
+      target.replaceWithAdditionalYields(rewriter, source.getInits(),
+                                         /*replaceInitOperandUsesInLoop=*/false,
+                                         newYieldValuesFn);
+  if (failed(maybeFusedLoop))
+    llvm_unreachable("failed to replace loop");
+  LoopLikeOpInterface fusedLoop = *maybeFusedLoop;
+  // Since the target op is rewritten at the original's location, we move it to
+  // the soure op's location.
+  rewriter.moveOpBefore(fusedLoop, source);
+
+  // Map control operands.
+  IRMapping mapping;
+  std::optional<SmallVector<Value>> fusedInductionVar =
+      fusedLoop.getLoopInductionVars();
+  if (fusedInductionVar) {
+    if (!targetInductionVar || !sourceInductionVar)
+      llvm_unreachable(
+          "expected target and source loops to have induction vars");
+    mapping.map(*targetInductionVar, *fusedInductionVar);
+    mapping.map(*sourceInductionVar, *fusedInductionVar);
+  }
+  mapping.map(targetIterArgs,
+              fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size()));
+  mapping.map(targetYieldOperands,
+              fusedLoop.getYieldedValues().take_front(targetIterArgs.size()));
+  mapping.map(sourceIterArgs,
+              fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size()));
+  mapping.map(sourceYieldOperands,
+              fusedLoop.getYieldedValues().take_back(sourceIterArgs.size()));
+  // Append everything except the terminator into the fused operation.
+  rewriter.setInsertionPoint(
+      fusedLoop.getLoopRegions().front()->front().getTerminator());
+  for (Operation &op : sourceRegion->front().without_terminator())
+    rewriter.clone(op, mapping);
+
+  // TODO: Replace with corresponding interface method if added
+  fuseTerminatorFn(rewriter, source, fusedLoop, mapping);
+
+  return fusedLoop;
+}

diff  --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index 54dd2bdf953ca..f8246b74a5744 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -47,6 +47,169 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func @fuse_two_parallel
+// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+func.func @fuse_two_parallel(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+// CHECK-DAG:  [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG:  [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG:  [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG:  [[C1FP:%.*]] = arith.constant 1.
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1fp = arith.constant 1.0 : f32
+// CHECK:      [[SUM:%.*]] = memref.alloc()
+  %sum = memref.alloc()  : memref<2x2xf32>
+// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK:        [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
+// CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
+// CHECK:        [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
+// CHECK:        memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        scf.reduce
+// CHECK:      }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %c1fp : f32
+    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+// CHECK:      memref.dealloc [[SUM]]
+  memref.dealloc %sum : memref<2x2xf32>
+  return
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %parallel:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func @fuse_two_parallel_reverse
+// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+func.func @fuse_two_parallel_reverse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+// CHECK-DAG:  [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG:  [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG:  [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG:  [[C1FP:%.*]] = arith.constant 1.
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1fp = arith.constant 1.0 : f32
+// CHECK:      [[SUM:%.*]] = memref.alloc()
+  %sum = memref.alloc()  : memref<2x2xf32>
+// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK:        [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
+// CHECK:        memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
+// CHECK:        [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
+// CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK:        scf.reduce
+// CHECK:      }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %c1fp : f32
+    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+// CHECK:      memref.dealloc [[SUM]]
+  memref.dealloc %sum : memref<2x2xf32>
+  return
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %parallel:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %parallel#1 into %parallel#0 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func @fuse_reductions_two
+//  CHECK-SAME:  (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
+func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
+//   CHECK-DAG:   %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
+//       CHECK:   %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
+//  CHECK-SAME:   to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
+//  CHECK-SAME:   init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
+//       CHECK:   %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
+//       CHECK:   %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
+//       CHECK:   scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) {
+//       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+//       CHECK:     %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
+//       CHECK:     scf.reduce.return %[[R]] : f32
+//       CHECK:   }
+//       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+//       CHECK:     %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
+//       CHECK:     scf.reduce.return %[[R]] : f32
+//       CHECK:   }
+//       CHECK:   return %[[RES]]#0, %[[RES]]#1 : f32, f32
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %init1 = arith.constant 1.0 : f32
+  %init2 = arith.constant 2.0 : f32
+  %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    scf.reduce(%A_elem : f32) {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.addf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+  }
+  %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    scf.reduce(%B_elem : f32) {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.mulf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+  }
+  return %res1, %res2 : f32, f32
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %parallel:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
 func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
   // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
@@ -208,6 +371,62 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+
+// -----
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 32) 
+#map = affine_map<(d0) -> (d0 * 32)>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+module {
+  // CHECK: func.func @loop_sibling_fusion(%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}
+  func.func @loop_sibling_fusion(%arg0: tensor<128xf32>, %arg1: tensor<128x128xf16>, %arg2: tensor<128x64xf32>, %arg3: tensor<128x128xf32>) -> (tensor<128xf32>, tensor<128x128xf16>) {
+  // CHECK:      %[[EMPTY:.*]] = tensor.empty() : tensor<128x128xf16>
+  // CHECK-NEXT: %[[RESULTS:.*]]:2 = scf.forall (%[[I:.*]]) in (4) shared_outs(%[[S1:.*]] = %[[ARG0]], %[[S2:.*]] = %[[ARG1]]) -> (tensor<128xf32>, tensor<128x128xf16>) {
+  // CHECK-NEXT:  %[[IDX:.*]] = affine.apply #[[$MAP]](%[[I]])
+  // CHECK-NEXT:  %[[SLICE0:.*]] = tensor.extract_slice %[[ARG3]][%[[IDX]], 0] [32, 1] [1, 1] : tensor<128x128xf32> to tensor<32xf32>
+  // CHECK-NEXT:  %[[SLICE1:.*]] = tensor.extract_slice %[[ARG3]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<128x128xf32> to tensor<32x128xf32>
+  // CHECK-NEXT:  %[[SLICE2:.*]] = tensor.extract_slice %[[EMPTY]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<128x128xf16> to tensor<32x128xf16>
+  // CHECK-NEXT:  %[[GENERIC:.*]] = linalg.generic {{.*}} ins(%[[SLICE1]] : {{.*}}) outs(%[[SLICE2]] : {{.*}})
+  // CHECK:       scf.forall.in_parallel {
+  // CHECK-NEXT:    tensor.parallel_insert_slice %[[SLICE0]] into %[[S1]][%[[IDX]]] [32] [1] : tensor<32xf32> into tensor<128xf32>
+  // CHECK-NEXT:    tensor.parallel_insert_slice %[[GENERIC]] into %[[S2]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<32x128xf16> into tensor<128x128xf16>
+  // CHECK-NEXT:  }
+  // CHECK-NEXT: } {mapping = [#gpu.warp<linear_dim_0>]}
+  // CHECK-NEXT: return %[[RESULTS]]#0, %[[RESULTS]]#1
+    %0 = scf.forall (%arg4) in (4) shared_outs(%arg5 = %arg0) -> (tensor<128xf32>) {
+      %3 = affine.apply #map(%arg4)
+      %extracted_slice = tensor.extract_slice %arg3[%3, 0] [32, 1] [1, 1] : tensor<128x128xf32> to tensor<32xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %extracted_slice into %arg5[%3] [32] [1] : tensor<32xf32> into tensor<128xf32>
+      }
+    } {mapping = [#gpu.warp<linear_dim_0>]}
+    %1 = tensor.empty() : tensor<128x128xf16>
+    %2 = scf.forall (%arg4) in (4) shared_outs(%arg5 = %arg1) -> (tensor<128x128xf16>) {
+      %3 = affine.apply #map(%arg4)
+      %extracted_slice = tensor.extract_slice %arg3[%3, 0] [32, 128] [1, 1] : tensor<128x128xf32> to tensor<32x128xf32>
+      %extracted_slice_0 = tensor.extract_slice %1[%3, 0] [32, 128] [1, 1] : tensor<128x128xf16> to tensor<32x128xf16>
+      %4 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<32x128xf32>) outs(%extracted_slice_0 : tensor<32x128xf16>) {
+      ^bb0(%in: f32, %out: f16):
+        %5 = arith.truncf %in : f32 to f16 
+        linalg.yield %5 : f16 
+      } -> tensor<32x128xf16>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %4 into %arg5[%3, 0] [32, 128] [1, 1] : tensor<32x128xf16> into tensor<128x128xf16>
+      }   
+    } {mapping = [#gpu.warp<linear_dim_0>]}
+    return %0, %2 : tensor<128xf32>, tensor<128x128xf16>
+  }
+}
+
+module attributes { transform.with_named_sequence } { 
+  transform.named_sequence @__transform_main(%root: !transform.any_op) {
+    %loops = transform.structured.match ops{["scf.forall"]} in %root : (!transform.any_op) -> !transform.any_op
+    %loop1, %loop2 = transform.split_handle %loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %loop3 = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
 // -----
 
 func.func @source_for_uses_result_of_target_for_err(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
@@ -282,8 +501,9 @@ func.func @target_for_region_uses_result_of_source_for_err(%A: tensor<128xf32>,
     %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
     scf.yield %6 : tensor<128xf32>
   }
-  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
   // expected-error @below {{values used inside regions of target should be properly dominated by source}}
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
+    // expected-note @below {{see operation}}
     %dup2 = vector.transfer_read %1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
     %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
     %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
@@ -328,6 +548,74 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+func.func @non_matching_iteration_spaces_err(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1fp = arith.constant 1.0 : f32
+  %sum = memref.alloc()  : memref<2x2xf32>
+  // expected-error @below {{target and source iteration spaces must be equal}}
+  scf.parallel (%i) = (%c0) to (%c2) step (%c1) {
+    %B_elem = memref.load %B[%i, %c0] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %c1fp : f32
+    memref.store %sum_elem, %sum[%i, %c0] : memref<2x2xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  memref.dealloc %sum : memref<2x2xf32>
+  return
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %parallel:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @non_matching_loop_types_err(%A: memref<2xf32>, %B: memref<2xf32>) {
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1fp = arith.constant 1.0 : f32
+  %sum = memref.alloc()  : memref<2xf32>
+  // expected-error @below {{target and source must be same loop type}}
+  scf.for %i = %c0 to %c2 step %c1 {
+    %B_elem = memref.load %B[%i] : memref<2xf32>
+    %sum_elem = arith.addf %B_elem, %c1fp : f32
+    memref.store %sum_elem, %sum[%i] : memref<2xf32>
+  }
+  scf.parallel (%i) = (%c0) to (%c2) step (%c1) {
+    %sum_elem = memref.load %sum[%i] : memref<2xf32>
+    %A_elem = memref.load %A[%i] : memref<2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %B[%i] : memref<2xf32>
+    scf.reduce
+  }
+  memref.dealloc %sum : memref<2xf32>
+  return
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %fused = transform.loop.fuse_sibling %0 into %1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
 // -----
 
 // CHECK: func.func @foreach_loop_pair_fuse([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}


        


More information about the Mlir-commits mailing list