[Mlir-commits] [mlir] Refactor LoopFuseSiblingOp and support parallel fusion (PR #94391)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 5 12:22:16 PDT 2024


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/94391

>From 5020e498440b0016adef7e99806aa55c4837b441 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 5 Jun 2024 13:08:22 -0500
Subject: [PATCH 1/4] Add getters for multi dim loop variables in
 LoopLikeOpInterface

---
 .../mlir/Dialect/Affine/IR/AffineOps.td       |  4 +-
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    | 37 ++--------
 .../mlir/Interfaces/LoopLikeInterface.td      | 65 +++++++++++------
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 20 +++---
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 70 +++++++------------
 .../Dialect/SCF/LoopLikeSCFOpsTest.cpp        |  8 +++
 6 files changed, 97 insertions(+), 107 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 3640055ea8da8..bb2c29b5733b8 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -118,8 +118,8 @@ def AffineForOp : Affine_Op<"for",
     [AttrSizedOperandSegments, AutomaticAllocationScope,
      ImplicitAffineTerminator, ConditionallySpeculatable,
      RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
-     ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
-      "getSingleUpperBound", "getYieldedValuesMutable",
+     ["getInductionVars", "getMixedLowerBound", "getMixedStep",
+      "getMixedUpperBound", "getYieldedValuesMutable",
       "replaceWithAdditionalYields"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
      ["getEntrySuccessorOperands"]>]> {
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 0b063aa772bab..3b28ca8b21d0f 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -136,8 +136,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
 def ForOp : SCF_Op<"for",
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
        ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
-        "getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
-        "getSingleUpperBound", "getYieldedValuesMutable",
+        "getInductionVars", "getMixedLowerBound", "getMixedStep",
+        "getMixedUpperBound", "getYieldedValuesMutable",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
         "yieldTiledValuesAndReplace"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
@@ -301,8 +301,8 @@ def ForallOp : SCF_Op<"forall", [
        AttrSizedOperandSegments,
        AutomaticAllocationScope,
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
-          ["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar", 
-           "getSingleLowerBound", "getSingleUpperBound", "getSingleStep",
+          ["getInitsMutable", "getRegionIterArgs", "getInductionVars", 
+           "getMixedLowerBound", "getMixedUpperBound", "getMixedStep",
            "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
        RecursiveMemoryEffects,
        SingleBlockImplicitTerminator<"scf::InParallelOp">,
@@ -510,24 +510,6 @@ def ForallOp : SCF_Op<"forall", [
   ];
 
   let extraClassDeclaration = [{
-    // Get lower bounds as OpFoldResult.
-    SmallVector<OpFoldResult> getMixedLowerBound() {
-      Builder b(getOperation()->getContext());
-      return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
-    }
-
-    // Get upper bounds as OpFoldResult.
-    SmallVector<OpFoldResult> getMixedUpperBound() {
-      Builder b(getOperation()->getContext());
-      return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
-    }
-
-    // Get steps as OpFoldResult.
-    SmallVector<OpFoldResult> getMixedStep() {
-      Builder b(getOperation()->getContext());
-      return getMixedValues(getStaticStep(), getDynamicStep(), b);
-    }
-
     /// Get lower bounds as values.
     SmallVector<Value> getLowerBound(OpBuilder &b) {
       return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedLowerBound());
@@ -584,10 +566,6 @@ def ForallOp : SCF_Op<"forall", [
                                     getNumDynamicControlOperands() + getRank());
     }
 
-    ::mlir::ValueRange getInductionVars() {
-      return getBody()->getArguments().take_front(getRank());
-    }
-
     ::mlir::Value getInductionVar(int64_t idx) {
       return getInductionVars()[idx];
     }
@@ -765,8 +743,8 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
 def ParallelOp : SCF_Op<"parallel",
     [AutomaticAllocationScope,
      AttrSizedOperandSegments,
-     DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getSingleInductionVar",
-          "getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
+     DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getInductionVars",
+          "getMixedLowerBound", "getMixedUpperBound", "getMixedStep"]>,
      RecursiveMemoryEffects,
      DeclareOpInterfaceMethods<RegionBranchOpInterface>,
      SingleBlockImplicitTerminator<"scf::ReduceOp">,
@@ -846,9 +824,6 @@ def ParallelOp : SCF_Op<"parallel",
   ];
 
   let extraClassDeclaration = [{
-    ValueRange getInductionVars() {
-      return getBody()->getArguments();
-    }
     unsigned getNumLoops() { return getStep().size(); }
     unsigned getNumReductions() { return getInitVals().size(); }
   }];
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index f0dc6e60eba58..813779c852027 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -93,51 +93,47 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       }]
     >,
     InterfaceMethod<[{
-        If there is a single induction variable return it, otherwise return
-        std::nullopt.
+        Return all induction variables.
       }],
-      /*retTy=*/"::std::optional<::mlir::Value>",
-      /*methodName=*/"getSingleInductionVar",
+      /*retTy=*/"::mlir::ValueRange",
+      /*methodName=*/"getInductionVars",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return std::nullopt;
+        return {};
       }]
     >,
     InterfaceMethod<[{
-        Return the single lower bound value or attribute if it exists, otherwise
-        return std::nullopt.
+        Return all lower bounds.
       }],
-      /*retTy=*/"::std::optional<::mlir::OpFoldResult>",
-      /*methodName=*/"getSingleLowerBound",
+      /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
+      /*methodName=*/"getMixedLowerBound",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return std::nullopt;
+        return {};
       }]
     >,
     InterfaceMethod<[{
-        Return the single step value or attribute if it exists, otherwise
-        return std::nullopt.
+        Return all steps.
       }],
-      /*retTy=*/"::std::optional<::mlir::OpFoldResult>",
-      /*methodName=*/"getSingleStep",
+      /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
+      /*methodName=*/"getMixedStep",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return std::nullopt;
+        return {};
       }]
     >,
     InterfaceMethod<[{
-        Return the single upper bound value or attribute if it exists, otherwise
-        return std::nullopt.
+        Return all upper bounds.
       }],
-      /*retTy=*/"::std::optional<::mlir::OpFoldResult>",
-      /*methodName=*/"getSingleUpperBound",
+      /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
+      /*methodName=*/"getMixedUpperBound",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return std::nullopt;
+        return {};
       }]
     >,
     InterfaceMethod<[{
@@ -235,6 +231,35 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
   }];
 
   let extraSharedClassDeclaration = [{
+    /// If there is a single induction variable return it, otherwise return
+    /// std::nullopt.
+    ::std::optional<::mlir::Value> getSingleInductionVar() {
+      if (this->getInductionVars().size() == 1)
+          return this->getInductionVars()[0];
+        return std::nullopt;
+    }
+    /// Return the single lower bound value or attribute if it exists, otherwise
+    /// return std::nullopt.
+    ::std::optional<::mlir::OpFoldResult> getSingleLowerBound() {
+      if (this->getMixedLowerBound().size() == 1)
+          return this->getMixedLowerBound()[0];
+        return std::nullopt;
+    }
+    /// Return the single step value or attribute if it exists, otherwise
+    /// return std::nullopt.
+    ::std::optional<::mlir::OpFoldResult> getSingleStep() {
+      if (this->getMixedStep().size() == 1)
+          return this->getMixedStep()[0];
+        return std::nullopt;
+    }
+    /// Return the single upper bound value or attribute if it exists, otherwise
+    /// return std::nullopt.
+    ::std::optional<::mlir::OpFoldResult> getSingleUpperBound() {
+      if (this->getMixedUpperBound().size() == 1)
+          return this->getMixedUpperBound()[0];
+        return std::nullopt;
+    }
+
     /// Append the specified additional "init" operands: replace this loop with
     /// a new loop that has the additional init operands. The loop body of this
     /// loop is moved over to the new loop.
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 2e31487bd55a0..746a9c919560c 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2454,27 +2454,25 @@ bool AffineForOp::matchingBoundOperandList() {
 
 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
 
-std::optional<Value> AffineForOp::getSingleInductionVar() {
-  return getInductionVar();
-}
+ValueRange AffineForOp::getInductionVars() { return {getInductionVar()}; }
 
-std::optional<OpFoldResult> AffineForOp::getSingleLowerBound() {
+SmallVector<OpFoldResult> AffineForOp::getMixedLowerBound() {
   if (!hasConstantLowerBound())
-    return std::nullopt;
+    return {};
   OpBuilder b(getContext());
-  return OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()));
+  return {OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
 }
 
-std::optional<OpFoldResult> AffineForOp::getSingleStep() {
+SmallVector<OpFoldResult> AffineForOp::getMixedStep() {
   OpBuilder b(getContext());
-  return OpFoldResult(b.getI64IntegerAttr(getStepAsInt()));
+  return {OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
 }
 
-std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
+SmallVector<OpFoldResult> AffineForOp::getMixedUpperBound() {
   if (!hasConstantUpperBound())
-    return std::nullopt;
+    return {};
   OpBuilder b(getContext());
-  return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
+  return {OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
 }
 
 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 107fd0690f193..e275ff1849c10 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -378,20 +378,18 @@ LogicalResult ForOp::verifyRegions() {
   return success();
 }
 
-std::optional<Value> ForOp::getSingleInductionVar() {
-  return getInductionVar();
-}
+ValueRange ForOp::getInductionVars() { return {getInductionVar()}; }
 
-std::optional<OpFoldResult> ForOp::getSingleLowerBound() {
-  return OpFoldResult(getLowerBound());
+SmallVector<OpFoldResult> ForOp::getMixedLowerBound() {
+  return {OpFoldResult(getLowerBound())};
 }
 
-std::optional<OpFoldResult> ForOp::getSingleStep() {
-  return OpFoldResult(getStep());
+SmallVector<OpFoldResult> ForOp::getMixedStep() {
+  return {OpFoldResult(getStep())};
 }
 
-std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
-  return OpFoldResult(getUpperBound());
+SmallVector<OpFoldResult> ForOp::getMixedUpperBound() {
+  return {OpFoldResult(getUpperBound())};
 }
 
 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
@@ -1428,28 +1426,26 @@ SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
   return storeOps;
 }
 
-std::optional<Value> ForallOp::getSingleInductionVar() {
-  if (getRank() != 1)
-    return std::nullopt;
-  return getInductionVar(0);
+ValueRange ForallOp::getInductionVars() {
+  return getBody()->getArguments().take_front(getRank());
 }
 
-std::optional<OpFoldResult> ForallOp::getSingleLowerBound() {
-  if (getRank() != 1)
-    return std::nullopt;
-  return getMixedLowerBound()[0];
+// Get lower bounds as OpFoldResult.
+SmallVector<OpFoldResult> ForallOp::getMixedLowerBound() {
+  Builder b(getOperation()->getContext());
+  return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
 }
 
-std::optional<OpFoldResult> ForallOp::getSingleUpperBound() {
-  if (getRank() != 1)
-    return std::nullopt;
-  return getMixedUpperBound()[0];
+// Get upper bounds as OpFoldResult.
+SmallVector<OpFoldResult> ForallOp::getMixedUpperBound() {
+  Builder b(getOperation()->getContext());
+  return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
 }
 
-std::optional<OpFoldResult> ForallOp::getSingleStep() {
-  if (getRank() != 1)
-    return std::nullopt;
-  return getMixedStep()[0];
+// Get steps as OpFoldResult.
+SmallVector<OpFoldResult> ForallOp::getMixedStep() {
+  Builder b(getOperation()->getContext());
+  return getMixedValues(getStaticStep(), getDynamicStep(), b);
 }
 
 ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) {
@@ -3008,29 +3004,17 @@ void ParallelOp::print(OpAsmPrinter &p) {
 
 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
 
-std::optional<Value> ParallelOp::getSingleInductionVar() {
-  if (getNumLoops() != 1)
-    return std::nullopt;
-  return getBody()->getArgument(0);
-}
+ValueRange ParallelOp::getInductionVars() { return getBody()->getArguments(); }
 
-std::optional<OpFoldResult> ParallelOp::getSingleLowerBound() {
-  if (getNumLoops() != 1)
-    return std::nullopt;
-  return getLowerBound()[0];
+SmallVector<OpFoldResult> ParallelOp::getMixedLowerBound() {
+  return getLowerBound();
 }
 
-std::optional<OpFoldResult> ParallelOp::getSingleUpperBound() {
-  if (getNumLoops() != 1)
-    return std::nullopt;
-  return getUpperBound()[0];
+SmallVector<OpFoldResult> ParallelOp::getMixedUpperBound() {
+  return getUpperBound();
 }
 
-std::optional<OpFoldResult> ParallelOp::getSingleStep() {
-  if (getNumLoops() != 1)
-    return std::nullopt;
-  return getStep()[0];
-}
+SmallVector<OpFoldResult> ParallelOp::getMixedStep() { return getStep(); }
 
 ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
   auto ivArg = llvm::dyn_cast<BlockArgument>(val);
diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
index 6bc0fd6113b9b..d8cdb213070da 100644
--- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
+++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
@@ -36,6 +36,10 @@ class SCFLoopLikeTest : public ::testing::Test {
     std::optional<OpFoldResult> maybeIndVar =
         loopLikeOp.getSingleInductionVar();
     EXPECT_TRUE(maybeIndVar.has_value());
+    EXPECT_EQ(loopLikeOp.getInductionVars().size(), 1u);
+    EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u);
+    EXPECT_EQ(loopLikeOp.getMixedStep().size(), 1u);
+    EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u);
   }
 
   void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
@@ -48,6 +52,10 @@ class SCFLoopLikeTest : public ::testing::Test {
     std::optional<OpFoldResult> maybeIndVar =
         loopLikeOp.getSingleInductionVar();
     EXPECT_FALSE(maybeIndVar.has_value());
+    EXPECT_EQ(loopLikeOp.getInductionVars().size(), 2u);
+    EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u);
+    EXPECT_EQ(loopLikeOp.getMixedStep().size(), 2u);
+    EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u);
   }
 
   MLIRContext context;

>From 50852d570440e0041c8b2b38925c4af05fac0636 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 4 Jun 2024 14:44:53 -0500
Subject: [PATCH 2/4] Refactor LoopFuseSiblingOp and support parallel fusion

---
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h   |  16 ++
 .../SCF/TransformOps/SCFTransformOps.cpp      |  53 +++--
 .../SCF/Transforms/ParallelLoopFusion.cpp     | 204 +----------------
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 208 ++++++++++++++++++
 .../SCF/transform-loop-fuse-sibling.mlir      |  53 +++++
 5 files changed, 304 insertions(+), 230 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index bc09cc7f7fa5e..2944d8ffac022 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -156,6 +156,12 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
 void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
                              scf::ForOp root);
 
+/// Prepends operations of firstPloop's body into secondPloop's body.
+/// Updates secondPloop with new loop.
+void fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
+                 OpBuilder builder,
+                 llvm::function_ref<bool(Value, Value)> mayAlias);
+
 /// Given two scf.forall loops, `target` and `source`, fuses `target` into
 /// `source`. Assumes that the given loops are siblings and are independent of
 /// each other.
@@ -177,6 +183,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
 scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
                                           RewriterBase &rewriter);
 
+/// Given two scf.parallel loops, `target` and `source`, fuses `target` into
+/// `source`. Assumes that the given loops are siblings and are independent of
+/// each other.
+///
+/// This function does not perform any legality checks and simply fuses the
+/// loops. The caller is responsible for ensuring that the loops are legal to
+/// fuse.
+scf::ParallelOp fuseIndependentSiblingParallelLoops(scf::ParallelOp target,
+                                                    scf::ParallelOp source,
+                                                    RewriterBase &rewriter);
 } // namespace mlir
 
 #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 69f83d8bd70da..1c53e89d69040 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -442,39 +442,32 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
   return DiagnosedSilenceableFailure::success();
 }
 
-/// Check if `target` scf.forall can be fused into `source` scf.forall.
+/// Check if `target` scf loop can be fused into `source` scf loop.
+/// Applies for scf.for, scf.forall, and scf.parallel.
 ///
 /// This simply checks if both loops have the same bounds, steps and mapping.
 /// No attempt is made at checking that the side effects of `target` and
 /// `source` are independent of each other.
-static bool isForallWithIdenticalConfiguration(Operation *target,
-                                               Operation *source) {
-  auto targetOp = dyn_cast<scf::ForallOp>(target);
-  auto sourceOp = dyn_cast<scf::ForallOp>(source);
-  if (!targetOp || !sourceOp)
-    return false;
-
-  return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
-         targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
-         targetOp.getMixedStep() == sourceOp.getMixedStep() &&
-         targetOp.getMapping() == sourceOp.getMapping();
-}
-
-/// Check if `target` scf.for can be fused into `source` scf.for.
-///
-/// This simply checks if both loops have the same bounds and steps. No attempt
-/// is made at checking that the side effects of `target` and `source` are
-/// independent of each other.
-static bool isForWithIdenticalConfiguration(Operation *target,
-                                            Operation *source) {
-  auto targetOp = dyn_cast<scf::ForOp>(target);
-  auto sourceOp = dyn_cast<scf::ForOp>(source);
+template <typename LoopTy>
+static bool isLoopWithIdenticalConfiguration(Operation *target,
+                                             Operation *source) {
+  static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
+                                scf::ParallelOp>::value,
+                "applies to only `forall`, `for` and `parallel`");
+  auto targetOp = dyn_cast<LoopTy>(target);
+  auto sourceOp = dyn_cast<LoopTy>(source);
   if (!targetOp || !sourceOp)
     return false;
 
-  return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
-         targetOp.getUpperBound() == sourceOp.getUpperBound() &&
-         targetOp.getStep() == sourceOp.getStep();
+  if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
+    return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
+           targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
+           targetOp.getMixedStep() == sourceOp.getMixedStep() &&
+           targetOp.getMapping() == sourceOp.getMapping();
+  else
+    return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
+           targetOp.getUpperBound() == sourceOp.getUpperBound() &&
+           targetOp.getStep() == sourceOp.getStep();
 }
 
 DiagnosedSilenceableFailure
@@ -502,12 +495,16 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
 
   Operation *fusedLoop;
   /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
-  if (isForWithIdenticalConfiguration(target, source)) {
+  if (isLoopWithIdenticalConfiguration<scf::ForOp>(target, source)) {
     fusedLoop = fuseIndependentSiblingForLoops(
         cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
-  } else if (isForallWithIdenticalConfiguration(target, source)) {
+  } else if (isLoopWithIdenticalConfiguration<scf::ForallOp>(target, source)) {
     fusedLoop = fuseIndependentSiblingForallLoops(
         cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
+  } else if (isLoopWithIdenticalConfiguration<scf::ParallelOp>(target,
+                                                               source)) {
+    fusedLoop = fuseIndependentSiblingParallelLoops(
+        cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
   } else
     return emitSilenceableFailure(target->getLoc())
            << "operations cannot be fused";
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 5934d85373b03..abac91cfaf7d9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/OpDefinition.h"
@@ -30,207 +31,6 @@ namespace mlir {
 using namespace mlir;
 using namespace mlir::scf;
 
-/// Verify there are no nested ParallelOps.
-static bool hasNestedParallelOp(ParallelOp ploop) {
-  auto walkResult =
-      ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
-  return walkResult.wasInterrupted();
-}
-
-/// Verify equal iteration spaces.
-static bool equalIterationSpaces(ParallelOp firstPloop,
-                                 ParallelOp secondPloop) {
-  if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
-    return false;
-
-  auto matchOperands = [&](const OperandRange &lhs,
-                           const OperandRange &rhs) -> bool {
-    // TODO: Extend this to support aliases and equal constants.
-    return std::equal(lhs.begin(), lhs.end(), rhs.begin());
-  };
-  return matchOperands(firstPloop.getLowerBound(),
-                       secondPloop.getLowerBound()) &&
-         matchOperands(firstPloop.getUpperBound(),
-                       secondPloop.getUpperBound()) &&
-         matchOperands(firstPloop.getStep(), secondPloop.getStep());
-}
-
-/// Checks if the parallel loops have mixed access to the same buffers. Returns
-/// `true` if the first parallel loop writes to the same indices that the second
-/// loop reads.
-static bool haveNoReadsAfterWriteExceptSameIndex(
-    ParallelOp firstPloop, ParallelOp secondPloop,
-    const IRMapping &firstToSecondPloopIndices,
-    llvm::function_ref<bool(Value, Value)> mayAlias) {
-  DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
-  SmallVector<Value> bufferStoresVec;
-  firstPloop.getBody()->walk([&](memref::StoreOp store) {
-    bufferStores[store.getMemRef()].push_back(store.getIndices());
-    bufferStoresVec.emplace_back(store.getMemRef());
-  });
-  auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
-    Value loadMem = load.getMemRef();
-    // Stop if the memref is defined in secondPloop body. Careful alias analysis
-    // is needed.
-    auto *memrefDef = loadMem.getDefiningOp();
-    if (memrefDef && memrefDef->getBlock() == load->getBlock())
-      return WalkResult::interrupt();
-
-    for (Value store : bufferStoresVec)
-      if (store != loadMem && mayAlias(store, loadMem))
-        return WalkResult::interrupt();
-
-    auto write = bufferStores.find(loadMem);
-    if (write == bufferStores.end())
-      return WalkResult::advance();
-
-    // Check that at last one store was retrieved
-    if (!write->second.size())
-      return WalkResult::interrupt();
-
-    auto storeIndices = write->second.front();
-
-    // Multiple writes to the same memref are allowed only on the same indices
-    for (const auto &othStoreIndices : write->second) {
-      if (othStoreIndices != storeIndices)
-        return WalkResult::interrupt();
-    }
-
-    // Check that the load indices of secondPloop coincide with store indices of
-    // firstPloop for the same memrefs.
-    auto loadIndices = load.getIndices();
-    if (storeIndices.size() != loadIndices.size())
-      return WalkResult::interrupt();
-    for (int i = 0, e = storeIndices.size(); i < e; ++i) {
-      if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
-          loadIndices[i]) {
-        auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
-        auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
-        if (storeIndexDefOp && loadIndexDefOp) {
-          if (!isMemoryEffectFree(storeIndexDefOp))
-            return WalkResult::interrupt();
-          if (!isMemoryEffectFree(loadIndexDefOp))
-            return WalkResult::interrupt();
-          if (!OperationEquivalence::isEquivalentTo(
-                  storeIndexDefOp, loadIndexDefOp,
-                  [&](Value storeIndex, Value loadIndex) {
-                    if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
-                        firstToSecondPloopIndices.lookupOrDefault(loadIndex))
-                      return failure();
-                    else
-                      return success();
-                  },
-                  /*markEquivalent=*/nullptr,
-                  OperationEquivalence::Flags::IgnoreLocations)) {
-            return WalkResult::interrupt();
-          }
-        } else
-          return WalkResult::interrupt();
-      }
-    }
-    return WalkResult::advance();
-  });
-  return !walkResult.wasInterrupted();
-}
-
-/// Analyzes dependencies in the most primitive way by checking simple read and
-/// write patterns.
-static LogicalResult
-verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
-                   const IRMapping &firstToSecondPloopIndices,
-                   llvm::function_ref<bool(Value, Value)> mayAlias) {
-  if (!haveNoReadsAfterWriteExceptSameIndex(
-          firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
-    return failure();
-
-  IRMapping secondToFirstPloopIndices;
-  secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
-                                firstPloop.getBody()->getArguments());
-  return success(haveNoReadsAfterWriteExceptSameIndex(
-      secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
-}
-
-static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
-                          const IRMapping &firstToSecondPloopIndices,
-                          llvm::function_ref<bool(Value, Value)> mayAlias) {
-  return !hasNestedParallelOp(firstPloop) &&
-         !hasNestedParallelOp(secondPloop) &&
-         equalIterationSpaces(firstPloop, secondPloop) &&
-         succeeded(verifyDependencies(firstPloop, secondPloop,
-                                      firstToSecondPloopIndices, mayAlias));
-}
-
-/// Prepends operations of firstPloop's body into secondPloop's body.
-/// Updates secondPloop with new loop.
-static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
-                        OpBuilder builder,
-                        llvm::function_ref<bool(Value, Value)> mayAlias) {
-  Block *block1 = firstPloop.getBody();
-  Block *block2 = secondPloop.getBody();
-  IRMapping firstToSecondPloopIndices;
-  firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
-
-  if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
-                     mayAlias))
-    return;
-
-  DominanceInfo dom;
-  // We are fusing first loop into second, make sure there are no users of the
-  // first loop results between loops.
-  for (Operation *user : firstPloop->getUsers())
-    if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
-      return;
-
-  ValueRange inits1 = firstPloop.getInitVals();
-  ValueRange inits2 = secondPloop.getInitVals();
-
-  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
-  newInitVars.append(inits2.begin(), inits2.end());
-
-  IRRewriter b(builder);
-  b.setInsertionPoint(secondPloop);
-  auto newSecondPloop = b.create<ParallelOp>(
-      secondPloop.getLoc(), secondPloop.getLowerBound(),
-      secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
-
-  Block *newBlock = newSecondPloop.getBody();
-  auto term1 = cast<ReduceOp>(block1->getTerminator());
-  auto term2 = cast<ReduceOp>(block2->getTerminator());
-
-  b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
-                      newBlock->getArguments());
-  b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
-                      newBlock->getArguments());
-
-  ValueRange results = newSecondPloop.getResults();
-  if (!results.empty()) {
-    b.setInsertionPointToEnd(newBlock);
-
-    ValueRange reduceArgs1 = term1.getOperands();
-    ValueRange reduceArgs2 = term2.getOperands();
-    SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
-    newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
-
-    auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
-
-    for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
-             term1.getReductions(), term2.getReductions()))) {
-      Block &oldRedBlock = reg.front();
-      Block &newRedBlock = newReduceOp.getReductions()[i].front();
-      b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
-                          newRedBlock.getArguments());
-    }
-
-    firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
-    secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
-  }
-  term1->erase();
-  term2->erase();
-  firstPloop.erase();
-  secondPloop.erase();
-  secondPloop = newSecondPloop;
-}
-
 void mlir::scf::naivelyFuseParallelOps(
     Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
   OpBuilder b(region);
@@ -259,7 +59,7 @@ void mlir::scf::naivelyFuseParallelOps(
     }
     for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
-        fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
+        mlir::fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
     }
   }
 }
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 6658cca03eba7..d85339f32dbe3 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
@@ -1070,6 +1071,206 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
   return tileLoops;
 }
 
+/// Checks if the parallel loops have mixed access to the same buffers. Returns
+/// `true` if the first parallel loop writes to the same indices that the second
+/// loop reads.
+static bool haveNoReadsAfterWriteExceptSameIndex(
+    scf::ParallelOp firstPloop, scf::ParallelOp secondPloop,
+    const IRMapping &firstToSecondPloopIndices,
+    llvm::function_ref<bool(Value, Value)> mayAlias) {
+  DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
+  SmallVector<Value> bufferStoresVec;
+  firstPloop.getBody()->walk([&](memref::StoreOp store) {
+    bufferStores[store.getMemRef()].push_back(store.getIndices());
+    bufferStoresVec.emplace_back(store.getMemRef());
+  });
+  auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
+    Value loadMem = load.getMemRef();
+    // Stop if the memref is defined in secondPloop body. Careful alias analysis
+    // is needed.
+    auto *memrefDef = loadMem.getDefiningOp();
+    if (memrefDef && memrefDef->getBlock() == load->getBlock())
+      return WalkResult::interrupt();
+
+    for (Value store : bufferStoresVec)
+      if (store != loadMem && mayAlias(store, loadMem))
+        return WalkResult::interrupt();
+
+    auto write = bufferStores.find(loadMem);
+    if (write == bufferStores.end())
+      return WalkResult::advance();
+
+    // Check that at last one store was retrieved
+    if (!write->second.size())
+      return WalkResult::interrupt();
+
+    auto storeIndices = write->second.front();
+
+    // Multiple writes to the same memref are allowed only on the same indices
+    for (const auto &othStoreIndices : write->second) {
+      if (othStoreIndices != storeIndices)
+        return WalkResult::interrupt();
+    }
+
+    // Check that the load indices of secondPloop coincide with store indices of
+    // firstPloop for the same memrefs.
+    auto loadIndices = load.getIndices();
+    if (storeIndices.size() != loadIndices.size())
+      return WalkResult::interrupt();
+    for (int i = 0, e = storeIndices.size(); i < e; ++i) {
+      if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
+          loadIndices[i]) {
+        auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
+        auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
+        if (storeIndexDefOp && loadIndexDefOp) {
+          if (!isMemoryEffectFree(storeIndexDefOp))
+            return WalkResult::interrupt();
+          if (!isMemoryEffectFree(loadIndexDefOp))
+            return WalkResult::interrupt();
+          if (!OperationEquivalence::isEquivalentTo(
+                  storeIndexDefOp, loadIndexDefOp,
+                  [&](Value storeIndex, Value loadIndex) {
+                    if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
+                        firstToSecondPloopIndices.lookupOrDefault(loadIndex))
+                      return failure();
+                    else
+                      return success();
+                  },
+                  /*markEquivalent=*/nullptr,
+                  OperationEquivalence::Flags::IgnoreLocations)) {
+            return WalkResult::interrupt();
+          }
+        } else
+          return WalkResult::interrupt();
+      }
+    }
+    return WalkResult::advance();
+  });
+  return !walkResult.wasInterrupted();
+}
+
+/// Analyzes dependencies in the most primitive way by checking simple read and
+/// write patterns.
+static LogicalResult
+verifyDependencies(scf::ParallelOp firstPloop, scf::ParallelOp secondPloop,
+                   const IRMapping &firstToSecondPloopIndices,
+                   llvm::function_ref<bool(Value, Value)> mayAlias) {
+  if (!haveNoReadsAfterWriteExceptSameIndex(
+          firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
+    return failure();
+
+  IRMapping secondToFirstPloopIndices;
+  secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
+                                firstPloop.getBody()->getArguments());
+  return success(haveNoReadsAfterWriteExceptSameIndex(
+      secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
+}
+
+/// Verify equal iteration spaces.
+static bool equalIterationSpaces(scf::ParallelOp firstPloop,
+                                 scf::ParallelOp secondPloop) {
+  if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
+    return false;
+
+  auto matchOperands = [&](const OperandRange &lhs,
+                           const OperandRange &rhs) -> bool {
+    // TODO: Extend this to support aliases and equal constants.
+    return std::equal(lhs.begin(), lhs.end(), rhs.begin());
+  };
+  return matchOperands(firstPloop.getLowerBound(),
+                       secondPloop.getLowerBound()) &&
+         matchOperands(firstPloop.getUpperBound(),
+                       secondPloop.getUpperBound()) &&
+         matchOperands(firstPloop.getStep(), secondPloop.getStep());
+}
+
+/// Verify there are no nested ParallelOps.
+static bool hasNestedParallelOp(scf::ParallelOp ploop) {
+  auto walkResult = ploop.getBody()->walk(
+      [](scf::ParallelOp) { return WalkResult::interrupt(); });
+  return walkResult.wasInterrupted();
+}
+
+static bool isFusionLegal(scf::ParallelOp firstPloop,
+                          scf::ParallelOp secondPloop,
+                          const IRMapping &firstToSecondPloopIndices,
+                          llvm::function_ref<bool(Value, Value)> mayAlias) {
+  return !hasNestedParallelOp(firstPloop) &&
+         !hasNestedParallelOp(secondPloop) &&
+         equalIterationSpaces(firstPloop, secondPloop) &&
+         succeeded(verifyDependencies(firstPloop, secondPloop,
+                                      firstToSecondPloopIndices, mayAlias));
+}
+
+void mlir::fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
+                       OpBuilder builder,
+                       llvm::function_ref<bool(Value, Value)> mayAlias) {
+  Block *block1 = firstPloop.getBody();
+  Block *block2 = secondPloop.getBody();
+  IRMapping firstToSecondPloopIndices;
+  firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
+
+  if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
+                     mayAlias))
+    return;
+
+  DominanceInfo dom;
+  // We are fusing first loop into second, make sure there are no users of the
+  // first loop results between loops.
+  for (Operation *user : firstPloop->getUsers())
+    if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
+      return;
+
+  ValueRange inits1 = firstPloop.getInitVals();
+  ValueRange inits2 = secondPloop.getInitVals();
+
+  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+  newInitVars.append(inits2.begin(), inits2.end());
+
+  IRRewriter b(builder);
+  b.setInsertionPoint(secondPloop);
+  auto newSecondPloop = b.create<scf::ParallelOp>(
+      secondPloop.getLoc(), secondPloop.getLowerBound(),
+      secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
+
+  Block *newBlock = newSecondPloop.getBody();
+  auto term1 = cast<scf::ReduceOp>(block1->getTerminator());
+  auto term2 = cast<scf::ReduceOp>(block2->getTerminator());
+
+  b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
+                      newBlock->getArguments());
+  b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
+                      newBlock->getArguments());
+
+  ValueRange results = newSecondPloop.getResults();
+  if (!results.empty()) {
+    b.setInsertionPointToEnd(newBlock);
+
+    ValueRange reduceArgs1 = term1.getOperands();
+    ValueRange reduceArgs2 = term2.getOperands();
+    SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
+    newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
+
+    auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
+
+    for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
+             term1.getReductions(), term2.getReductions()))) {
+      Block &oldRedBlock = reg.front();
+      Block &newRedBlock = newReduceOp.getReductions()[i].front();
+      b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
+                          newRedBlock.getArguments());
+    }
+
+    firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
+    secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
+  }
+  term1->erase();
+  term2->erase();
+  firstPloop.erase();
+  secondPloop.erase();
+  secondPloop = newSecondPloop;
+}
+
 scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
                                                       scf::ForallOp source,
                                                       RewriterBase &rewriter) {
@@ -1171,3 +1372,10 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
 
   return fusedLoop;
 }
+
+scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
+    scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) {
+  auto mayAlias = [&](Value val1, Value val2) -> bool { return false; };
+  mlir::fuseIfLegal(target, source, rewriter, mayAlias);
+  return source;
+}
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index 0f51b1cdbe0cf..46c6be36c3271 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -47,6 +47,59 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func @fuse_two_parallel
+// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+func.func @fuse_two_parallel(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+// CHECK-DAG:  [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG:  [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG:  [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG:  [[C1FP:%.*]] = arith.constant 1.
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1fp = arith.constant 1.0 : f32
+// CHECK:      [[SUM:%.*]] = memref.alloc()
+  %sum = memref.alloc()  : memref<2x2xf32>
+// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK:        [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
+// CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
+// CHECK:        [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
+// CHECK:        memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        scf.reduce
+// CHECK:      }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %c1fp : f32
+    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+// CHECK:      memref.dealloc [[SUM]]
+  memref.dealloc %sum : memref<2x2xf32>
+  return
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %parallel:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
 func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
   // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index

>From b73238a9472b0682f250e37848ad504d21a57059 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 5 Jun 2024 10:47:00 -0500
Subject: [PATCH 3/4] add checkFusionStructuralLegality

---
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h |  7 ++++++
 mlir/lib/Dialect/SCF/Utils/Utils.cpp        | 26 +++++++++++++++++++++
 2 files changed, 33 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 2944d8ffac022..834857f177cdf 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -156,6 +156,13 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
 void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
                              scf::ForOp root);
 
+//===----------------------------------------------------------------------===//
+// Fusion related helpers
+//===----------------------------------------------------------------------===//
+
+template <typename LoopTy>
+bool checkFusionStructuralLegality(Operation *target, Operation *source);
+
 /// Prepends operations of firstPloop's body into secondPloop's body.
 /// Updates secondPloop with new loop.
 void fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index d85339f32dbe3..c490983335470 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1184,6 +1184,10 @@ static bool equalIterationSpaces(scf::ParallelOp firstPloop,
          matchOperands(firstPloop.getStep(), secondPloop.getStep());
 }
 
+//===----------------------------------------------------------------------===//
+// Fusion related helpers
+//===----------------------------------------------------------------------===//
+
 /// Verify there are no nested ParallelOps.
 static bool hasNestedParallelOp(scf::ParallelOp ploop) {
   auto walkResult = ploop.getBody()->walk(
@@ -1191,6 +1195,28 @@ static bool hasNestedParallelOp(scf::ParallelOp ploop) {
   return walkResult.wasInterrupted();
 }
 
+template <typename LoopTy>
+static bool checkFusionStructuralLegality(Operation *target,
+                                             Operation *source) {
+  static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
+                                scf::ParallelOp>::value,
+                "applies to only `forall`, `for` and `parallel`");
+  auto targetOp = dyn_cast<LoopTy>(target);
+  auto sourceOp = dyn_cast<LoopTy>(source);
+  if (!targetOp || !sourceOp)
+    return false;
+
+  if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
+    return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
+           targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
+           targetOp.getMixedStep() == sourceOp.getMixedStep() &&
+           targetOp.getMapping() == sourceOp.getMapping();
+  else
+    return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
+           targetOp.getUpperBound() == sourceOp.getUpperBound() &&
+           targetOp.getStep() == sourceOp.getStep();
+}
+
 static bool isFusionLegal(scf::ParallelOp firstPloop,
                           scf::ParallelOp secondPloop,
                           const IRMapping &firstToSecondPloopIndices,

>From f5bbd131bb7713ae47a58587b2d9acf82dc3b12f Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 5 Jun 2024 14:19:56 -0500
Subject: [PATCH 4/4] replace isLoopWithIdenticalConfiguration with
 checkFusionStructuralLegality

---
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h   |  4 +-
 .../SCF/TransformOps/SCFTransformOps.cpp      | 50 ++++++-------------
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 32 +++++-------
 3 files changed, 29 insertions(+), 57 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 834857f177cdf..ab9d154aa480d 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -160,8 +160,8 @@ void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
 // Fusion related helpers
 //===----------------------------------------------------------------------===//
 
-template <typename LoopTy>
-bool checkFusionStructuralLegality(Operation *target, Operation *source);
+bool checkFusionStructuralLegality(LoopLikeOpInterface &target,
+                                   LoopLikeOpInterface &source);
 
 /// Prepends operations of firstPloop's body into secondPloop's body.
 /// Updates secondPloop with new loop.
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 1c53e89d69040..9f541b94af474 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -442,34 +442,6 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
   return DiagnosedSilenceableFailure::success();
 }
 
-/// Check if `target` scf loop can be fused into `source` scf loop.
-/// Applies for scf.for, scf.forall, and scf.parallel.
-///
-/// This simply checks if both loops have the same bounds, steps and mapping.
-/// No attempt is made at checking that the side effects of `target` and
-/// `source` are independent of each other.
-template <typename LoopTy>
-static bool isLoopWithIdenticalConfiguration(Operation *target,
-                                             Operation *source) {
-  static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
-                                scf::ParallelOp>::value,
-                "applies to only `forall`, `for` and `parallel`");
-  auto targetOp = dyn_cast<LoopTy>(target);
-  auto sourceOp = dyn_cast<LoopTy>(source);
-  if (!targetOp || !sourceOp)
-    return false;
-
-  if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
-    return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
-           targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
-           targetOp.getMixedStep() == sourceOp.getMixedStep() &&
-           targetOp.getMapping() == sourceOp.getMapping();
-  else
-    return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
-           targetOp.getUpperBound() == sourceOp.getUpperBound() &&
-           targetOp.getStep() == sourceOp.getStep();
-}
-
 DiagnosedSilenceableFailure
 transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
                                     transform::TransformResults &results,
@@ -485,29 +457,37 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
            << "source handle (got " << llvm::range_size(sourceOps) << ")";
   }
 
-  Operation *target = *targetOps.begin();
-  Operation *source = *sourceOps.begin();
+  LoopLikeOpInterface target =
+      dyn_cast<LoopLikeOpInterface>(*targetOps.begin());
+  LoopLikeOpInterface source =
+      dyn_cast<LoopLikeOpInterface>(*sourceOps.begin());
+  if (!target || !source)
+    return emitSilenceableFailure(target->getLoc())
+           << "target or source is not a loop op";
 
   // Check if the target and source are siblings.
   DiagnosedSilenceableFailure diag = isOpSibling(target, source);
   if (!diag.succeeded())
     return diag;
 
+  if (!mlir::checkFusionStructuralLegality(target, source))
+    return emitSilenceableFailure(target->getLoc())
+           << "operations cannot be fused";
+
   Operation *fusedLoop;
   /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
-  if (isLoopWithIdenticalConfiguration<scf::ForOp>(target, source)) {
+  if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
     fusedLoop = fuseIndependentSiblingForLoops(
         cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
-  } else if (isLoopWithIdenticalConfiguration<scf::ForallOp>(target, source)) {
+  } else if (isa<scf::ForallOp>(target) && isa<scf::ForallOp>(source)) {
     fusedLoop = fuseIndependentSiblingForallLoops(
         cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
-  } else if (isLoopWithIdenticalConfiguration<scf::ParallelOp>(target,
-                                                               source)) {
+  } else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
     fusedLoop = fuseIndependentSiblingParallelLoops(
         cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
   } else
     return emitSilenceableFailure(target->getLoc())
-           << "operations cannot be fused";
+           << "unsupported loop type for fusion";
 
   assert(fusedLoop && "failed to fuse operations");
 
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index c490983335470..ce20730459c2a 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1195,26 +1195,18 @@ static bool hasNestedParallelOp(scf::ParallelOp ploop) {
   return walkResult.wasInterrupted();
 }
 
-template <typename LoopTy>
-static bool checkFusionStructuralLegality(Operation *target,
-                                             Operation *source) {
-  static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
-                                scf::ParallelOp>::value,
-                "applies to only `forall`, `for` and `parallel`");
-  auto targetOp = dyn_cast<LoopTy>(target);
-  auto sourceOp = dyn_cast<LoopTy>(source);
-  if (!targetOp || !sourceOp)
-    return false;
-
-  if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
-    return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
-           targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
-           targetOp.getMixedStep() == sourceOp.getMixedStep() &&
-           targetOp.getMapping() == sourceOp.getMapping();
-  else
-    return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
-           targetOp.getUpperBound() == sourceOp.getUpperBound() &&
-           targetOp.getStep() == sourceOp.getStep();
+bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
+                                         LoopLikeOpInterface &source) {
+  auto iterSpaceEq =
+      target.getMixedLowerBound() == source.getMixedLowerBound() &&
+      target.getMixedUpperBound() == source.getMixedUpperBound() &&
+      target.getMixedStep() == source.getMixedStep();
+  auto forAllTarget = dyn_cast<scf::ForallOp>(*target);
+  auto forAllSource = dyn_cast<scf::ForallOp>(*source);
+  if (forAllTarget && forAllSource)
+    return iterSpaceEq &&
+           forAllTarget.getMapping() == forAllSource.getMapping();
+  return iterSpaceEq;
 }
 
 static bool isFusionLegal(scf::ParallelOp firstPloop,



More information about the Mlir-commits mailing list