[Mlir-commits] [mlir] Reland Refactor LoopFuseSiblingOp and support parallel fusion #94391 (PR #97607)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 3 09:56:47 PDT 2024


https://github.com/srcarroll created https://github.com/llvm/llvm-project/pull/97607

The refactor had a bug where the fused loop was inserted in an incorrect location.  This patch fixes the bug and relands the original PR https://github.com/llvm/llvm-project/pull/94391.

>From 5020e498440b0016adef7e99806aa55c4837b441 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 5 Jun 2024 13:08:22 -0500
Subject: [PATCH 01/34] Add getters for multi dim loop variables in
 LoopLikeOpInterface

---
 .../mlir/Dialect/Affine/IR/AffineOps.td       |  4 +-
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    | 37 ++--------
 .../mlir/Interfaces/LoopLikeInterface.td      | 65 +++++++++++------
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 20 +++---
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 70 +++++++------------
 .../Dialect/SCF/LoopLikeSCFOpsTest.cpp        |  8 +++
 6 files changed, 97 insertions(+), 107 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 3640055ea8da8..bb2c29b5733b8 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -118,8 +118,8 @@ def AffineForOp : Affine_Op<"for",
     [AttrSizedOperandSegments, AutomaticAllocationScope,
      ImplicitAffineTerminator, ConditionallySpeculatable,
      RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
-     ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
-      "getSingleUpperBound", "getYieldedValuesMutable",
+     ["getInductionVars", "getMixedLowerBound", "getMixedStep",
+      "getMixedUpperBound", "getYieldedValuesMutable",
       "replaceWithAdditionalYields"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
      ["getEntrySuccessorOperands"]>]> {
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 0b063aa772bab..3b28ca8b21d0f 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -136,8 +136,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
 def ForOp : SCF_Op<"for",
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
        ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
-        "getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
-        "getSingleUpperBound", "getYieldedValuesMutable",
+        "getInductionVars", "getMixedLowerBound", "getMixedStep",
+        "getMixedUpperBound", "getYieldedValuesMutable",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
         "yieldTiledValuesAndReplace"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
@@ -301,8 +301,8 @@ def ForallOp : SCF_Op<"forall", [
        AttrSizedOperandSegments,
        AutomaticAllocationScope,
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
-          ["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar", 
-           "getSingleLowerBound", "getSingleUpperBound", "getSingleStep",
+          ["getInitsMutable", "getRegionIterArgs", "getInductionVars", 
+           "getMixedLowerBound", "getMixedUpperBound", "getMixedStep",
            "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
        RecursiveMemoryEffects,
        SingleBlockImplicitTerminator<"scf::InParallelOp">,
@@ -510,24 +510,6 @@ def ForallOp : SCF_Op<"forall", [
   ];
 
   let extraClassDeclaration = [{
-    // Get lower bounds as OpFoldResult.
-    SmallVector<OpFoldResult> getMixedLowerBound() {
-      Builder b(getOperation()->getContext());
-      return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
-    }
-
-    // Get upper bounds as OpFoldResult.
-    SmallVector<OpFoldResult> getMixedUpperBound() {
-      Builder b(getOperation()->getContext());
-      return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
-    }
-
-    // Get steps as OpFoldResult.
-    SmallVector<OpFoldResult> getMixedStep() {
-      Builder b(getOperation()->getContext());
-      return getMixedValues(getStaticStep(), getDynamicStep(), b);
-    }
-
     /// Get lower bounds as values.
     SmallVector<Value> getLowerBound(OpBuilder &b) {
       return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedLowerBound());
@@ -584,10 +566,6 @@ def ForallOp : SCF_Op<"forall", [
                                     getNumDynamicControlOperands() + getRank());
     }
 
-    ::mlir::ValueRange getInductionVars() {
-      return getBody()->getArguments().take_front(getRank());
-    }
-
     ::mlir::Value getInductionVar(int64_t idx) {
       return getInductionVars()[idx];
     }
@@ -765,8 +743,8 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
 def ParallelOp : SCF_Op<"parallel",
     [AutomaticAllocationScope,
      AttrSizedOperandSegments,
-     DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getSingleInductionVar",
-          "getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
+     DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getInductionVars",
+          "getMixedLowerBound", "getMixedUpperBound", "getMixedStep"]>,
      RecursiveMemoryEffects,
      DeclareOpInterfaceMethods<RegionBranchOpInterface>,
      SingleBlockImplicitTerminator<"scf::ReduceOp">,
@@ -846,9 +824,6 @@ def ParallelOp : SCF_Op<"parallel",
   ];
 
   let extraClassDeclaration = [{
-    ValueRange getInductionVars() {
-      return getBody()->getArguments();
-    }
     unsigned getNumLoops() { return getStep().size(); }
     unsigned getNumReductions() { return getInitVals().size(); }
   }];
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index f0dc6e60eba58..813779c852027 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -93,51 +93,47 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       }]
     >,
     InterfaceMethod<[{
-        If there is a single induction variable return it, otherwise return
-        std::nullopt.
+        Return all induction variables.
       }],
-      /*retTy=*/"::std::optional<::mlir::Value>",
-      /*methodName=*/"getSingleInductionVar",
+      /*retTy=*/"::mlir::ValueRange",
+      /*methodName=*/"getInductionVars",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return std::nullopt;
+        return {};
       }]
     >,
     InterfaceMethod<[{
-        Return the single lower bound value or attribute if it exists, otherwise
-        return std::nullopt.
+        Return all lower bounds.
       }],
-      /*retTy=*/"::std::optional<::mlir::OpFoldResult>",
-      /*methodName=*/"getSingleLowerBound",
+      /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
+      /*methodName=*/"getMixedLowerBound",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return std::nullopt;
+        return {};
       }]
     >,
     InterfaceMethod<[{
-        Return the single step value or attribute if it exists, otherwise
-        return std::nullopt.
+        Return all steps.
       }],
-      /*retTy=*/"::std::optional<::mlir::OpFoldResult>",
-      /*methodName=*/"getSingleStep",
+      /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
+      /*methodName=*/"getMixedStep",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return std::nullopt;
+        return {};
       }]
     >,
     InterfaceMethod<[{
-        Return the single upper bound value or attribute if it exists, otherwise
-        return std::nullopt.
+        Return all upper bounds.
       }],
-      /*retTy=*/"::std::optional<::mlir::OpFoldResult>",
-      /*methodName=*/"getSingleUpperBound",
+      /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
+      /*methodName=*/"getMixedUpperBound",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return std::nullopt;
+        return {};
       }]
     >,
     InterfaceMethod<[{
@@ -235,6 +231,35 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
   }];
 
   let extraSharedClassDeclaration = [{
+    /// If there is a single induction variable return it, otherwise return
+    /// std::nullopt.
+    ::std::optional<::mlir::Value> getSingleInductionVar() {
+      if (this->getInductionVars().size() == 1)
+          return this->getInductionVars()[0];
+        return std::nullopt;
+    }
+    /// Return the single lower bound value or attribute if it exists, otherwise
+    /// return std::nullopt.
+    ::std::optional<::mlir::OpFoldResult> getSingleLowerBound() {
+      if (this->getMixedLowerBound().size() == 1)
+          return this->getMixedLowerBound()[0];
+        return std::nullopt;
+    }
+    /// Return the single step value or attribute if it exists, otherwise
+    /// return std::nullopt.
+    ::std::optional<::mlir::OpFoldResult> getSingleStep() {
+      if (this->getMixedStep().size() == 1)
+          return this->getMixedStep()[0];
+        return std::nullopt;
+    }
+    /// Return the single upper bound value or attribute if it exists, otherwise
+    /// return std::nullopt.
+    ::std::optional<::mlir::OpFoldResult> getSingleUpperBound() {
+      if (this->getMixedUpperBound().size() == 1)
+          return this->getMixedUpperBound()[0];
+        return std::nullopt;
+    }
+
     /// Append the specified additional "init" operands: replace this loop with
     /// a new loop that has the additional init operands. The loop body of this
     /// loop is moved over to the new loop.
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 2e31487bd55a0..746a9c919560c 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2454,27 +2454,25 @@ bool AffineForOp::matchingBoundOperandList() {
 
 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
 
-std::optional<Value> AffineForOp::getSingleInductionVar() {
-  return getInductionVar();
-}
+ValueRange AffineForOp::getInductionVars() { return {getInductionVar()}; }
 
-std::optional<OpFoldResult> AffineForOp::getSingleLowerBound() {
+SmallVector<OpFoldResult> AffineForOp::getMixedLowerBound() {
   if (!hasConstantLowerBound())
-    return std::nullopt;
+    return {};
   OpBuilder b(getContext());
-  return OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()));
+  return {OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
 }
 
-std::optional<OpFoldResult> AffineForOp::getSingleStep() {
+SmallVector<OpFoldResult> AffineForOp::getMixedStep() {
   OpBuilder b(getContext());
-  return OpFoldResult(b.getI64IntegerAttr(getStepAsInt()));
+  return {OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
 }
 
-std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
+SmallVector<OpFoldResult> AffineForOp::getMixedUpperBound() {
   if (!hasConstantUpperBound())
-    return std::nullopt;
+    return {};
   OpBuilder b(getContext());
-  return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
+  return {OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
 }
 
 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 107fd0690f193..e275ff1849c10 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -378,20 +378,18 @@ LogicalResult ForOp::verifyRegions() {
   return success();
 }
 
-std::optional<Value> ForOp::getSingleInductionVar() {
-  return getInductionVar();
-}
+ValueRange ForOp::getInductionVars() { return {getInductionVar()}; }
 
-std::optional<OpFoldResult> ForOp::getSingleLowerBound() {
-  return OpFoldResult(getLowerBound());
+SmallVector<OpFoldResult> ForOp::getMixedLowerBound() {
+  return {OpFoldResult(getLowerBound())};
 }
 
-std::optional<OpFoldResult> ForOp::getSingleStep() {
-  return OpFoldResult(getStep());
+SmallVector<OpFoldResult> ForOp::getMixedStep() {
+  return {OpFoldResult(getStep())};
 }
 
-std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
-  return OpFoldResult(getUpperBound());
+SmallVector<OpFoldResult> ForOp::getMixedUpperBound() {
+  return {OpFoldResult(getUpperBound())};
 }
 
 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
@@ -1428,28 +1426,26 @@ SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
   return storeOps;
 }
 
-std::optional<Value> ForallOp::getSingleInductionVar() {
-  if (getRank() != 1)
-    return std::nullopt;
-  return getInductionVar(0);
+ValueRange ForallOp::getInductionVars() {
+  return getBody()->getArguments().take_front(getRank());
 }
 
-std::optional<OpFoldResult> ForallOp::getSingleLowerBound() {
-  if (getRank() != 1)
-    return std::nullopt;
-  return getMixedLowerBound()[0];
+// Get lower bounds as OpFoldResult.
+SmallVector<OpFoldResult> ForallOp::getMixedLowerBound() {
+  Builder b(getOperation()->getContext());
+  return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
 }
 
-std::optional<OpFoldResult> ForallOp::getSingleUpperBound() {
-  if (getRank() != 1)
-    return std::nullopt;
-  return getMixedUpperBound()[0];
+// Get upper bounds as OpFoldResult.
+SmallVector<OpFoldResult> ForallOp::getMixedUpperBound() {
+  Builder b(getOperation()->getContext());
+  return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
 }
 
-std::optional<OpFoldResult> ForallOp::getSingleStep() {
-  if (getRank() != 1)
-    return std::nullopt;
-  return getMixedStep()[0];
+// Get steps as OpFoldResult.
+SmallVector<OpFoldResult> ForallOp::getMixedStep() {
+  Builder b(getOperation()->getContext());
+  return getMixedValues(getStaticStep(), getDynamicStep(), b);
 }
 
 ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) {
@@ -3008,29 +3004,17 @@ void ParallelOp::print(OpAsmPrinter &p) {
 
 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
 
-std::optional<Value> ParallelOp::getSingleInductionVar() {
-  if (getNumLoops() != 1)
-    return std::nullopt;
-  return getBody()->getArgument(0);
-}
+ValueRange ParallelOp::getInductionVars() { return getBody()->getArguments(); }
 
-std::optional<OpFoldResult> ParallelOp::getSingleLowerBound() {
-  if (getNumLoops() != 1)
-    return std::nullopt;
-  return getLowerBound()[0];
+SmallVector<OpFoldResult> ParallelOp::getMixedLowerBound() {
+  return getLowerBound();
 }
 
-std::optional<OpFoldResult> ParallelOp::getSingleUpperBound() {
-  if (getNumLoops() != 1)
-    return std::nullopt;
-  return getUpperBound()[0];
+SmallVector<OpFoldResult> ParallelOp::getMixedUpperBound() {
+  return getUpperBound();
 }
 
-std::optional<OpFoldResult> ParallelOp::getSingleStep() {
-  if (getNumLoops() != 1)
-    return std::nullopt;
-  return getStep()[0];
-}
+SmallVector<OpFoldResult> ParallelOp::getMixedStep() { return getStep(); }
 
 ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
   auto ivArg = llvm::dyn_cast<BlockArgument>(val);
diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
index 6bc0fd6113b9b..d8cdb213070da 100644
--- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
+++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
@@ -36,6 +36,10 @@ class SCFLoopLikeTest : public ::testing::Test {
     std::optional<OpFoldResult> maybeIndVar =
         loopLikeOp.getSingleInductionVar();
     EXPECT_TRUE(maybeIndVar.has_value());
+    EXPECT_EQ(loopLikeOp.getInductionVars().size(), 1u);
+    EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u);
+    EXPECT_EQ(loopLikeOp.getMixedStep().size(), 1u);
+    EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u);
   }
 
   void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
@@ -48,6 +52,10 @@ class SCFLoopLikeTest : public ::testing::Test {
     std::optional<OpFoldResult> maybeIndVar =
         loopLikeOp.getSingleInductionVar();
     EXPECT_FALSE(maybeIndVar.has_value());
+    EXPECT_EQ(loopLikeOp.getInductionVars().size(), 2u);
+    EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u);
+    EXPECT_EQ(loopLikeOp.getMixedStep().size(), 2u);
+    EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u);
   }
 
   MLIRContext context;

>From 50852d570440e0041c8b2b38925c4af05fac0636 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 4 Jun 2024 14:44:53 -0500
Subject: [PATCH 02/34] Refactor LoopFuseSiblingOp and support parallel fusion

---
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h   |  16 ++
 .../SCF/TransformOps/SCFTransformOps.cpp      |  53 +++--
 .../SCF/Transforms/ParallelLoopFusion.cpp     | 204 +----------------
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 208 ++++++++++++++++++
 .../SCF/transform-loop-fuse-sibling.mlir      |  53 +++++
 5 files changed, 304 insertions(+), 230 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index bc09cc7f7fa5e..2944d8ffac022 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -156,6 +156,12 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
 void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
                              scf::ForOp root);
 
+/// Prepends operations of firstPloop's body into secondPloop's body.
+/// Updates secondPloop with new loop.
+void fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
+                 OpBuilder builder,
+                 llvm::function_ref<bool(Value, Value)> mayAlias);
+
 /// Given two scf.forall loops, `target` and `source`, fuses `target` into
 /// `source`. Assumes that the given loops are siblings and are independent of
 /// each other.
@@ -177,6 +183,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
 scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
                                           RewriterBase &rewriter);
 
+/// Given two scf.parallel loops, `target` and `source`, fuses `target` into
+/// `source`. Assumes that the given loops are siblings and are independent of
+/// each other.
+///
+/// This function does not perform any legality checks and simply fuses the
+/// loops. The caller is responsible for ensuring that the loops are legal to
+/// fuse.
+scf::ParallelOp fuseIndependentSiblingParallelLoops(scf::ParallelOp target,
+                                                    scf::ParallelOp source,
+                                                    RewriterBase &rewriter);
 } // namespace mlir
 
 #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 69f83d8bd70da..1c53e89d69040 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -442,39 +442,32 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
   return DiagnosedSilenceableFailure::success();
 }
 
-/// Check if `target` scf.forall can be fused into `source` scf.forall.
+/// Check if `target` scf loop can be fused into `source` scf loop.
+/// Applies for scf.for, scf.forall, and scf.parallel.
 ///
 /// This simply checks if both loops have the same bounds, steps and mapping.
 /// No attempt is made at checking that the side effects of `target` and
 /// `source` are independent of each other.
-static bool isForallWithIdenticalConfiguration(Operation *target,
-                                               Operation *source) {
-  auto targetOp = dyn_cast<scf::ForallOp>(target);
-  auto sourceOp = dyn_cast<scf::ForallOp>(source);
-  if (!targetOp || !sourceOp)
-    return false;
-
-  return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
-         targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
-         targetOp.getMixedStep() == sourceOp.getMixedStep() &&
-         targetOp.getMapping() == sourceOp.getMapping();
-}
-
-/// Check if `target` scf.for can be fused into `source` scf.for.
-///
-/// This simply checks if both loops have the same bounds and steps. No attempt
-/// is made at checking that the side effects of `target` and `source` are
-/// independent of each other.
-static bool isForWithIdenticalConfiguration(Operation *target,
-                                            Operation *source) {
-  auto targetOp = dyn_cast<scf::ForOp>(target);
-  auto sourceOp = dyn_cast<scf::ForOp>(source);
+template <typename LoopTy>
+static bool isLoopWithIdenticalConfiguration(Operation *target,
+                                             Operation *source) {
+  static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
+                                scf::ParallelOp>::value,
+                "applies to only `forall`, `for` and `parallel`");
+  auto targetOp = dyn_cast<LoopTy>(target);
+  auto sourceOp = dyn_cast<LoopTy>(source);
   if (!targetOp || !sourceOp)
     return false;
 
-  return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
-         targetOp.getUpperBound() == sourceOp.getUpperBound() &&
-         targetOp.getStep() == sourceOp.getStep();
+  if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
+    return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
+           targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
+           targetOp.getMixedStep() == sourceOp.getMixedStep() &&
+           targetOp.getMapping() == sourceOp.getMapping();
+  else
+    return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
+           targetOp.getUpperBound() == sourceOp.getUpperBound() &&
+           targetOp.getStep() == sourceOp.getStep();
 }
 
 DiagnosedSilenceableFailure
@@ -502,12 +495,16 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
 
   Operation *fusedLoop;
   /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
-  if (isForWithIdenticalConfiguration(target, source)) {
+  if (isLoopWithIdenticalConfiguration<scf::ForOp>(target, source)) {
     fusedLoop = fuseIndependentSiblingForLoops(
         cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
-  } else if (isForallWithIdenticalConfiguration(target, source)) {
+  } else if (isLoopWithIdenticalConfiguration<scf::ForallOp>(target, source)) {
     fusedLoop = fuseIndependentSiblingForallLoops(
         cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
+  } else if (isLoopWithIdenticalConfiguration<scf::ParallelOp>(target,
+                                                               source)) {
+    fusedLoop = fuseIndependentSiblingParallelLoops(
+        cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
   } else
     return emitSilenceableFailure(target->getLoc())
            << "operations cannot be fused";
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 5934d85373b03..abac91cfaf7d9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/OpDefinition.h"
@@ -30,207 +31,6 @@ namespace mlir {
 using namespace mlir;
 using namespace mlir::scf;
 
-/// Verify there are no nested ParallelOps.
-static bool hasNestedParallelOp(ParallelOp ploop) {
-  auto walkResult =
-      ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
-  return walkResult.wasInterrupted();
-}
-
-/// Verify equal iteration spaces.
-static bool equalIterationSpaces(ParallelOp firstPloop,
-                                 ParallelOp secondPloop) {
-  if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
-    return false;
-
-  auto matchOperands = [&](const OperandRange &lhs,
-                           const OperandRange &rhs) -> bool {
-    // TODO: Extend this to support aliases and equal constants.
-    return std::equal(lhs.begin(), lhs.end(), rhs.begin());
-  };
-  return matchOperands(firstPloop.getLowerBound(),
-                       secondPloop.getLowerBound()) &&
-         matchOperands(firstPloop.getUpperBound(),
-                       secondPloop.getUpperBound()) &&
-         matchOperands(firstPloop.getStep(), secondPloop.getStep());
-}
-
-/// Checks if the parallel loops have mixed access to the same buffers. Returns
-/// `true` if the first parallel loop writes to the same indices that the second
-/// loop reads.
-static bool haveNoReadsAfterWriteExceptSameIndex(
-    ParallelOp firstPloop, ParallelOp secondPloop,
-    const IRMapping &firstToSecondPloopIndices,
-    llvm::function_ref<bool(Value, Value)> mayAlias) {
-  DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
-  SmallVector<Value> bufferStoresVec;
-  firstPloop.getBody()->walk([&](memref::StoreOp store) {
-    bufferStores[store.getMemRef()].push_back(store.getIndices());
-    bufferStoresVec.emplace_back(store.getMemRef());
-  });
-  auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
-    Value loadMem = load.getMemRef();
-    // Stop if the memref is defined in secondPloop body. Careful alias analysis
-    // is needed.
-    auto *memrefDef = loadMem.getDefiningOp();
-    if (memrefDef && memrefDef->getBlock() == load->getBlock())
-      return WalkResult::interrupt();
-
-    for (Value store : bufferStoresVec)
-      if (store != loadMem && mayAlias(store, loadMem))
-        return WalkResult::interrupt();
-
-    auto write = bufferStores.find(loadMem);
-    if (write == bufferStores.end())
-      return WalkResult::advance();
-
-    // Check that at last one store was retrieved
-    if (!write->second.size())
-      return WalkResult::interrupt();
-
-    auto storeIndices = write->second.front();
-
-    // Multiple writes to the same memref are allowed only on the same indices
-    for (const auto &othStoreIndices : write->second) {
-      if (othStoreIndices != storeIndices)
-        return WalkResult::interrupt();
-    }
-
-    // Check that the load indices of secondPloop coincide with store indices of
-    // firstPloop for the same memrefs.
-    auto loadIndices = load.getIndices();
-    if (storeIndices.size() != loadIndices.size())
-      return WalkResult::interrupt();
-    for (int i = 0, e = storeIndices.size(); i < e; ++i) {
-      if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
-          loadIndices[i]) {
-        auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
-        auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
-        if (storeIndexDefOp && loadIndexDefOp) {
-          if (!isMemoryEffectFree(storeIndexDefOp))
-            return WalkResult::interrupt();
-          if (!isMemoryEffectFree(loadIndexDefOp))
-            return WalkResult::interrupt();
-          if (!OperationEquivalence::isEquivalentTo(
-                  storeIndexDefOp, loadIndexDefOp,
-                  [&](Value storeIndex, Value loadIndex) {
-                    if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
-                        firstToSecondPloopIndices.lookupOrDefault(loadIndex))
-                      return failure();
-                    else
-                      return success();
-                  },
-                  /*markEquivalent=*/nullptr,
-                  OperationEquivalence::Flags::IgnoreLocations)) {
-            return WalkResult::interrupt();
-          }
-        } else
-          return WalkResult::interrupt();
-      }
-    }
-    return WalkResult::advance();
-  });
-  return !walkResult.wasInterrupted();
-}
-
-/// Analyzes dependencies in the most primitive way by checking simple read and
-/// write patterns.
-static LogicalResult
-verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
-                   const IRMapping &firstToSecondPloopIndices,
-                   llvm::function_ref<bool(Value, Value)> mayAlias) {
-  if (!haveNoReadsAfterWriteExceptSameIndex(
-          firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
-    return failure();
-
-  IRMapping secondToFirstPloopIndices;
-  secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
-                                firstPloop.getBody()->getArguments());
-  return success(haveNoReadsAfterWriteExceptSameIndex(
-      secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
-}
-
-static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
-                          const IRMapping &firstToSecondPloopIndices,
-                          llvm::function_ref<bool(Value, Value)> mayAlias) {
-  return !hasNestedParallelOp(firstPloop) &&
-         !hasNestedParallelOp(secondPloop) &&
-         equalIterationSpaces(firstPloop, secondPloop) &&
-         succeeded(verifyDependencies(firstPloop, secondPloop,
-                                      firstToSecondPloopIndices, mayAlias));
-}
-
-/// Prepends operations of firstPloop's body into secondPloop's body.
-/// Updates secondPloop with new loop.
-static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
-                        OpBuilder builder,
-                        llvm::function_ref<bool(Value, Value)> mayAlias) {
-  Block *block1 = firstPloop.getBody();
-  Block *block2 = secondPloop.getBody();
-  IRMapping firstToSecondPloopIndices;
-  firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
-
-  if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
-                     mayAlias))
-    return;
-
-  DominanceInfo dom;
-  // We are fusing first loop into second, make sure there are no users of the
-  // first loop results between loops.
-  for (Operation *user : firstPloop->getUsers())
-    if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
-      return;
-
-  ValueRange inits1 = firstPloop.getInitVals();
-  ValueRange inits2 = secondPloop.getInitVals();
-
-  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
-  newInitVars.append(inits2.begin(), inits2.end());
-
-  IRRewriter b(builder);
-  b.setInsertionPoint(secondPloop);
-  auto newSecondPloop = b.create<ParallelOp>(
-      secondPloop.getLoc(), secondPloop.getLowerBound(),
-      secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
-
-  Block *newBlock = newSecondPloop.getBody();
-  auto term1 = cast<ReduceOp>(block1->getTerminator());
-  auto term2 = cast<ReduceOp>(block2->getTerminator());
-
-  b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
-                      newBlock->getArguments());
-  b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
-                      newBlock->getArguments());
-
-  ValueRange results = newSecondPloop.getResults();
-  if (!results.empty()) {
-    b.setInsertionPointToEnd(newBlock);
-
-    ValueRange reduceArgs1 = term1.getOperands();
-    ValueRange reduceArgs2 = term2.getOperands();
-    SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
-    newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
-
-    auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
-
-    for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
-             term1.getReductions(), term2.getReductions()))) {
-      Block &oldRedBlock = reg.front();
-      Block &newRedBlock = newReduceOp.getReductions()[i].front();
-      b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
-                          newRedBlock.getArguments());
-    }
-
-    firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
-    secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
-  }
-  term1->erase();
-  term2->erase();
-  firstPloop.erase();
-  secondPloop.erase();
-  secondPloop = newSecondPloop;
-}
-
 void mlir::scf::naivelyFuseParallelOps(
     Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
   OpBuilder b(region);
@@ -259,7 +59,7 @@ void mlir::scf::naivelyFuseParallelOps(
     }
     for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
-        fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
+        mlir::fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
     }
   }
 }
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 6658cca03eba7..d85339f32dbe3 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
@@ -1070,6 +1071,206 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
   return tileLoops;
 }
 
+/// Checks if the parallel loops have mixed access to the same buffers. Returns
+/// `true` if the first parallel loop writes to the same indices that the second
+/// loop reads.
+static bool haveNoReadsAfterWriteExceptSameIndex(
+    scf::ParallelOp firstPloop, scf::ParallelOp secondPloop,
+    const IRMapping &firstToSecondPloopIndices,
+    llvm::function_ref<bool(Value, Value)> mayAlias) {
+  DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
+  SmallVector<Value> bufferStoresVec;
+  firstPloop.getBody()->walk([&](memref::StoreOp store) {
+    bufferStores[store.getMemRef()].push_back(store.getIndices());
+    bufferStoresVec.emplace_back(store.getMemRef());
+  });
+  auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
+    Value loadMem = load.getMemRef();
+    // Stop if the memref is defined in secondPloop body. Careful alias analysis
+    // is needed.
+    auto *memrefDef = loadMem.getDefiningOp();
+    if (memrefDef && memrefDef->getBlock() == load->getBlock())
+      return WalkResult::interrupt();
+
+    for (Value store : bufferStoresVec)
+      if (store != loadMem && mayAlias(store, loadMem))
+        return WalkResult::interrupt();
+
+    auto write = bufferStores.find(loadMem);
+    if (write == bufferStores.end())
+      return WalkResult::advance();
+
+    // Check that at last one store was retrieved
+    if (!write->second.size())
+      return WalkResult::interrupt();
+
+    auto storeIndices = write->second.front();
+
+    // Multiple writes to the same memref are allowed only on the same indices
+    for (const auto &othStoreIndices : write->second) {
+      if (othStoreIndices != storeIndices)
+        return WalkResult::interrupt();
+    }
+
+    // Check that the load indices of secondPloop coincide with store indices of
+    // firstPloop for the same memrefs.
+    auto loadIndices = load.getIndices();
+    if (storeIndices.size() != loadIndices.size())
+      return WalkResult::interrupt();
+    for (int i = 0, e = storeIndices.size(); i < e; ++i) {
+      if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
+          loadIndices[i]) {
+        auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
+        auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
+        if (storeIndexDefOp && loadIndexDefOp) {
+          if (!isMemoryEffectFree(storeIndexDefOp))
+            return WalkResult::interrupt();
+          if (!isMemoryEffectFree(loadIndexDefOp))
+            return WalkResult::interrupt();
+          if (!OperationEquivalence::isEquivalentTo(
+                  storeIndexDefOp, loadIndexDefOp,
+                  [&](Value storeIndex, Value loadIndex) {
+                    if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
+                        firstToSecondPloopIndices.lookupOrDefault(loadIndex))
+                      return failure();
+                    else
+                      return success();
+                  },
+                  /*markEquivalent=*/nullptr,
+                  OperationEquivalence::Flags::IgnoreLocations)) {
+            return WalkResult::interrupt();
+          }
+        } else
+          return WalkResult::interrupt();
+      }
+    }
+    return WalkResult::advance();
+  });
+  return !walkResult.wasInterrupted();
+}
+
+/// Analyzes dependencies in the most primitive way by checking simple read and
+/// write patterns.
+static LogicalResult
+verifyDependencies(scf::ParallelOp firstPloop, scf::ParallelOp secondPloop,
+                   const IRMapping &firstToSecondPloopIndices,
+                   llvm::function_ref<bool(Value, Value)> mayAlias) {
+  if (!haveNoReadsAfterWriteExceptSameIndex(
+          firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
+    return failure();
+
+  IRMapping secondToFirstPloopIndices;
+  secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
+                                firstPloop.getBody()->getArguments());
+  return success(haveNoReadsAfterWriteExceptSameIndex(
+      secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
+}
+
+/// Verify equal iteration spaces.
+static bool equalIterationSpaces(scf::ParallelOp firstPloop,
+                                 scf::ParallelOp secondPloop) {
+  if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
+    return false;
+
+  auto matchOperands = [&](const OperandRange &lhs,
+                           const OperandRange &rhs) -> bool {
+    // TODO: Extend this to support aliases and equal constants.
+    return std::equal(lhs.begin(), lhs.end(), rhs.begin());
+  };
+  return matchOperands(firstPloop.getLowerBound(),
+                       secondPloop.getLowerBound()) &&
+         matchOperands(firstPloop.getUpperBound(),
+                       secondPloop.getUpperBound()) &&
+         matchOperands(firstPloop.getStep(), secondPloop.getStep());
+}
+
+/// Verify there are no nested ParallelOps.
+static bool hasNestedParallelOp(scf::ParallelOp ploop) {
+  auto walkResult = ploop.getBody()->walk(
+      [](scf::ParallelOp) { return WalkResult::interrupt(); });
+  return walkResult.wasInterrupted();
+}
+
+static bool isFusionLegal(scf::ParallelOp firstPloop,
+                          scf::ParallelOp secondPloop,
+                          const IRMapping &firstToSecondPloopIndices,
+                          llvm::function_ref<bool(Value, Value)> mayAlias) {
+  return !hasNestedParallelOp(firstPloop) &&
+         !hasNestedParallelOp(secondPloop) &&
+         equalIterationSpaces(firstPloop, secondPloop) &&
+         succeeded(verifyDependencies(firstPloop, secondPloop,
+                                      firstToSecondPloopIndices, mayAlias));
+}
+
+void mlir::fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
+                       OpBuilder builder,
+                       llvm::function_ref<bool(Value, Value)> mayAlias) {
+  Block *block1 = firstPloop.getBody();
+  Block *block2 = secondPloop.getBody();
+  IRMapping firstToSecondPloopIndices;
+  firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
+
+  if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
+                     mayAlias))
+    return;
+
+  DominanceInfo dom;
+  // We are fusing first loop into second, make sure there are no users of the
+  // first loop results between loops.
+  for (Operation *user : firstPloop->getUsers())
+    if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
+      return;
+
+  ValueRange inits1 = firstPloop.getInitVals();
+  ValueRange inits2 = secondPloop.getInitVals();
+
+  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+  newInitVars.append(inits2.begin(), inits2.end());
+
+  IRRewriter b(builder);
+  b.setInsertionPoint(secondPloop);
+  auto newSecondPloop = b.create<scf::ParallelOp>(
+      secondPloop.getLoc(), secondPloop.getLowerBound(),
+      secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
+
+  Block *newBlock = newSecondPloop.getBody();
+  auto term1 = cast<scf::ReduceOp>(block1->getTerminator());
+  auto term2 = cast<scf::ReduceOp>(block2->getTerminator());
+
+  b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
+                      newBlock->getArguments());
+  b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
+                      newBlock->getArguments());
+
+  ValueRange results = newSecondPloop.getResults();
+  if (!results.empty()) {
+    b.setInsertionPointToEnd(newBlock);
+
+    ValueRange reduceArgs1 = term1.getOperands();
+    ValueRange reduceArgs2 = term2.getOperands();
+    SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
+    newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
+
+    auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
+
+    for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
+             term1.getReductions(), term2.getReductions()))) {
+      Block &oldRedBlock = reg.front();
+      Block &newRedBlock = newReduceOp.getReductions()[i].front();
+      b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
+                          newRedBlock.getArguments());
+    }
+
+    firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
+    secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
+  }
+  term1->erase();
+  term2->erase();
+  firstPloop.erase();
+  secondPloop.erase();
+  secondPloop = newSecondPloop;
+}
+
 scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
                                                       scf::ForallOp source,
                                                       RewriterBase &rewriter) {
@@ -1171,3 +1372,10 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
 
   return fusedLoop;
 }
+
+scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
+    scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) {
+  auto mayAlias = [&](Value val1, Value val2) -> bool { return false; };
+  mlir::fuseIfLegal(target, source, rewriter, mayAlias);
+  return source;
+}
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index 0f51b1cdbe0cf..46c6be36c3271 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -47,6 +47,59 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func @fuse_two_parallel
+// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+func.func @fuse_two_parallel(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+// CHECK-DAG:  [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG:  [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG:  [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG:  [[C1FP:%.*]] = arith.constant 1.
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1fp = arith.constant 1.0 : f32
+// CHECK:      [[SUM:%.*]] = memref.alloc()
+  %sum = memref.alloc()  : memref<2x2xf32>
+// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK:        [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
+// CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
+// CHECK:        [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
+// CHECK:        memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        scf.reduce
+// CHECK:      }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %c1fp : f32
+    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+// CHECK:      memref.dealloc [[SUM]]
+  memref.dealloc %sum : memref<2x2xf32>
+  return
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %parallel:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
 func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
   // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index

>From b73238a9472b0682f250e37848ad504d21a57059 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 5 Jun 2024 10:47:00 -0500
Subject: [PATCH 03/34] add checkFusionStructuralLegality

---
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h |  7 ++++++
 mlir/lib/Dialect/SCF/Utils/Utils.cpp        | 26 +++++++++++++++++++++
 2 files changed, 33 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 2944d8ffac022..834857f177cdf 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -156,6 +156,13 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
 void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
                              scf::ForOp root);
 
+//===----------------------------------------------------------------------===//
+// Fusion related helpers
+//===----------------------------------------------------------------------===//
+
+template <typename LoopTy>
+bool checkFusionStructuralLegality(Operation *target, Operation *source);
+
 /// Prepends operations of firstPloop's body into secondPloop's body.
 /// Updates secondPloop with new loop.
 void fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index d85339f32dbe3..c490983335470 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1184,6 +1184,10 @@ static bool equalIterationSpaces(scf::ParallelOp firstPloop,
          matchOperands(firstPloop.getStep(), secondPloop.getStep());
 }
 
+//===----------------------------------------------------------------------===//
+// Fusion related helpers
+//===----------------------------------------------------------------------===//
+
 /// Verify there are no nested ParallelOps.
 static bool hasNestedParallelOp(scf::ParallelOp ploop) {
   auto walkResult = ploop.getBody()->walk(
@@ -1191,6 +1195,28 @@ static bool hasNestedParallelOp(scf::ParallelOp ploop) {
   return walkResult.wasInterrupted();
 }
 
+template <typename LoopTy>
+static bool checkFusionStructuralLegality(Operation *target,
+                                             Operation *source) {
+  static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
+                                scf::ParallelOp>::value,
+                "applies to only `forall`, `for` and `parallel`");
+  auto targetOp = dyn_cast<LoopTy>(target);
+  auto sourceOp = dyn_cast<LoopTy>(source);
+  if (!targetOp || !sourceOp)
+    return false;
+
+  if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
+    return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
+           targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
+           targetOp.getMixedStep() == sourceOp.getMixedStep() &&
+           targetOp.getMapping() == sourceOp.getMapping();
+  else
+    return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
+           targetOp.getUpperBound() == sourceOp.getUpperBound() &&
+           targetOp.getStep() == sourceOp.getStep();
+}
+
 static bool isFusionLegal(scf::ParallelOp firstPloop,
                           scf::ParallelOp secondPloop,
                           const IRMapping &firstToSecondPloopIndices,

>From f5bbd131bb7713ae47a58587b2d9acf82dc3b12f Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 5 Jun 2024 14:19:56 -0500
Subject: [PATCH 04/34] replace isLoopWithIdenticalConfiguration with
 checkFusionStructuralLegality

---
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h   |  4 +-
 .../SCF/TransformOps/SCFTransformOps.cpp      | 50 ++++++-------------
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 32 +++++-------
 3 files changed, 29 insertions(+), 57 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 834857f177cdf..ab9d154aa480d 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -160,8 +160,8 @@ void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
 // Fusion related helpers
 //===----------------------------------------------------------------------===//
 
-template <typename LoopTy>
-bool checkFusionStructuralLegality(Operation *target, Operation *source);
+bool checkFusionStructuralLegality(LoopLikeOpInterface &target,
+                                   LoopLikeOpInterface &source);
 
 /// Prepends operations of firstPloop's body into secondPloop's body.
 /// Updates secondPloop with new loop.
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 1c53e89d69040..9f541b94af474 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -442,34 +442,6 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
   return DiagnosedSilenceableFailure::success();
 }
 
-/// Check if `target` scf loop can be fused into `source` scf loop.
-/// Applies for scf.for, scf.forall, and scf.parallel.
-///
-/// This simply checks if both loops have the same bounds, steps and mapping.
-/// No attempt is made at checking that the side effects of `target` and
-/// `source` are independent of each other.
-template <typename LoopTy>
-static bool isLoopWithIdenticalConfiguration(Operation *target,
-                                             Operation *source) {
-  static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
-                                scf::ParallelOp>::value,
-                "applies to only `forall`, `for` and `parallel`");
-  auto targetOp = dyn_cast<LoopTy>(target);
-  auto sourceOp = dyn_cast<LoopTy>(source);
-  if (!targetOp || !sourceOp)
-    return false;
-
-  if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
-    return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
-           targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
-           targetOp.getMixedStep() == sourceOp.getMixedStep() &&
-           targetOp.getMapping() == sourceOp.getMapping();
-  else
-    return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
-           targetOp.getUpperBound() == sourceOp.getUpperBound() &&
-           targetOp.getStep() == sourceOp.getStep();
-}
-
 DiagnosedSilenceableFailure
 transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
                                     transform::TransformResults &results,
@@ -485,29 +457,37 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
            << "source handle (got " << llvm::range_size(sourceOps) << ")";
   }
 
-  Operation *target = *targetOps.begin();
-  Operation *source = *sourceOps.begin();
+  LoopLikeOpInterface target =
+      dyn_cast<LoopLikeOpInterface>(*targetOps.begin());
+  LoopLikeOpInterface source =
+      dyn_cast<LoopLikeOpInterface>(*sourceOps.begin());
+  if (!target || !source)
+    return emitSilenceableFailure(target->getLoc())
+           << "target or source is not a loop op";
 
   // Check if the target and source are siblings.
   DiagnosedSilenceableFailure diag = isOpSibling(target, source);
   if (!diag.succeeded())
     return diag;
 
+  if (!mlir::checkFusionStructuralLegality(target, source))
+    return emitSilenceableFailure(target->getLoc())
+           << "operations cannot be fused";
+
   Operation *fusedLoop;
   /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
-  if (isLoopWithIdenticalConfiguration<scf::ForOp>(target, source)) {
+  if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
     fusedLoop = fuseIndependentSiblingForLoops(
         cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
-  } else if (isLoopWithIdenticalConfiguration<scf::ForallOp>(target, source)) {
+  } else if (isa<scf::ForallOp>(target) && isa<scf::ForallOp>(source)) {
     fusedLoop = fuseIndependentSiblingForallLoops(
         cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
-  } else if (isLoopWithIdenticalConfiguration<scf::ParallelOp>(target,
-                                                               source)) {
+  } else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
     fusedLoop = fuseIndependentSiblingParallelLoops(
         cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
   } else
     return emitSilenceableFailure(target->getLoc())
-           << "operations cannot be fused";
+           << "unsupported loop type for fusion";
 
   assert(fusedLoop && "failed to fuse operations");
 
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index c490983335470..ce20730459c2a 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1195,26 +1195,18 @@ static bool hasNestedParallelOp(scf::ParallelOp ploop) {
   return walkResult.wasInterrupted();
 }
 
-template <typename LoopTy>
-static bool checkFusionStructuralLegality(Operation *target,
-                                             Operation *source) {
-  static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
-                                scf::ParallelOp>::value,
-                "applies to only `forall`, `for` and `parallel`");
-  auto targetOp = dyn_cast<LoopTy>(target);
-  auto sourceOp = dyn_cast<LoopTy>(source);
-  if (!targetOp || !sourceOp)
-    return false;
-
-  if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
-    return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
-           targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
-           targetOp.getMixedStep() == sourceOp.getMixedStep() &&
-           targetOp.getMapping() == sourceOp.getMapping();
-  else
-    return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
-           targetOp.getUpperBound() == sourceOp.getUpperBound() &&
-           targetOp.getStep() == sourceOp.getStep();
+bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
+                                         LoopLikeOpInterface &source) {
+  auto iterSpaceEq =
+      target.getMixedLowerBound() == source.getMixedLowerBound() &&
+      target.getMixedUpperBound() == source.getMixedUpperBound() &&
+      target.getMixedStep() == source.getMixedStep();
+  auto forAllTarget = dyn_cast<scf::ForallOp>(*target);
+  auto forAllSource = dyn_cast<scf::ForallOp>(*source);
+  if (forAllTarget && forAllSource)
+    return iterSpaceEq &&
+           forAllTarget.getMapping() == forAllSource.getMapping();
+  return iterSpaceEq;
 }
 
 static bool isFusionLegal(scf::ParallelOp firstPloop,

>From 7d995815064cb25e47ec8e400de3692fbe5fdfba Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 5 Jun 2024 14:37:57 -0500
Subject: [PATCH 05/34] address review comment

---
 .../mlir/Interfaces/LoopLikeInterface.td      | 20 +++++++++++--------
 1 file changed, 12 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 813779c852027..5cf3eba0bd9ed 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -234,29 +234,33 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     /// If there is a single induction variable return it, otherwise return
     /// std::nullopt.
     ::std::optional<::mlir::Value> getSingleInductionVar() {
-      if (this->getInductionVars().size() == 1)
-          return this->getInductionVars()[0];
+      auto inductionVars = this->getInductionVars();
+      if (inductionVars.size() == 1)
+          return inductionVars[0];
         return std::nullopt;
     }
     /// Return the single lower bound value or attribute if it exists, otherwise
     /// return std::nullopt.
     ::std::optional<::mlir::OpFoldResult> getSingleLowerBound() {
-      if (this->getMixedLowerBound().size() == 1)
-          return this->getMixedLowerBound()[0];
+      auto lowerBounds = this->getMixedLowerBound();
+      if (lowerBounds.size() == 1)
+          return lowerBounds[0];
         return std::nullopt;
     }
     /// Return the single step value or attribute if it exists, otherwise
     /// return std::nullopt.
     ::std::optional<::mlir::OpFoldResult> getSingleStep() {
-      if (this->getMixedStep().size() == 1)
-          return this->getMixedStep()[0];
+      auto steps = this->getMixedStep(); 
+      if (steps.size() == 1)
+          return steps[0];
         return std::nullopt;
     }
     /// Return the single upper bound value or attribute if it exists, otherwise
     /// return std::nullopt.
     ::std::optional<::mlir::OpFoldResult> getSingleUpperBound() {
-      if (this->getMixedUpperBound().size() == 1)
-          return this->getMixedUpperBound()[0];
+      auto upperBounds = this->getMixedUpperBound();
+      if (upperBounds.size() == 1)
+          return upperBounds[0];
         return std::nullopt;
     }
 

>From a5fa3b3c4903c344847ee544cd9812b6f0c70571 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 5 Jun 2024 20:56:02 -0500
Subject: [PATCH 06/34] Make return types optional and change names

---
 .../mlir/Dialect/Affine/IR/AffineOps.td       | 10 +--
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    | 29 +++++++--
 .../mlir/Interfaces/LoopLikeInterface.td      | 42 ++++++-------
 .../AffineToStandard/AffineToStandard.cpp     |  4 +-
 .../SCFToControlFlow/SCFToControlFlow.cpp     |  9 +--
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 27 ++++----
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       |  2 +-
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 26 ++++----
 .../Dialect/SCF/Transforms/ForallToFor.cpp    |  9 +--
 .../Dialect/SCF/LoopLikeSCFOpsTest.cpp        | 62 +++++++++++++------
 10 files changed, 131 insertions(+), 89 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index bb2c29b5733b8..4c032e66f7a83 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -118,8 +118,8 @@ def AffineForOp : Affine_Op<"for",
     [AttrSizedOperandSegments, AutomaticAllocationScope,
      ImplicitAffineTerminator, ConditionallySpeculatable,
      RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
-     ["getInductionVars", "getMixedLowerBound", "getMixedStep",
-      "getMixedUpperBound", "getYieldedValuesMutable",
+     ["getInductionVars", "getLowerBounds", "getSteps",
+      "getUpperBounds", "getYieldedValuesMutable",
       "replaceWithAdditionalYields"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
      ["getEntrySuccessorOperands"]>]> {
@@ -671,7 +671,7 @@ def AffineParallelOp : Affine_Op<"parallel",
      I32ElementsAttr:$lowerBoundsGroups,
      AffineMapAttr:$upperBoundsMap,
      I32ElementsAttr:$upperBoundsGroups,
-     I64SmallVectorArrayAttr:$steps,
+     I64SmallVectorArrayAttr:$step,
      Variadic<Index>:$mapOperands);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
@@ -682,7 +682,7 @@ def AffineParallelOp : Affine_Op<"parallel",
     OpBuilder<(ins "TypeRange":$resultTypes,
       "ArrayRef<arith::AtomicRMWKind>":$reductions, "ArrayRef<AffineMap>":$lbMaps,
       "ValueRange":$lbArgs, "ArrayRef<AffineMap>":$ubMaps, "ValueRange":$ubArgs,
-      "ArrayRef<int64_t>":$steps)>
+      "ArrayRef<int64_t>":$step)>
   ];
 
   let extraClassDeclaration = [{
@@ -727,7 +727,7 @@ def AffineParallelOp : Affine_Op<"parallel",
     static StringRef getUpperBoundsGroupsAttrStrName() {
       return "upperBoundsGroups";
     }
-    static StringRef getStepsAttrStrName() { return "steps"; }
+    static StringRef getStepsAttrStrName() { return "step"; }
 
     /// Returns `true` if the loop bounds have min/max expressions.
     bool hasMinMaxBounds() {
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 3b28ca8b21d0f..66b478f141b32 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -136,8 +136,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
 def ForOp : SCF_Op<"for",
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
        ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
-        "getInductionVars", "getMixedLowerBound", "getMixedStep",
-        "getMixedUpperBound", "getYieldedValuesMutable",
+        "getInductionVars", "getLowerBounds", "getSteps",
+        "getUpperBounds", "getYieldedValuesMutable",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
         "yieldTiledValuesAndReplace"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
@@ -302,7 +302,7 @@ def ForallOp : SCF_Op<"forall", [
        AutomaticAllocationScope,
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
           ["getInitsMutable", "getRegionIterArgs", "getInductionVars", 
-           "getMixedLowerBound", "getMixedUpperBound", "getMixedStep",
+           "getLowerBounds", "getUpperBounds", "getSteps",
            "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
        RecursiveMemoryEffects,
        SingleBlockImplicitTerminator<"scf::InParallelOp">,
@@ -510,6 +510,27 @@ def ForallOp : SCF_Op<"forall", [
   ];
 
   let extraClassDeclaration = [{
+    // Get lower bounds as OpFoldResult.
+    SmallVector<OpFoldResult> getMixedLowerBound() {
+      auto maybeLowerBounds = getLowerBounds();
+      assert(maybeLowerBounds.has_value() && "expected values");
+      return *maybeLowerBounds;
+    }
+
+    // Get upper bounds as OpFoldResult.
+    SmallVector<OpFoldResult> getMixedUpperBound() {
+      auto maybeUpperBounds = getUpperBounds();
+      assert(maybeUpperBounds.has_value() && "expected values");
+      return *maybeUpperBounds;
+    }
+
+    // Get steps as OpFoldResult.
+    SmallVector<OpFoldResult> getMixedStep() {
+      auto maybeSteps = getSteps();
+      assert(maybeSteps.has_value() && "expected values");
+      return *maybeSteps;
+    }
+
     /// Get lower bounds as values.
     SmallVector<Value> getLowerBound(OpBuilder &b) {
       return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedLowerBound());
@@ -744,7 +765,7 @@ def ParallelOp : SCF_Op<"parallel",
     [AutomaticAllocationScope,
      AttrSizedOperandSegments,
      DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getInductionVars",
-          "getMixedLowerBound", "getMixedUpperBound", "getMixedStep"]>,
+          "getLowerBounds", "getUpperBounds", "getSteps"]>,
      RecursiveMemoryEffects,
      DeclareOpInterfaceMethods<RegionBranchOpInterface>,
      SingleBlockImplicitTerminator<"scf::ReduceOp">,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 5cf3eba0bd9ed..cc79d026c8d4e 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -106,34 +106,34 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     InterfaceMethod<[{
         Return all lower bounds.
       }],
-      /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
-      /*methodName=*/"getMixedLowerBound",
+      /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
+      /*methodName=*/"getLowerBounds",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return {};
+        return std::nullopt;
       }]
     >,
     InterfaceMethod<[{
         Return all steps.
       }],
-      /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
-      /*methodName=*/"getMixedStep",
+      /*retTy=*/"std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
+      /*methodName=*/"getSteps",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return {};
+        return std::nullopt;
       }]
     >,
     InterfaceMethod<[{
         Return all upper bounds.
       }],
-      /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
-      /*methodName=*/"getMixedUpperBound",
+      /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
+      /*methodName=*/"getUpperBounds",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return {};
+        return std::nullopt;
       }]
     >,
     InterfaceMethod<[{
@@ -242,26 +242,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     /// Return the single lower bound value or attribute if it exists, otherwise
     /// return std::nullopt.
     ::std::optional<::mlir::OpFoldResult> getSingleLowerBound() {
-      auto lowerBounds = this->getMixedLowerBound();
-      if (lowerBounds.size() == 1)
-          return lowerBounds[0];
-        return std::nullopt;
+      auto lowerBounds = this->getLowerBounds();
+      if (lowerBounds.has_value() && (*lowerBounds).size() == 1)
+          return (*lowerBounds)[0];
+      return std::nullopt;
     }
     /// Return the single step value or attribute if it exists, otherwise
     /// return std::nullopt.
     ::std::optional<::mlir::OpFoldResult> getSingleStep() {
-      auto steps = this->getMixedStep(); 
-      if (steps.size() == 1)
-          return steps[0];
-        return std::nullopt;
+      auto steps = this->getSteps(); 
+      if (steps.has_value() && (*steps).size() == 1)
+          return (*steps)[0];
+      return std::nullopt;
     }
     /// Return the single upper bound value or attribute if it exists, otherwise
     /// return std::nullopt.
     ::std::optional<::mlir::OpFoldResult> getSingleUpperBound() {
-      auto upperBounds = this->getMixedUpperBound();
-      if (upperBounds.size() == 1)
-          return upperBounds[0];
-        return std::nullopt;
+      auto upperBounds = this->getUpperBounds();
+      if (upperBounds.has_value() && (*upperBounds).size() == 1)
+          return (*upperBounds)[0];
+      return std::nullopt;
     }
 
     /// Append the specified additional "init" operands: replace this loop with
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 10ccd5c97783b..20487b32e3fe0 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -196,8 +196,8 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
         return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds");
       upperBoundTuple.push_back(upper);
     }
-    steps.reserve(op.getSteps().size());
-    for (int64_t step : op.getSteps())
+    steps.reserve(op.getStep().size());
+    for (int64_t step : op.getStep())
       steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
 
     // Get the terminator op.
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 9eb8a289d7d65..48e1d88c1c75e 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -695,12 +695,9 @@ LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
         "only fully bufferized scf.forall ops can be lowered to scf.parallel");
 
   // Convert mixed bounds and steps to SSA values.
-  SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
-      rewriter, loc, forallOp.getMixedLowerBound());
-  SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
-      rewriter, loc, forallOp.getMixedUpperBound());
-  SmallVector<Value> steps =
-      getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
+  SmallVector<Value> lbs = forallOp.getLowerBound(rewriter);
+  SmallVector<Value> ubs = forallOp.getUpperBound(rewriter);
+  SmallVector<Value> steps = forallOp.getStep(rewriter);
 
   // Create empty scf.parallel op.
   auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps);
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 746a9c919560c..d3f034a0660ba 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2456,23 +2456,26 @@ SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
 
 ValueRange AffineForOp::getInductionVars() { return {getInductionVar()}; }
 
-SmallVector<OpFoldResult> AffineForOp::getMixedLowerBound() {
+std::optional<SmallVector<OpFoldResult>> AffineForOp::getLowerBounds() {
   if (!hasConstantLowerBound())
-    return {};
+    return std::nullopt;
   OpBuilder b(getContext());
-  return {OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
+  return SmallVector<OpFoldResult>{
+      OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
 }
 
-SmallVector<OpFoldResult> AffineForOp::getMixedStep() {
+std::optional<SmallVector<OpFoldResult>> AffineForOp::getSteps() {
   OpBuilder b(getContext());
-  return {OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
+  return SmallVector<OpFoldResult>{
+      OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
 }
 
-SmallVector<OpFoldResult> AffineForOp::getMixedUpperBound() {
+std::optional<SmallVector<OpFoldResult>> AffineForOp::getUpperBounds() {
   if (!hasConstantUpperBound())
     return {};
   OpBuilder b(getContext());
-  return {OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
+  return SmallVector<OpFoldResult>{
+      OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
 }
 
 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
@@ -3753,7 +3756,7 @@ SmallVector<Region *> AffineParallelOp::getLoopRegions() {
   return {&getRegion()};
 }
 
-unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
+unsigned AffineParallelOp::getNumDims() { return getStep().size(); }
 
 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
   return getOperands().take_front(getLowerBoundsMap().getNumInputs());
@@ -3838,7 +3841,7 @@ void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
 }
 
 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
-  setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
+  setStepAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
 }
 
 // check whether resultType match op or not in affine.parallel
@@ -3888,14 +3891,14 @@ LogicalResult AffineParallelOp::verify() {
   auto numDims = getNumDims();
   if (getLowerBoundsGroups().getNumElements() != numDims ||
       getUpperBoundsGroups().getNumElements() != numDims ||
-      getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
+      getStep().size() != numDims || getBody()->getNumArguments() != numDims) {
     return emitOpError() << "the number of region arguments ("
                          << getBody()->getNumArguments()
                          << ") and the number of map groups for lower ("
                          << getLowerBoundsGroups().getNumElements()
                          << ") and upper bound ("
                          << getUpperBoundsGroups().getNumElements()
-                         << "), and the number of steps (" << getSteps().size()
+                         << "), and the number of steps (" << getStep().size()
                          << ") must all match";
   }
 
@@ -4013,7 +4016,7 @@ void AffineParallelOp::print(OpAsmPrinter &p) {
   printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
                    getUpperBoundsOperands(), "min");
   p << ')';
-  SmallVector<int64_t, 8> steps = getSteps();
+  SmallVector<int64_t, 8> steps = getStep();
   bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
   if (!elideSteps) {
     p << " step (";
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index f46381403bc52..a652ee4a488d1 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -494,7 +494,7 @@ void mlir::affine::normalizeAffineParallel(AffineParallelOp op) {
     return;
 
   AffineMap lbMap = op.getLowerBoundsMap();
-  SmallVector<int64_t, 8> steps = op.getSteps();
+  SmallVector<int64_t, 8> steps = op.getStep();
   // No need to do any work if the parallel op is already normalized.
   bool isAlreadyNormalized =
       llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) {
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index e275ff1849c10..281d73afee4a8 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -380,16 +380,16 @@ LogicalResult ForOp::verifyRegions() {
 
 ValueRange ForOp::getInductionVars() { return {getInductionVar()}; }
 
-SmallVector<OpFoldResult> ForOp::getMixedLowerBound() {
-  return {OpFoldResult(getLowerBound())};
+std::optional<SmallVector<OpFoldResult>> ForOp::getLowerBounds() {
+  return SmallVector<OpFoldResult, 1>{OpFoldResult(getLowerBound())};
 }
 
-SmallVector<OpFoldResult> ForOp::getMixedStep() {
-  return {OpFoldResult(getStep())};
+std::optional<SmallVector<OpFoldResult>> ForOp::getSteps() {
+  return SmallVector<OpFoldResult, 1>{OpFoldResult(getStep())};
 }
 
-SmallVector<OpFoldResult> ForOp::getMixedUpperBound() {
-  return {OpFoldResult(getUpperBound())};
+std::optional<SmallVector<OpFoldResult>> ForOp::getUpperBounds() {
+  return SmallVector<OpFoldResult, 1>{OpFoldResult(getUpperBound())};
 }
 
 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
@@ -1431,19 +1431,19 @@ ValueRange ForallOp::getInductionVars() {
 }
 
 // Get lower bounds as OpFoldResult.
-SmallVector<OpFoldResult> ForallOp::getMixedLowerBound() {
+std::optional<SmallVector<OpFoldResult>> ForallOp::getLowerBounds() {
   Builder b(getOperation()->getContext());
   return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
 }
 
 // Get upper bounds as OpFoldResult.
-SmallVector<OpFoldResult> ForallOp::getMixedUpperBound() {
+std::optional<SmallVector<OpFoldResult>> ForallOp::getUpperBounds() {
   Builder b(getOperation()->getContext());
   return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
 }
 
 // Get steps as OpFoldResult.
-SmallVector<OpFoldResult> ForallOp::getMixedStep() {
+std::optional<SmallVector<OpFoldResult>> ForallOp::getSteps() {
   Builder b(getOperation()->getContext());
   return getMixedValues(getStaticStep(), getDynamicStep(), b);
 }
@@ -3006,15 +3006,17 @@ SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
 
 ValueRange ParallelOp::getInductionVars() { return getBody()->getArguments(); }
 
-SmallVector<OpFoldResult> ParallelOp::getMixedLowerBound() {
+std::optional<SmallVector<OpFoldResult>> ParallelOp::getLowerBounds() {
   return getLowerBound();
 }
 
-SmallVector<OpFoldResult> ParallelOp::getMixedUpperBound() {
+std::optional<SmallVector<OpFoldResult>> ParallelOp::getUpperBounds() {
   return getUpperBound();
 }
 
-SmallVector<OpFoldResult> ParallelOp::getMixedStep() { return getStep(); }
+std::optional<SmallVector<OpFoldResult>> ParallelOp::getSteps() {
+  return getStep();
+}
 
 ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
   auto ivArg = llvm::dyn_cast<BlockArgument>(val);
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
index 198cb2e6cc69e..5da1b76e929be 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
@@ -34,12 +34,9 @@ mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
   rewriter.setInsertionPoint(forallOp);
 
   Location loc = forallOp.getLoc();
-  SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
-      rewriter, loc, forallOp.getMixedLowerBound());
-  SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
-      rewriter, loc, forallOp.getMixedUpperBound());
-  SmallVector<Value> steps =
-      getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
+  SmallVector<Value> lbs = forallOp.getLowerBound(rewriter);
+  SmallVector<Value> ubs = forallOp.getUpperBound(rewriter);
+  SmallVector<Value> steps = forallOp.getStep(rewriter);
   LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps);
 
   SmallVector<Value> ivs = llvm::map_to_vector(
diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
index d8cdb213070da..07504a99fecd3 100644
--- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
+++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
@@ -27,35 +27,57 @@ class SCFLoopLikeTest : public ::testing::Test {
   }
 
   void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
-    std::optional<OpFoldResult> maybeLb = loopLikeOp.getSingleLowerBound();
+    std::optional<OpFoldResult> maybeSingleLb =
+        loopLikeOp.getSingleLowerBound();
+    EXPECT_TRUE(maybeSingleLb.has_value());
+    std::optional<OpFoldResult> maybeSingleUb =
+        loopLikeOp.getSingleUpperBound();
+    EXPECT_TRUE(maybeSingleUb.has_value());
+    std::optional<OpFoldResult> maybeSingleStep = loopLikeOp.getSingleStep();
+    EXPECT_TRUE(maybeSingleStep.has_value());
+    std::optional<OpFoldResult> maybeSingleIndVar =
+        loopLikeOp.getSingleInductionVar();
+    EXPECT_TRUE(maybeSingleIndVar.has_value());
+
+    std::optional<SmallVector<OpFoldResult>> maybeLb =
+        loopLikeOp.getLowerBounds();
     EXPECT_TRUE(maybeLb.has_value());
-    std::optional<OpFoldResult> maybeUb = loopLikeOp.getSingleUpperBound();
+    EXPECT_EQ((*maybeLb).size(), 1u);
+    std::optional<SmallVector<OpFoldResult>> maybeUb =
+        loopLikeOp.getUpperBounds();
     EXPECT_TRUE(maybeUb.has_value());
-    std::optional<OpFoldResult> maybeStep = loopLikeOp.getSingleStep();
+    EXPECT_EQ((*maybeUb).size(), 1u);
+    std::optional<SmallVector<OpFoldResult>> maybeStep = loopLikeOp.getSteps();
     EXPECT_TRUE(maybeStep.has_value());
-    std::optional<OpFoldResult> maybeIndVar =
-        loopLikeOp.getSingleInductionVar();
-    EXPECT_TRUE(maybeIndVar.has_value());
+    EXPECT_EQ((*maybeStep).size(), 1u);
     EXPECT_EQ(loopLikeOp.getInductionVars().size(), 1u);
-    EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u);
-    EXPECT_EQ(loopLikeOp.getMixedStep().size(), 1u);
-    EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u);
   }
 
   void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
-    std::optional<OpFoldResult> maybeLb = loopLikeOp.getSingleLowerBound();
-    EXPECT_FALSE(maybeLb.has_value());
-    std::optional<OpFoldResult> maybeUb = loopLikeOp.getSingleUpperBound();
-    EXPECT_FALSE(maybeUb.has_value());
-    std::optional<OpFoldResult> maybeStep = loopLikeOp.getSingleStep();
-    EXPECT_FALSE(maybeStep.has_value());
-    std::optional<OpFoldResult> maybeIndVar =
+    std::optional<OpFoldResult> maybeSingleLb =
+        loopLikeOp.getSingleLowerBound();
+    EXPECT_FALSE(maybeSingleLb.has_value());
+    std::optional<OpFoldResult> maybeSingleUb =
+        loopLikeOp.getSingleUpperBound();
+    EXPECT_FALSE(maybeSingleUb.has_value());
+    std::optional<OpFoldResult> maybeSingleStep = loopLikeOp.getSingleStep();
+    EXPECT_FALSE(maybeSingleStep.has_value());
+    std::optional<OpFoldResult> maybeSingleIndVar =
         loopLikeOp.getSingleInductionVar();
-    EXPECT_FALSE(maybeIndVar.has_value());
+    EXPECT_FALSE(maybeSingleIndVar.has_value());
+
+    std::optional<SmallVector<OpFoldResult>> maybeLb =
+        loopLikeOp.getLowerBounds();
+    EXPECT_TRUE(maybeLb.has_value());
+    EXPECT_EQ((*maybeLb).size(), 2u);
+    std::optional<SmallVector<OpFoldResult>> maybeUb =
+        loopLikeOp.getUpperBounds();
+    EXPECT_TRUE(maybeUb.has_value());
+    EXPECT_EQ((*maybeUb).size(), 2u);
+    std::optional<SmallVector<OpFoldResult>> maybeStep = loopLikeOp.getSteps();
+    EXPECT_TRUE(maybeStep.has_value());
+    EXPECT_EQ((*maybeStep).size(), 2u);
     EXPECT_EQ(loopLikeOp.getInductionVars().size(), 2u);
-    EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u);
-    EXPECT_EQ(loopLikeOp.getMixedStep().size(), 2u);
-    EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u);
   }
 
   MLIRContext context;

>From 1babe681d7858a4992303c62e22684cb73d82472 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 6 Jun 2024 11:31:11 -0500
Subject: [PATCH 07/34] change return type of getInductionVars to
 SmallVector<Value>

---
 mlir/include/mlir/Interfaces/LoopLikeInterface.td |  2 +-
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp          |  4 +++-
 mlir/lib/Dialect/Linalg/Transforms/Loops.cpp      |  3 +--
 mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp     |  6 +++---
 mlir/lib/Dialect/SCF/IR/SCF.cpp                   | 10 ++++++----
 5 files changed, 14 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index cc79d026c8d4e..bace8f8384d44 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -95,7 +95,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     InterfaceMethod<[{
         Return all induction variables.
       }],
-      /*retTy=*/"::mlir::ValueRange",
+      /*retTy=*/"::llvm::SmallVector<::mlir::Value>",
       /*methodName=*/"getInductionVars",
       /*args=*/(ins),
       /*methodBody=*/"",
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index d3f034a0660ba..5467c60242664 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2454,7 +2454,9 @@ bool AffineForOp::matchingBoundOperandList() {
 
 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
 
-ValueRange AffineForOp::getInductionVars() { return {getInductionVar()}; }
+SmallVector<Value> AffineForOp::getInductionVars() {
+  return {getInductionVar()};
+}
 
 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLowerBounds() {
   if (!hasConstantLowerBound())
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index b0a4de2da1e86..8b0e04fb61b1b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -184,8 +184,7 @@ static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter,
   for (Operation *loopOp : loopOps) {
     llvm::TypeSwitch<Operation *>(loopOp)
         .Case([&](scf::ParallelOp parallelOp) {
-          allIvs.append(parallelOp.getInductionVars().begin(),
-                        parallelOp.getInductionVars().end());
+          allIvs.append(parallelOp.getInductionVars());
         })
         .Case([&](scf::ForOp forOp) {
           allIvs.push_back(forOp.getInductionVar());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index fd314ef9f8134..4eacaa8d1e327 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -243,7 +243,7 @@ static void calculateTileOffsetsAndSizes(
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPointToStart(forallOp.getBody(0));
 
-  ValueRange threadIds = forallOp.getInductionVars();
+  auto threadIds = forallOp.getInductionVars();
   SmallVector<OpFoldResult> nonZeroNumThreads =
       llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
         return !isConstantIntValue(ofr, 0);
@@ -746,7 +746,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
                                            b.getIndexAttr(0));
       SmallVector<OpFoldResult> sizes = tiledSizes;
       sizes[reductionDim] = b.getIndexAttr(1);
-      outOffsets[reductionDim] = forallOp.getInductionVars().front();
+      outOffsets[reductionDim] = forallOp.getInductionVars()[0];
       // TODO: use SubsetExtractOpInterface once it is available.
       tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>(
           loc, cast<RankedTensorType>(initOperand.getType()),
@@ -814,7 +814,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
     int64_t sizeIdx = 0;
     for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
       if (i == reductionDim) {
-        resultOffsetsRank.push_back(forallOp.getInductionVars().front());
+        resultOffsetsRank.push_back(forallOp.getInductionVars()[0]);
         resultSizesRank.push_back(b.getIndexAttr(1));
         continue;
       }
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 281d73afee4a8..0ce10ebdad3e2 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -378,7 +378,7 @@ LogicalResult ForOp::verifyRegions() {
   return success();
 }
 
-ValueRange ForOp::getInductionVars() { return {getInductionVar()}; }
+SmallVector<Value> ForOp::getInductionVars() { return {getInductionVar()}; }
 
 std::optional<SmallVector<OpFoldResult>> ForOp::getLowerBounds() {
   return SmallVector<OpFoldResult, 1>{OpFoldResult(getLowerBound())};
@@ -1426,8 +1426,8 @@ SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
   return storeOps;
 }
 
-ValueRange ForallOp::getInductionVars() {
-  return getBody()->getArguments().take_front(getRank());
+SmallVector<Value> ForallOp::getInductionVars() {
+  return SmallVector<Value>(getBody()->getArguments().take_front(getRank()));
 }
 
 // Get lower bounds as OpFoldResult.
@@ -3004,7 +3004,9 @@ void ParallelOp::print(OpAsmPrinter &p) {
 
 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
 
-ValueRange ParallelOp::getInductionVars() { return getBody()->getArguments(); }
+SmallVector<Value> ParallelOp::getInductionVars() {
+  return SmallVector<Value>(getBody()->getArguments());
+}
 
 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLowerBounds() {
   return getLowerBound();

>From 009fd15ab8abefd56afe6424e27f912a4166329d Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 6 Jun 2024 14:02:52 -0500
Subject: [PATCH 08/34] address maks's comments

---
 mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 2 +-
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 4eacaa8d1e327..a0a0e11a6903d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -243,7 +243,7 @@ static void calculateTileOffsetsAndSizes(
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPointToStart(forallOp.getBody(0));
 
-  auto threadIds = forallOp.getInductionVars();
+  SmallVector<Value> threadIds = forallOp.getInductionVars();
   SmallVector<OpFoldResult> nonZeroNumThreads =
       llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
         return !isConstantIntValue(ofr, 0);
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 0ce10ebdad3e2..a930f8c71454c 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1427,7 +1427,7 @@ SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
 }
 
 SmallVector<Value> ForallOp::getInductionVars() {
-  return SmallVector<Value>(getBody()->getArguments().take_front(getRank()));
+  return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
 }
 
 // Get lower bounds as OpFoldResult.
@@ -3005,7 +3005,7 @@ void ParallelOp::print(OpAsmPrinter &p) {
 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
 
 SmallVector<Value> ParallelOp::getInductionVars() {
-  return SmallVector<Value>(getBody()->getArguments());
+  return SmallVector<Value>{getBody()->getArguments()};
 }
 
 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLowerBounds() {

>From d34ad95aba669b5700976f0d2ed4d68b4902e9be Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 6 Jun 2024 15:26:06 -0500
Subject: [PATCH 09/34] change interface method names again and revert steps
 operand change

---
 .../mlir/Dialect/Affine/IR/AffineOps.td       | 10 +++----
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    | 27 ++++++++++++-------
 .../mlir/Interfaces/LoopLikeInterface.td      | 16 +++++------
 .../AffineToStandard/AffineToStandard.cpp     |  4 +--
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 18 ++++++-------
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       |  2 +-
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 24 ++++++++---------
 .../Dialect/SCF/LoopLikeSCFOpsTest.cpp        | 18 +++++++------
 8 files changed, 64 insertions(+), 55 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 4c032e66f7a83..dbec741cf1b1f 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -118,8 +118,8 @@ def AffineForOp : Affine_Op<"for",
     [AttrSizedOperandSegments, AutomaticAllocationScope,
      ImplicitAffineTerminator, ConditionallySpeculatable,
      RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
-     ["getInductionVars", "getLowerBounds", "getSteps",
-      "getUpperBounds", "getYieldedValuesMutable",
+     ["getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
+      "getLoopUpperBounds", "getYieldedValuesMutable",
       "replaceWithAdditionalYields"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
      ["getEntrySuccessorOperands"]>]> {
@@ -671,7 +671,7 @@ def AffineParallelOp : Affine_Op<"parallel",
      I32ElementsAttr:$lowerBoundsGroups,
      AffineMapAttr:$upperBoundsMap,
      I32ElementsAttr:$upperBoundsGroups,
-     I64SmallVectorArrayAttr:$step,
+     I64SmallVectorArrayAttr:$steps,
      Variadic<Index>:$mapOperands);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
@@ -682,7 +682,7 @@ def AffineParallelOp : Affine_Op<"parallel",
     OpBuilder<(ins "TypeRange":$resultTypes,
       "ArrayRef<arith::AtomicRMWKind>":$reductions, "ArrayRef<AffineMap>":$lbMaps,
       "ValueRange":$lbArgs, "ArrayRef<AffineMap>":$ubMaps, "ValueRange":$ubArgs,
-      "ArrayRef<int64_t>":$step)>
+      "ArrayRef<int64_t>":$steps)>
   ];
 
   let extraClassDeclaration = [{
@@ -727,7 +727,7 @@ def AffineParallelOp : Affine_Op<"parallel",
     static StringRef getUpperBoundsGroupsAttrStrName() {
       return "upperBoundsGroups";
     }
-    static StringRef getStepsAttrStrName() { return "step"; }
+    static StringRef getStepsAttrStrName() { return "steps"; }
 
     /// Returns `true` if the loop bounds have min/max expressions.
     bool hasMinMaxBounds() {
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 66b478f141b32..3704b15972278 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -136,8 +136,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
 def ForOp : SCF_Op<"for",
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
        ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
-        "getInductionVars", "getLowerBounds", "getSteps",
-        "getUpperBounds", "getYieldedValuesMutable",
+        "getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
+        "getLoopUpperBounds", "getYieldedValuesMutable",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
         "yieldTiledValuesAndReplace"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
@@ -301,8 +301,8 @@ def ForallOp : SCF_Op<"forall", [
        AttrSizedOperandSegments,
        AutomaticAllocationScope,
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
-          ["getInitsMutable", "getRegionIterArgs", "getInductionVars", 
-           "getLowerBounds", "getUpperBounds", "getSteps",
+          ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars", 
+           "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
            "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
        RecursiveMemoryEffects,
        SingleBlockImplicitTerminator<"scf::InParallelOp">,
@@ -510,23 +510,26 @@ def ForallOp : SCF_Op<"forall", [
   ];
 
   let extraClassDeclaration = [{
+    SmallVector<Value> getInductionVars() {
+      return getLoopInductionVars();
+    }
     // Get lower bounds as OpFoldResult.
     SmallVector<OpFoldResult> getMixedLowerBound() {
-      auto maybeLowerBounds = getLowerBounds();
+      auto maybeLowerBounds = getLoopLowerBounds();
       assert(maybeLowerBounds.has_value() && "expected values");
       return *maybeLowerBounds;
     }
 
     // Get upper bounds as OpFoldResult.
     SmallVector<OpFoldResult> getMixedUpperBound() {
-      auto maybeUpperBounds = getUpperBounds();
+      auto maybeUpperBounds = getLoopUpperBounds();
       assert(maybeUpperBounds.has_value() && "expected values");
       return *maybeUpperBounds;
     }
 
     // Get steps as OpFoldResult.
     SmallVector<OpFoldResult> getMixedStep() {
-      auto maybeSteps = getSteps();
+      auto maybeSteps = getLoopSteps();
       assert(maybeSteps.has_value() && "expected values");
       return *maybeSteps;
     }
@@ -588,7 +591,7 @@ def ForallOp : SCF_Op<"forall", [
     }
 
     ::mlir::Value getInductionVar(int64_t idx) {
-      return getInductionVars()[idx];
+      return getLoopInductionVars()[idx];
     }
 
     ::mlir::Block::BlockArgListType getRegionOutArgs() {
@@ -764,8 +767,8 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
 def ParallelOp : SCF_Op<"parallel",
     [AutomaticAllocationScope,
      AttrSizedOperandSegments,
-     DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getInductionVars",
-          "getLowerBounds", "getUpperBounds", "getSteps"]>,
+     DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getLoopInductionVars",
+          "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps"]>,
      RecursiveMemoryEffects,
      DeclareOpInterfaceMethods<RegionBranchOpInterface>,
      SingleBlockImplicitTerminator<"scf::ReduceOp">,
@@ -845,6 +848,10 @@ def ParallelOp : SCF_Op<"parallel",
   ];
 
   let extraClassDeclaration = [{
+    // Get induction variables.
+    SmallVector<Value> getInductionVars() {
+      return getLoopInductionVars();
+    }
     unsigned getNumLoops() { return getStep().size(); }
     unsigned getNumReductions() { return getInitVals().size(); }
   }];
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index bace8f8384d44..5312ace4db68e 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -96,7 +96,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
         Return all induction variables.
       }],
       /*retTy=*/"::llvm::SmallVector<::mlir::Value>",
-      /*methodName=*/"getInductionVars",
+      /*methodName=*/"getLoopInductionVars",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
@@ -107,7 +107,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
         Return all lower bounds.
       }],
       /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
-      /*methodName=*/"getLowerBounds",
+      /*methodName=*/"getLoopLowerBounds",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
@@ -118,7 +118,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
         Return all steps.
       }],
       /*retTy=*/"std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
-      /*methodName=*/"getSteps",
+      /*methodName=*/"getLoopSteps",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
@@ -129,7 +129,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
         Return all upper bounds.
       }],
       /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
-      /*methodName=*/"getUpperBounds",
+      /*methodName=*/"getLoopUpperBounds",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
@@ -234,7 +234,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     /// If there is a single induction variable return it, otherwise return
     /// std::nullopt.
     ::std::optional<::mlir::Value> getSingleInductionVar() {
-      auto inductionVars = this->getInductionVars();
+      auto inductionVars = this->getLoopInductionVars();
       if (inductionVars.size() == 1)
           return inductionVars[0];
         return std::nullopt;
@@ -242,7 +242,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     /// Return the single lower bound value or attribute if it exists, otherwise
     /// return std::nullopt.
     ::std::optional<::mlir::OpFoldResult> getSingleLowerBound() {
-      auto lowerBounds = this->getLowerBounds();
+      auto lowerBounds = this->getLoopLowerBounds();
       if (lowerBounds.has_value() && (*lowerBounds).size() == 1)
           return (*lowerBounds)[0];
       return std::nullopt;
@@ -250,7 +250,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     /// Return the single step value or attribute if it exists, otherwise
     /// return std::nullopt.
     ::std::optional<::mlir::OpFoldResult> getSingleStep() {
-      auto steps = this->getSteps(); 
+      auto steps = this->getLoopSteps(); 
       if (steps.has_value() && (*steps).size() == 1)
           return (*steps)[0];
       return std::nullopt;
@@ -258,7 +258,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     /// Return the single upper bound value or attribute if it exists, otherwise
     /// return std::nullopt.
     ::std::optional<::mlir::OpFoldResult> getSingleUpperBound() {
-      auto upperBounds = this->getUpperBounds();
+      auto upperBounds = this->getLoopUpperBounds();
       if (upperBounds.has_value() && (*upperBounds).size() == 1)
           return (*upperBounds)[0];
       return std::nullopt;
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 20487b32e3fe0..10ccd5c97783b 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -196,8 +196,8 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
         return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds");
       upperBoundTuple.push_back(upper);
     }
-    steps.reserve(op.getStep().size());
-    for (int64_t step : op.getStep())
+    steps.reserve(op.getSteps().size());
+    for (int64_t step : op.getSteps())
       steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
 
     // Get the terminator op.
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 5467c60242664..d5cb04743dfb9 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2454,11 +2454,11 @@ bool AffineForOp::matchingBoundOperandList() {
 
 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
 
-SmallVector<Value> AffineForOp::getInductionVars() {
+SmallVector<Value> AffineForOp::getLoopInductionVars() {
   return {getInductionVar()};
 }
 
-std::optional<SmallVector<OpFoldResult>> AffineForOp::getLowerBounds() {
+std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
   if (!hasConstantLowerBound())
     return std::nullopt;
   OpBuilder b(getContext());
@@ -2466,13 +2466,13 @@ std::optional<SmallVector<OpFoldResult>> AffineForOp::getLowerBounds() {
       OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
 }
 
-std::optional<SmallVector<OpFoldResult>> AffineForOp::getSteps() {
+std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
   OpBuilder b(getContext());
   return SmallVector<OpFoldResult>{
       OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
 }
 
-std::optional<SmallVector<OpFoldResult>> AffineForOp::getUpperBounds() {
+std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
   if (!hasConstantUpperBound())
     return {};
   OpBuilder b(getContext());
@@ -3758,7 +3758,7 @@ SmallVector<Region *> AffineParallelOp::getLoopRegions() {
   return {&getRegion()};
 }
 
-unsigned AffineParallelOp::getNumDims() { return getStep().size(); }
+unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
 
 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
   return getOperands().take_front(getLowerBoundsMap().getNumInputs());
@@ -3843,7 +3843,7 @@ void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
 }
 
 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
-  setStepAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
+  setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
 }
 
 // check whether resultType match op or not in affine.parallel
@@ -3893,14 +3893,14 @@ LogicalResult AffineParallelOp::verify() {
   auto numDims = getNumDims();
   if (getLowerBoundsGroups().getNumElements() != numDims ||
       getUpperBoundsGroups().getNumElements() != numDims ||
-      getStep().size() != numDims || getBody()->getNumArguments() != numDims) {
+      getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
     return emitOpError() << "the number of region arguments ("
                          << getBody()->getNumArguments()
                          << ") and the number of map groups for lower ("
                          << getLowerBoundsGroups().getNumElements()
                          << ") and upper bound ("
                          << getUpperBoundsGroups().getNumElements()
-                         << "), and the number of steps (" << getStep().size()
+                         << "), and the number of steps (" << getSteps().size()
                          << ") must all match";
   }
 
@@ -4018,7 +4018,7 @@ void AffineParallelOp::print(OpAsmPrinter &p) {
   printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
                    getUpperBoundsOperands(), "min");
   p << ')';
-  SmallVector<int64_t, 8> steps = getStep();
+  SmallVector<int64_t, 8> steps = getSteps();
   bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
   if (!elideSteps) {
     p << " step (";
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index a652ee4a488d1..f46381403bc52 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -494,7 +494,7 @@ void mlir::affine::normalizeAffineParallel(AffineParallelOp op) {
     return;
 
   AffineMap lbMap = op.getLowerBoundsMap();
-  SmallVector<int64_t, 8> steps = op.getStep();
+  SmallVector<int64_t, 8> steps = op.getSteps();
   // No need to do any work if the parallel op is already normalized.
   bool isAlreadyNormalized =
       llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) {
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index a930f8c71454c..e921177f73215 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -378,17 +378,17 @@ LogicalResult ForOp::verifyRegions() {
   return success();
 }
 
-SmallVector<Value> ForOp::getInductionVars() { return {getInductionVar()}; }
+SmallVector<Value> ForOp::getLoopInductionVars() { return {getInductionVar()}; }
 
-std::optional<SmallVector<OpFoldResult>> ForOp::getLowerBounds() {
+std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
   return SmallVector<OpFoldResult, 1>{OpFoldResult(getLowerBound())};
 }
 
-std::optional<SmallVector<OpFoldResult>> ForOp::getSteps() {
+std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
   return SmallVector<OpFoldResult, 1>{OpFoldResult(getStep())};
 }
 
-std::optional<SmallVector<OpFoldResult>> ForOp::getUpperBounds() {
+std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
   return SmallVector<OpFoldResult, 1>{OpFoldResult(getUpperBound())};
 }
 
@@ -1426,24 +1426,24 @@ SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
   return storeOps;
 }
 
-SmallVector<Value> ForallOp::getInductionVars() {
+SmallVector<Value> ForallOp::getLoopInductionVars() {
   return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
 }
 
 // Get lower bounds as OpFoldResult.
-std::optional<SmallVector<OpFoldResult>> ForallOp::getLowerBounds() {
+std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
   Builder b(getOperation()->getContext());
   return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
 }
 
 // Get upper bounds as OpFoldResult.
-std::optional<SmallVector<OpFoldResult>> ForallOp::getUpperBounds() {
+std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
   Builder b(getOperation()->getContext());
   return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
 }
 
 // Get steps as OpFoldResult.
-std::optional<SmallVector<OpFoldResult>> ForallOp::getSteps() {
+std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
   Builder b(getOperation()->getContext());
   return getMixedValues(getStaticStep(), getDynamicStep(), b);
 }
@@ -3004,19 +3004,19 @@ void ParallelOp::print(OpAsmPrinter &p) {
 
 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
 
-SmallVector<Value> ParallelOp::getInductionVars() {
+SmallVector<Value> ParallelOp::getLoopInductionVars() {
   return SmallVector<Value>{getBody()->getArguments()};
 }
 
-std::optional<SmallVector<OpFoldResult>> ParallelOp::getLowerBounds() {
+std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
   return getLowerBound();
 }
 
-std::optional<SmallVector<OpFoldResult>> ParallelOp::getUpperBounds() {
+std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
   return getUpperBound();
 }
 
-std::optional<SmallVector<OpFoldResult>> ParallelOp::getSteps() {
+std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
   return getStep();
 }
 
diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
index 07504a99fecd3..75cd2bfb01de0 100644
--- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
+++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
@@ -40,17 +40,18 @@ class SCFLoopLikeTest : public ::testing::Test {
     EXPECT_TRUE(maybeSingleIndVar.has_value());
 
     std::optional<SmallVector<OpFoldResult>> maybeLb =
-        loopLikeOp.getLowerBounds();
+        loopLikeOp.getLoopLowerBounds();
     EXPECT_TRUE(maybeLb.has_value());
     EXPECT_EQ((*maybeLb).size(), 1u);
     std::optional<SmallVector<OpFoldResult>> maybeUb =
-        loopLikeOp.getUpperBounds();
+        loopLikeOp.getLoopUpperBounds();
     EXPECT_TRUE(maybeUb.has_value());
     EXPECT_EQ((*maybeUb).size(), 1u);
-    std::optional<SmallVector<OpFoldResult>> maybeStep = loopLikeOp.getSteps();
+    std::optional<SmallVector<OpFoldResult>> maybeStep =
+        loopLikeOp.getLoopSteps();
     EXPECT_TRUE(maybeStep.has_value());
     EXPECT_EQ((*maybeStep).size(), 1u);
-    EXPECT_EQ(loopLikeOp.getInductionVars().size(), 1u);
+    EXPECT_EQ(loopLikeOp.getLoopInductionVars().size(), 1u);
   }
 
   void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
@@ -67,17 +68,18 @@ class SCFLoopLikeTest : public ::testing::Test {
     EXPECT_FALSE(maybeSingleIndVar.has_value());
 
     std::optional<SmallVector<OpFoldResult>> maybeLb =
-        loopLikeOp.getLowerBounds();
+        loopLikeOp.getLoopLowerBounds();
     EXPECT_TRUE(maybeLb.has_value());
     EXPECT_EQ((*maybeLb).size(), 2u);
     std::optional<SmallVector<OpFoldResult>> maybeUb =
-        loopLikeOp.getUpperBounds();
+        loopLikeOp.getLoopUpperBounds();
     EXPECT_TRUE(maybeUb.has_value());
     EXPECT_EQ((*maybeUb).size(), 2u);
-    std::optional<SmallVector<OpFoldResult>> maybeStep = loopLikeOp.getSteps();
+    std::optional<SmallVector<OpFoldResult>> maybeStep =
+        loopLikeOp.getLoopSteps();
     EXPECT_TRUE(maybeStep.has_value());
     EXPECT_EQ((*maybeStep).size(), 2u);
-    EXPECT_EQ(loopLikeOp.getInductionVars().size(), 2u);
+    EXPECT_EQ(loopLikeOp.getLoopInductionVars().size(), 2u);
   }
 
   MLIRContext context;

>From e0e526210a5ab6ce28fbc5fa5ee24f79cb1ee9a8 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 6 Jun 2024 16:04:47 -0500
Subject: [PATCH 10/34] return option induction vars

---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td        | 10 +++++++---
 mlir/include/mlir/Interfaces/LoopLikeInterface.td |  8 ++++----
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp          |  4 ++--
 mlir/lib/Dialect/SCF/IR/SCF.cpp                   |  8 +++++---
 mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp | 10 ++++++++--
 5 files changed, 26 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 3704b15972278..d425c1c2a47b4 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -511,7 +511,9 @@ def ForallOp : SCF_Op<"forall", [
 
   let extraClassDeclaration = [{
     SmallVector<Value> getInductionVars() {
-      return getLoopInductionVars();
+      auto maybeInductionVars = getLoopInductionVars();;
+      assert(maybeInductionVars.has_value() && "expected values");
+      return *maybeInductionVars;
     }
     // Get lower bounds as OpFoldResult.
     SmallVector<OpFoldResult> getMixedLowerBound() {
@@ -591,7 +593,7 @@ def ForallOp : SCF_Op<"forall", [
     }
 
     ::mlir::Value getInductionVar(int64_t idx) {
-      return getLoopInductionVars()[idx];
+      return getInductionVars()[idx];
     }
 
     ::mlir::Block::BlockArgListType getRegionOutArgs() {
@@ -850,7 +852,9 @@ def ParallelOp : SCF_Op<"parallel",
   let extraClassDeclaration = [{
     // Get induction variables.
     SmallVector<Value> getInductionVars() {
-      return getLoopInductionVars();
+      auto maybeInductionVars = getLoopInductionVars();;
+      assert(maybeInductionVars.has_value() && "expected values");
+      return *maybeInductionVars;
     }
     unsigned getNumLoops() { return getStep().size(); }
     unsigned getNumReductions() { return getInitVals().size(); }
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 5312ace4db68e..2e6aabda30b07 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -95,12 +95,12 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     InterfaceMethod<[{
         Return all induction variables.
       }],
-      /*retTy=*/"::llvm::SmallVector<::mlir::Value>",
+      /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::Value>>",
       /*methodName=*/"getLoopInductionVars",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return {};
+        return std::nullopt;
       }]
     >,
     InterfaceMethod<[{
@@ -235,8 +235,8 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     /// std::nullopt.
     ::std::optional<::mlir::Value> getSingleInductionVar() {
       auto inductionVars = this->getLoopInductionVars();
-      if (inductionVars.size() == 1)
-          return inductionVars[0];
+      if (inductionVars.has_value() && (*inductionVars).size() == 1)
+          return (*inductionVars)[0];
         return std::nullopt;
     }
     /// Return the single lower bound value or attribute if it exists, otherwise
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index d5cb04743dfb9..0a58d2fdb02f5 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2454,8 +2454,8 @@ bool AffineForOp::matchingBoundOperandList() {
 
 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
 
-SmallVector<Value> AffineForOp::getLoopInductionVars() {
-  return {getInductionVar()};
+std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
+  return SmallVector<Value>{getInductionVar()};
 }
 
 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index e921177f73215..c00579443ea29 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -378,7 +378,9 @@ LogicalResult ForOp::verifyRegions() {
   return success();
 }
 
-SmallVector<Value> ForOp::getLoopInductionVars() { return {getInductionVar()}; }
+std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
+  return SmallVector<Value, 1>{getInductionVar()};
+}
 
 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
   return SmallVector<OpFoldResult, 1>{OpFoldResult(getLowerBound())};
@@ -1426,7 +1428,7 @@ SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
   return storeOps;
 }
 
-SmallVector<Value> ForallOp::getLoopInductionVars() {
+std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
   return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
 }
 
@@ -3004,7 +3006,7 @@ void ParallelOp::print(OpAsmPrinter &p) {
 
 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
 
-SmallVector<Value> ParallelOp::getLoopInductionVars() {
+std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
   return SmallVector<Value>{getBody()->getArguments()};
 }
 
diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
index 75cd2bfb01de0..20dbc8d362d27 100644
--- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
+++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
@@ -51,7 +51,10 @@ class SCFLoopLikeTest : public ::testing::Test {
         loopLikeOp.getLoopSteps();
     EXPECT_TRUE(maybeStep.has_value());
     EXPECT_EQ((*maybeStep).size(), 1u);
-    EXPECT_EQ(loopLikeOp.getLoopInductionVars().size(), 1u);
+    std::optional<SmallVector<Value>> maybeInductionVars =
+        loopLikeOp.getLoopInductionVars();
+    EXPECT_TRUE(maybeInductionVars.has_value());
+    EXPECT_EQ((*maybeInductionVars).size(), 1u);
   }
 
   void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
@@ -79,7 +82,10 @@ class SCFLoopLikeTest : public ::testing::Test {
         loopLikeOp.getLoopSteps();
     EXPECT_TRUE(maybeStep.has_value());
     EXPECT_EQ((*maybeStep).size(), 2u);
-    EXPECT_EQ(loopLikeOp.getLoopInductionVars().size(), 2u);
+    std::optional<SmallVector<Value>> maybeInductionVars =
+        loopLikeOp.getLoopInductionVars();
+    EXPECT_TRUE(maybeInductionVars.has_value());
+    EXPECT_EQ((*maybeInductionVars).size(), 2u);
   }
 
   MLIRContext context;

>From 7115a6e08bba43fe9750f8cef5c73f6be1b373fd Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 7 Jun 2024 11:45:44 -0500
Subject: [PATCH 11/34] address review comments

---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    | 19 ++++++------
 .../mlir/Interfaces/LoopLikeInterface.td      | 30 +++++++++++++------
 mlir/lib/Dialect/SCF/IR/SCF.cpp               |  8 ++---
 .../Dialect/SCF/LoopLikeSCFOpsTest.cpp        | 16 +++++-----
 4 files changed, 43 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index d425c1c2a47b4..f35ea962bea16 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -510,28 +510,29 @@ def ForallOp : SCF_Op<"forall", [
   ];
 
   let extraClassDeclaration = [{
+    /// Get induction variables.
     SmallVector<Value> getInductionVars() {
-      auto maybeInductionVars = getLoopInductionVars();;
+      std::optional<SmallVector<Value>> maybeInductionVars = getLoopInductionVars();
       assert(maybeInductionVars.has_value() && "expected values");
       return *maybeInductionVars;
     }
-    // Get lower bounds as OpFoldResult.
+    /// Get lower bounds as OpFoldResult.
     SmallVector<OpFoldResult> getMixedLowerBound() {
-      auto maybeLowerBounds = getLoopLowerBounds();
+      std::optional<SmallVector<OpFoldResult>> maybeLowerBounds = getLoopLowerBounds();
       assert(maybeLowerBounds.has_value() && "expected values");
       return *maybeLowerBounds;
     }
 
-    // Get upper bounds as OpFoldResult.
+    /// Get upper bounds as OpFoldResult.
     SmallVector<OpFoldResult> getMixedUpperBound() {
-      auto maybeUpperBounds = getLoopUpperBounds();
+      std::optional<SmallVector<OpFoldResult>> maybeUpperBounds = getLoopUpperBounds();
       assert(maybeUpperBounds.has_value() && "expected values");
       return *maybeUpperBounds;
     }
 
-    // Get steps as OpFoldResult.
+    /// Get steps as OpFoldResult.
     SmallVector<OpFoldResult> getMixedStep() {
-      auto maybeSteps = getLoopSteps();
+      std::optional<SmallVector<OpFoldResult>> maybeSteps = getLoopSteps();
       assert(maybeSteps.has_value() && "expected values");
       return *maybeSteps;
     }
@@ -850,9 +851,9 @@ def ParallelOp : SCF_Op<"parallel",
   ];
 
   let extraClassDeclaration = [{
-    // Get induction variables.
+    /// Get induction variables.
     SmallVector<Value> getInductionVars() {
-      auto maybeInductionVars = getLoopInductionVars();;
+      std::optional<SmallVector<Value>> maybeInductionVars = getLoopInductionVars();;
       assert(maybeInductionVars.has_value() && "expected values");
       return *maybeInductionVars;
     }
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 2e6aabda30b07..b748d5e29114a 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -93,47 +93,59 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       }]
     >,
     InterfaceMethod<[{
-        Return all induction variables.
+        Return all induction variables, if they exist. If the op has no notion of
+        induction variable, then return std::nullopt. If it does have
+        a notion but an instance doesn't have induction variables, then
+        return empty vector.
       }],
       /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::Value>>",
       /*methodName=*/"getLoopInductionVars",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return std::nullopt;
+        return ::std::nullopt;
       }]
     >,
     InterfaceMethod<[{
-        Return all lower bounds.
+        Return all lower bounds, if they exist. If the op has no notion of
+        lower bounds, then return std::nullopt. If it does have
+        a notion but an instance doesn't have lower bounds, then
+        return empty vector.
       }],
       /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
       /*methodName=*/"getLoopLowerBounds",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return std::nullopt;
+        return ::std::nullopt;
       }]
     >,
     InterfaceMethod<[{
-        Return all steps.
+        Return all steps, if they exist. If the op has no notion of
+        steps, then return std::nullopt. If it does have
+        a notion but an instance doesn't have steps, then
+        return empty vector.
       }],
-      /*retTy=*/"std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
+      /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
       /*methodName=*/"getLoopSteps",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return std::nullopt;
+        return ::std::nullopt;
       }]
     >,
     InterfaceMethod<[{
-        Return all upper bounds.
+        Return all upper bounds, if they exist. If the op has no notion of
+        lower bounds, then return std::nullopt. If it does have
+        a notion but an instance doesn't have lower bounds, then
+        return empty vector.
       }],
       /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
       /*methodName=*/"getLoopUpperBounds",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return std::nullopt;
+        return ::std::nullopt;
       }]
     >,
     InterfaceMethod<[{
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c00579443ea29..5e94f4dc612a7 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -379,19 +379,19 @@ LogicalResult ForOp::verifyRegions() {
 }
 
 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
-  return SmallVector<Value, 1>{getInductionVar()};
+  return SmallVector<Value>{getInductionVar()};
 }
 
 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
-  return SmallVector<OpFoldResult, 1>{OpFoldResult(getLowerBound())};
+  return SmallVector<OpFoldResult>{OpFoldResult(getLowerBound())};
 }
 
 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
-  return SmallVector<OpFoldResult, 1>{OpFoldResult(getStep())};
+  return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
 }
 
 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
-  return SmallVector<OpFoldResult, 1>{OpFoldResult(getUpperBound())};
+  return SmallVector<OpFoldResult>{OpFoldResult(getUpperBound())};
 }
 
 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
index 20dbc8d362d27..53a4af14d119a 100644
--- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
+++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
@@ -41,19 +41,19 @@ class SCFLoopLikeTest : public ::testing::Test {
 
     std::optional<SmallVector<OpFoldResult>> maybeLb =
         loopLikeOp.getLoopLowerBounds();
-    EXPECT_TRUE(maybeLb.has_value());
+    ASSERT_TRUE(maybeLb.has_value());
     EXPECT_EQ((*maybeLb).size(), 1u);
     std::optional<SmallVector<OpFoldResult>> maybeUb =
         loopLikeOp.getLoopUpperBounds();
-    EXPECT_TRUE(maybeUb.has_value());
+    ASSERT_TRUE(maybeUb.has_value());
     EXPECT_EQ((*maybeUb).size(), 1u);
     std::optional<SmallVector<OpFoldResult>> maybeStep =
         loopLikeOp.getLoopSteps();
-    EXPECT_TRUE(maybeStep.has_value());
+    ASSERT_TRUE(maybeStep.has_value());
     EXPECT_EQ((*maybeStep).size(), 1u);
     std::optional<SmallVector<Value>> maybeInductionVars =
         loopLikeOp.getLoopInductionVars();
-    EXPECT_TRUE(maybeInductionVars.has_value());
+    ASSERT_TRUE(maybeInductionVars.has_value());
     EXPECT_EQ((*maybeInductionVars).size(), 1u);
   }
 
@@ -72,19 +72,19 @@ class SCFLoopLikeTest : public ::testing::Test {
 
     std::optional<SmallVector<OpFoldResult>> maybeLb =
         loopLikeOp.getLoopLowerBounds();
-    EXPECT_TRUE(maybeLb.has_value());
+    ASSERT_TRUE(maybeLb.has_value());
     EXPECT_EQ((*maybeLb).size(), 2u);
     std::optional<SmallVector<OpFoldResult>> maybeUb =
         loopLikeOp.getLoopUpperBounds();
-    EXPECT_TRUE(maybeUb.has_value());
+    ASSERT_TRUE(maybeUb.has_value());
     EXPECT_EQ((*maybeUb).size(), 2u);
     std::optional<SmallVector<OpFoldResult>> maybeStep =
         loopLikeOp.getLoopSteps();
-    EXPECT_TRUE(maybeStep.has_value());
+    ASSERT_TRUE(maybeStep.has_value());
     EXPECT_EQ((*maybeStep).size(), 2u);
     std::optional<SmallVector<Value>> maybeInductionVars =
         loopLikeOp.getLoopInductionVars();
-    EXPECT_TRUE(maybeInductionVars.has_value());
+    ASSERT_TRUE(maybeInductionVars.has_value());
     EXPECT_EQ((*maybeInductionVars).size(), 2u);
   }
 

>From 6336fdf28f06c7525bcf6822a386f8c4cabe3c2d Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 7 Jun 2024 15:25:40 -0500
Subject: [PATCH 12/34] update after rebase

---
 mlir/lib/Dialect/SCF/Utils/Utils.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index ce20730459c2a..e3660e89fb684 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1198,9 +1198,9 @@ static bool hasNestedParallelOp(scf::ParallelOp ploop) {
 bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
                                          LoopLikeOpInterface &source) {
   auto iterSpaceEq =
-      target.getMixedLowerBound() == source.getMixedLowerBound() &&
-      target.getMixedUpperBound() == source.getMixedUpperBound() &&
-      target.getMixedStep() == source.getMixedStep();
+      target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
+      target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
+      target.getLoopSteps() == source.getLoopSteps();
   auto forAllTarget = dyn_cast<scf::ForallOp>(*target);
   auto forAllSource = dyn_cast<scf::ForallOp>(*source);
   if (forAllTarget && forAllSource)

>From 86406c335dd216ac91a941102c060bc680af10b1 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 8 Jun 2024 23:31:57 -0500
Subject: [PATCH 13/34] refactor main parallel fusion logic from fuseIfLegal to
 util func

---
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h   |   6 -
 .../SCF/Transforms/ParallelLoopFusion.cpp     | 159 +++++++++++++++++-
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 140 ++++++---------
 3 files changed, 208 insertions(+), 97 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index ab9d154aa480d..ac4434b337890 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -163,12 +163,6 @@ void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
 bool checkFusionStructuralLegality(LoopLikeOpInterface &target,
                                    LoopLikeOpInterface &source);
 
-/// Prepends operations of firstPloop's body into secondPloop's body.
-/// Updates secondPloop with new loop.
-void fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
-                 OpBuilder builder,
-                 llvm::function_ref<bool(Value, Value)> mayAlias);
-
 /// Given two scf.forall loops, `target` and `source`, fuses `target` into
 /// `source`. Assumes that the given loops are siblings and are independent of
 /// each other.
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index abac91cfaf7d9..326a8f93162b9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -31,6 +31,163 @@ namespace mlir {
 using namespace mlir;
 using namespace mlir::scf;
 
+/// Verify there are no nested ParallelOps.
+static bool hasNestedParallelOp(ParallelOp ploop) {
+  auto walkResult =
+      ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
+  return walkResult.wasInterrupted();
+}
+
+/// Verify equal iteration spaces.
+static bool equalIterationSpaces(ParallelOp firstPloop,
+                                 ParallelOp secondPloop) {
+  if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
+    return false;
+
+  auto matchOperands = [&](const OperandRange &lhs,
+                           const OperandRange &rhs) -> bool {
+    // TODO: Extend this to support aliases and equal constants.
+    return std::equal(lhs.begin(), lhs.end(), rhs.begin());
+  };
+  return matchOperands(firstPloop.getLowerBound(),
+                       secondPloop.getLowerBound()) &&
+         matchOperands(firstPloop.getUpperBound(),
+                       secondPloop.getUpperBound()) &&
+         matchOperands(firstPloop.getStep(), secondPloop.getStep());
+}
+
+/// Checks if the parallel loops have mixed access to the same buffers. Returns
+/// `true` if the first parallel loop writes to the same indices that the second
+/// loop reads.
+static bool haveNoReadsAfterWriteExceptSameIndex(
+    ParallelOp firstPloop, ParallelOp secondPloop,
+    const IRMapping &firstToSecondPloopIndices,
+    llvm::function_ref<bool(Value, Value)> mayAlias) {
+  DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
+  SmallVector<Value> bufferStoresVec;
+  firstPloop.getBody()->walk([&](memref::StoreOp store) {
+    bufferStores[store.getMemRef()].push_back(store.getIndices());
+    bufferStoresVec.emplace_back(store.getMemRef());
+  });
+  auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
+    Value loadMem = load.getMemRef();
+    // Stop if the memref is defined in secondPloop body. Careful alias analysis
+    // is needed.
+    auto *memrefDef = loadMem.getDefiningOp();
+    if (memrefDef && memrefDef->getBlock() == load->getBlock())
+      return WalkResult::interrupt();
+
+    for (Value store : bufferStoresVec)
+      if (store != loadMem && mayAlias(store, loadMem))
+        return WalkResult::interrupt();
+
+    auto write = bufferStores.find(loadMem);
+    if (write == bufferStores.end())
+      return WalkResult::advance();
+
+    // Check that at last one store was retrieved
+    if (!write->second.size())
+      return WalkResult::interrupt();
+
+    auto storeIndices = write->second.front();
+
+    // Multiple writes to the same memref are allowed only on the same indices
+    for (const auto &othStoreIndices : write->second) {
+      if (othStoreIndices != storeIndices)
+        return WalkResult::interrupt();
+    }
+
+    // Check that the load indices of secondPloop coincide with store indices of
+    // firstPloop for the same memrefs.
+    auto loadIndices = load.getIndices();
+    if (storeIndices.size() != loadIndices.size())
+      return WalkResult::interrupt();
+    for (int i = 0, e = storeIndices.size(); i < e; ++i) {
+      if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
+          loadIndices[i]) {
+        auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
+        auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
+        if (storeIndexDefOp && loadIndexDefOp) {
+          if (!isMemoryEffectFree(storeIndexDefOp))
+            return WalkResult::interrupt();
+          if (!isMemoryEffectFree(loadIndexDefOp))
+            return WalkResult::interrupt();
+          if (!OperationEquivalence::isEquivalentTo(
+                  storeIndexDefOp, loadIndexDefOp,
+                  [&](Value storeIndex, Value loadIndex) {
+                    if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
+                        firstToSecondPloopIndices.lookupOrDefault(loadIndex))
+                      return failure();
+                    else
+                      return success();
+                  },
+                  /*markEquivalent=*/nullptr,
+                  OperationEquivalence::Flags::IgnoreLocations)) {
+            return WalkResult::interrupt();
+          }
+        } else
+          return WalkResult::interrupt();
+      }
+    }
+    return WalkResult::advance();
+  });
+  return !walkResult.wasInterrupted();
+}
+
+/// Analyzes dependencies in the most primitive way by checking simple read and
+/// write patterns.
+static LogicalResult
+verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
+                   const IRMapping &firstToSecondPloopIndices,
+                   llvm::function_ref<bool(Value, Value)> mayAlias) {
+  if (!haveNoReadsAfterWriteExceptSameIndex(
+          firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
+    return failure();
+
+  IRMapping secondToFirstPloopIndices;
+  secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
+                                firstPloop.getBody()->getArguments());
+  return success(haveNoReadsAfterWriteExceptSameIndex(
+      secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
+}
+
+static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
+                          const IRMapping &firstToSecondPloopIndices,
+                          llvm::function_ref<bool(Value, Value)> mayAlias) {
+  return !hasNestedParallelOp(firstPloop) &&
+         !hasNestedParallelOp(secondPloop) &&
+         equalIterationSpaces(firstPloop, secondPloop) &&
+         succeeded(verifyDependencies(firstPloop, secondPloop,
+                                      firstToSecondPloopIndices, mayAlias));
+}
+
+/// Prepends operations of firstPloop's body into secondPloop's body.
+/// Updates secondPloop with new loop.
+static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
+                        OpBuilder builder,
+                        llvm::function_ref<bool(Value, Value)> mayAlias) {
+  Block *block1 = firstPloop.getBody();
+  Block *block2 = secondPloop.getBody();
+  IRMapping firstToSecondPloopIndices;
+  firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
+
+  if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
+                     mayAlias))
+    return;
+
+  DominanceInfo dom;
+  // We are fusing first loop into second, make sure there are no users of the
+  // first loop results between loops.
+  for (Operation *user : firstPloop->getUsers())
+    if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
+      return;
+
+  IRRewriter rewriter(builder);
+  secondPloop = mlir::fuseIndependentSiblingParallelLoops(
+      firstPloop, secondPloop, rewriter);
+  ;
+}
+
 void mlir::scf::naivelyFuseParallelOps(
     Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
   OpBuilder b(region);
@@ -59,7 +216,7 @@ void mlir::scf::naivelyFuseParallelOps(
     }
     for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
-        mlir::fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
+        fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
     }
   }
 }
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index e3660e89fb684..5f58767be409d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1188,13 +1188,6 @@ static bool equalIterationSpaces(scf::ParallelOp firstPloop,
 // Fusion related helpers
 //===----------------------------------------------------------------------===//
 
-/// Verify there are no nested ParallelOps.
-static bool hasNestedParallelOp(scf::ParallelOp ploop) {
-  auto walkResult = ploop.getBody()->walk(
-      [](scf::ParallelOp) { return WalkResult::interrupt(); });
-  return walkResult.wasInterrupted();
-}
-
 bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
                                          LoopLikeOpInterface &source) {
   auto iterSpaceEq =
@@ -1209,86 +1202,6 @@ bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
   return iterSpaceEq;
 }
 
-static bool isFusionLegal(scf::ParallelOp firstPloop,
-                          scf::ParallelOp secondPloop,
-                          const IRMapping &firstToSecondPloopIndices,
-                          llvm::function_ref<bool(Value, Value)> mayAlias) {
-  return !hasNestedParallelOp(firstPloop) &&
-         !hasNestedParallelOp(secondPloop) &&
-         equalIterationSpaces(firstPloop, secondPloop) &&
-         succeeded(verifyDependencies(firstPloop, secondPloop,
-                                      firstToSecondPloopIndices, mayAlias));
-}
-
-void mlir::fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
-                       OpBuilder builder,
-                       llvm::function_ref<bool(Value, Value)> mayAlias) {
-  Block *block1 = firstPloop.getBody();
-  Block *block2 = secondPloop.getBody();
-  IRMapping firstToSecondPloopIndices;
-  firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
-
-  if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
-                     mayAlias))
-    return;
-
-  DominanceInfo dom;
-  // We are fusing first loop into second, make sure there are no users of the
-  // first loop results between loops.
-  for (Operation *user : firstPloop->getUsers())
-    if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
-      return;
-
-  ValueRange inits1 = firstPloop.getInitVals();
-  ValueRange inits2 = secondPloop.getInitVals();
-
-  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
-  newInitVars.append(inits2.begin(), inits2.end());
-
-  IRRewriter b(builder);
-  b.setInsertionPoint(secondPloop);
-  auto newSecondPloop = b.create<scf::ParallelOp>(
-      secondPloop.getLoc(), secondPloop.getLowerBound(),
-      secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
-
-  Block *newBlock = newSecondPloop.getBody();
-  auto term1 = cast<scf::ReduceOp>(block1->getTerminator());
-  auto term2 = cast<scf::ReduceOp>(block2->getTerminator());
-
-  b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
-                      newBlock->getArguments());
-  b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
-                      newBlock->getArguments());
-
-  ValueRange results = newSecondPloop.getResults();
-  if (!results.empty()) {
-    b.setInsertionPointToEnd(newBlock);
-
-    ValueRange reduceArgs1 = term1.getOperands();
-    ValueRange reduceArgs2 = term2.getOperands();
-    SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
-    newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
-
-    auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
-
-    for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
-             term1.getReductions(), term2.getReductions()))) {
-      Block &oldRedBlock = reg.front();
-      Block &newRedBlock = newReduceOp.getReductions()[i].front();
-      b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
-                          newRedBlock.getArguments());
-    }
-
-    firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
-    secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
-  }
-  term1->erase();
-  term2->erase();
-  firstPloop.erase();
-  secondPloop.erase();
-  secondPloop = newSecondPloop;
-}
-
 scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
                                                       scf::ForallOp source,
                                                       RewriterBase &rewriter) {
@@ -1393,7 +1306,54 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
 
 scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
     scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) {
-  auto mayAlias = [&](Value val1, Value val2) -> bool { return false; };
-  mlir::fuseIfLegal(target, source, rewriter, mayAlias);
-  return source;
+  Block *block1 = target.getBody();
+  Block *block2 = source.getBody();
+  auto term1 = cast<scf::ReduceOp>(block1->getTerminator());
+  auto term2 = cast<scf::ReduceOp>(block2->getTerminator());
+
+  ValueRange inits1 = target.getInitVals();
+  ValueRange inits2 = source.getInitVals();
+
+  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+  newInitVars.append(inits2.begin(), inits2.end());
+
+  rewriter.setInsertionPoint(source);
+  auto fusedLoop = rewriter.create<scf::ParallelOp>(
+      source.getLoc(), source.getLowerBound(), source.getUpperBound(),
+      source.getStep(), newInitVars);
+  Block *newBlock = fusedLoop.getBody();
+  rewriter.inlineBlockBefore(block2, newBlock, newBlock->begin(),
+                             newBlock->getArguments());
+  rewriter.inlineBlockBefore(block1, newBlock, newBlock->begin(),
+                             newBlock->getArguments());
+
+  ValueRange results = fusedLoop.getResults();
+  if (!results.empty()) {
+    rewriter.setInsertionPointToEnd(newBlock);
+
+    ValueRange reduceArgs1 = term1.getOperands();
+    ValueRange reduceArgs2 = term2.getOperands();
+    SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
+    newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
+
+    auto newReduceOp =
+        rewriter.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
+
+    for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
+             term1.getReductions(), term2.getReductions()))) {
+      Block &oldRedBlock = reg.front();
+      Block &newRedBlock = newReduceOp.getReductions()[i].front();
+      rewriter.inlineBlockBefore(&oldRedBlock, &newRedBlock,
+                                 newRedBlock.begin(),
+                                 newRedBlock.getArguments());
+    }
+    target.replaceAllUsesWith(results.take_front(inits1.size()));
+    source.replaceAllUsesWith(results.take_back(inits2.size()));
+  }
+  term1->erase();
+  term2->erase();
+  target.erase();
+  source.erase();
+
+  return fusedLoop;
 }

>From 694d589dc535892f3dda9d27c2a43052fc0b445e Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 8 Jun 2024 23:35:30 -0500
Subject: [PATCH 14/34] remove unused functions

---
 mlir/lib/Dialect/SCF/Utils/Utils.cpp | 113 ---------------------------
 1 file changed, 113 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 5f58767be409d..e6cb88c427da8 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1071,119 +1071,6 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
   return tileLoops;
 }
 
-/// Checks if the parallel loops have mixed access to the same buffers. Returns
-/// `true` if the first parallel loop writes to the same indices that the second
-/// loop reads.
-static bool haveNoReadsAfterWriteExceptSameIndex(
-    scf::ParallelOp firstPloop, scf::ParallelOp secondPloop,
-    const IRMapping &firstToSecondPloopIndices,
-    llvm::function_ref<bool(Value, Value)> mayAlias) {
-  DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
-  SmallVector<Value> bufferStoresVec;
-  firstPloop.getBody()->walk([&](memref::StoreOp store) {
-    bufferStores[store.getMemRef()].push_back(store.getIndices());
-    bufferStoresVec.emplace_back(store.getMemRef());
-  });
-  auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
-    Value loadMem = load.getMemRef();
-    // Stop if the memref is defined in secondPloop body. Careful alias analysis
-    // is needed.
-    auto *memrefDef = loadMem.getDefiningOp();
-    if (memrefDef && memrefDef->getBlock() == load->getBlock())
-      return WalkResult::interrupt();
-
-    for (Value store : bufferStoresVec)
-      if (store != loadMem && mayAlias(store, loadMem))
-        return WalkResult::interrupt();
-
-    auto write = bufferStores.find(loadMem);
-    if (write == bufferStores.end())
-      return WalkResult::advance();
-
-    // Check that at last one store was retrieved
-    if (!write->second.size())
-      return WalkResult::interrupt();
-
-    auto storeIndices = write->second.front();
-
-    // Multiple writes to the same memref are allowed only on the same indices
-    for (const auto &othStoreIndices : write->second) {
-      if (othStoreIndices != storeIndices)
-        return WalkResult::interrupt();
-    }
-
-    // Check that the load indices of secondPloop coincide with store indices of
-    // firstPloop for the same memrefs.
-    auto loadIndices = load.getIndices();
-    if (storeIndices.size() != loadIndices.size())
-      return WalkResult::interrupt();
-    for (int i = 0, e = storeIndices.size(); i < e; ++i) {
-      if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
-          loadIndices[i]) {
-        auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
-        auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
-        if (storeIndexDefOp && loadIndexDefOp) {
-          if (!isMemoryEffectFree(storeIndexDefOp))
-            return WalkResult::interrupt();
-          if (!isMemoryEffectFree(loadIndexDefOp))
-            return WalkResult::interrupt();
-          if (!OperationEquivalence::isEquivalentTo(
-                  storeIndexDefOp, loadIndexDefOp,
-                  [&](Value storeIndex, Value loadIndex) {
-                    if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
-                        firstToSecondPloopIndices.lookupOrDefault(loadIndex))
-                      return failure();
-                    else
-                      return success();
-                  },
-                  /*markEquivalent=*/nullptr,
-                  OperationEquivalence::Flags::IgnoreLocations)) {
-            return WalkResult::interrupt();
-          }
-        } else
-          return WalkResult::interrupt();
-      }
-    }
-    return WalkResult::advance();
-  });
-  return !walkResult.wasInterrupted();
-}
-
-/// Analyzes dependencies in the most primitive way by checking simple read and
-/// write patterns.
-static LogicalResult
-verifyDependencies(scf::ParallelOp firstPloop, scf::ParallelOp secondPloop,
-                   const IRMapping &firstToSecondPloopIndices,
-                   llvm::function_ref<bool(Value, Value)> mayAlias) {
-  if (!haveNoReadsAfterWriteExceptSameIndex(
-          firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
-    return failure();
-
-  IRMapping secondToFirstPloopIndices;
-  secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
-                                firstPloop.getBody()->getArguments());
-  return success(haveNoReadsAfterWriteExceptSameIndex(
-      secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
-}
-
-/// Verify equal iteration spaces.
-static bool equalIterationSpaces(scf::ParallelOp firstPloop,
-                                 scf::ParallelOp secondPloop) {
-  if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
-    return false;
-
-  auto matchOperands = [&](const OperandRange &lhs,
-                           const OperandRange &rhs) -> bool {
-    // TODO: Extend this to support aliases and equal constants.
-    return std::equal(lhs.begin(), lhs.end(), rhs.begin());
-  };
-  return matchOperands(firstPloop.getLowerBound(),
-                       secondPloop.getLowerBound()) &&
-         matchOperands(firstPloop.getUpperBound(),
-                       secondPloop.getUpperBound()) &&
-         matchOperands(firstPloop.getStep(), secondPloop.getStep());
-}
-
 //===----------------------------------------------------------------------===//
 // Fusion related helpers
 //===----------------------------------------------------------------------===//

>From 67cb64f1a773795bfb2d4e9f0c981dd502572676 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 9 Jun 2024 14:45:57 -0500
Subject: [PATCH 15/34] refactor fuseIndependentSiblingForLoops to reuse
 replaceWithAdditionalYields

---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td |   4 +-
 mlir/lib/Dialect/SCF/IR/SCF.cpp            |  54 ++++++++
 mlir/lib/Dialect/SCF/Utils/Utils.cpp       | 151 +++++++++++++++------
 3 files changed, 168 insertions(+), 41 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index f35ea962bea16..e7b9665f797fa 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -301,8 +301,8 @@ def ForallOp : SCF_Op<"forall", [
        AttrSizedOperandSegments,
        AutomaticAllocationScope,
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
-          ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars", 
-           "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
+          ["getInitsMutable", "getRegionIterArgs", "getLoopResults", "getLoopInductionVars", 
+           "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps", "replaceWithAdditionalYields",
            "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
        RecursiveMemoryEffects,
        SingleBlockImplicitTerminator<"scf::InParallelOp">,
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5e94f4dc612a7..6ad181e2f3d77 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -616,8 +616,62 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
   regions.push_back(RegionSuccessor(getResults()));
 }
 
+std::optional<ResultRange> ForallOp::getLoopResults() { return getResults(); }
+
 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
 
+FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
+    RewriterBase &rewriter, ValueRange newInitOperands,
+    bool replaceInitOperandUsesInLoop,
+    const NewYieldValuesFn &newYieldValuesFn) {
+  // Create a new loop before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(getOperation());
+  auto inits = llvm::to_vector(getOutputs());
+  inits.append(newInitOperands.begin(), newInitOperands.end());
+  scf::ForallOp newLoop = rewriter.create<scf::ForallOp>(
+      getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
+      inits, getMapping());
+
+  // Generate the new yield values and append them to the scf.yield operation.
+  auto yieldOp = cast<scf::InParallelOp>(getTerminator());
+  ArrayRef<BlockArgument> newIterArgs =
+      newLoop.getBody()->getArguments().take_back(newInitOperands.size());
+  {
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(yieldOp);
+    SmallVector<Value> newYieldedValues =
+        newYieldValuesFn(rewriter, getLoc(), newIterArgs);
+    assert(newInitOperands.size() == newYieldedValues.size() &&
+           "expected as many new yield values as new iter operands");
+    // rewriter.modifyOpInPlace(yieldOp, [&]() {
+    //   yieldOp.getResultsMutable().append(newYieldedValues);
+    // });
+  }
+
+  // Move the loop body to the new op.
+  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
+                       newLoop.getBody()->getArguments().take_front(
+                           getBody()->getNumArguments()));
+
+  if (replaceInitOperandUsesInLoop) {
+    // Replace all uses of `newInitOperands` with the corresponding basic block
+    // arguments.
+    for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
+      rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
+                                 [&](OpOperand &use) {
+                                   Operation *user = use.getOwner();
+                                   return newLoop->isProperAncestor(user);
+                                 });
+    }
+  }
+
+  // Replace the old loop.
+  rewriter.replaceOp(getOperation(),
+                     newLoop->getResults().take_front(getNumResults()));
+  return cast<LoopLikeOpInterface>(newLoop.getOperation());
+}
+
 /// Promotes the loop body of a forallOp to its containing block if it can be
 /// determined that the loop has a single iteration.
 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index e6cb88c427da8..a61428208c405 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1089,6 +1089,92 @@ bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
   return iterSpaceEq;
 }
 
+template <typename LoopTy>
+void fuseTerminator(RewriterBase &rewriter, LoopTy target, LoopTy source,
+                    LoopTy &fused, IRMapping &mapping) {}
+
+template <>
+void fuseTerminator(RewriterBase &rewriter, scf::ForallOp target,
+                    scf::ForallOp source, scf::ForallOp &fused,
+                    IRMapping &mapping) {
+  // Fuse the old terminator in_parallel ops into the new one.
+  scf::InParallelOp targetTerm = target.getTerminator();
+  scf::InParallelOp sourceTerm = source.getTerminator();
+  scf::InParallelOp fusedTerm = fused.getTerminator();
+  rewriter.setInsertionPointToStart(fusedTerm.getBody());
+  for (Operation &op : targetTerm.getYieldingOps())
+    rewriter.clone(op, mapping);
+  for (Operation &op : sourceTerm.getYieldingOps())
+    rewriter.clone(op, mapping);
+}
+
+template <>
+void fuseTerminator(RewriterBase &rewriter, scf::ForOp target,
+                    scf::ForOp source, scf::ForOp &fused, IRMapping &mapping) {
+  // Build fused yield results by appropriately mapping original yield operands.
+  SmallVector<Value> yieldResults;
+  for (Value operand : target.getBody()->getTerminator()->getOperands())
+    yieldResults.push_back(mapping.lookupOrDefault(operand));
+  for (Value operand : source.getBody()->getTerminator()->getOperands())
+    yieldResults.push_back(mapping.lookupOrDefault(operand));
+  if (!yieldResults.empty())
+    rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
+}
+
+template <typename LoopTy>
+LoopTy createFused(LoopLikeOpInterface target, LoopLikeOpInterface source,
+                   RewriterBase &rewriter) {
+  auto targetResults = target.getLoopResults();
+  auto sourceResults = source.getLoopResults();
+  int64_t numTargetOuts = (*targetResults).size();
+  int64_t numSourceOuts = (*sourceResults).size();
+  printf("numTargetOuts %ld\n", numTargetOuts);
+
+  // Create fused shared_outs.
+  SmallVector<Value> fusedOuts;
+  llvm::append_range(fusedOuts, *targetResults);
+  llvm::append_range(fusedOuts, *sourceResults);
+
+  // Create a new scf.forall op after the source loop.
+  rewriter.setInsertionPointAfter(source);
+  // LoopTy fusedLoop = builder.create<LoopTy>(
+  //     source.getLoc(), source.getLoopLowerBounds(),
+  //     source.getLoopUpperBounds(), source.getLoopSteps(), fusedOuts,
+  //     source->getAttrs());
+  LoopTy fusedLoop = rewriter.cloneWithoutRegions(cast<LoopTy>(source));
+
+  // Map control operands.
+  IRMapping mapping;
+  mapping.map(*target.getLoopInductionVars(),
+              *fusedLoop.getLoopInductionVars());
+  mapping.map(*source.getLoopInductionVars(),
+              *fusedLoop.getLoopInductionVars());
+
+  // Map shared outs.
+  mapping.map(target.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+  mapping.map(source.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
+
+  // Append everything except the terminator into the fused operation.
+  rewriter.setInsertionPointToStart(
+      &fusedLoop.getLoopRegions().front()->front());
+  for (Operation &op :
+       target.getLoopRegions().front()->front().without_terminator())
+    rewriter.clone(op, mapping);
+  for (Operation &op :
+       source.getLoopRegions().front()->front().without_terminator())
+    rewriter.clone(op, mapping);
+
+  fuseTerminator<LoopTy>(rewriter, cast<LoopTy>(target), cast<LoopTy>(source),
+                         cast<LoopTy>(fusedLoop), mapping);
+
+  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
+  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+
+  return fusedLoop;
+}
+
 scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
                                                       scf::ForallOp source,
                                                       RewriterBase &rewriter) {
@@ -1144,50 +1230,37 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
 scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
                                                 scf::ForOp source,
                                                 RewriterBase &rewriter) {
-  unsigned numTargetOuts = target.getNumResults();
-  unsigned numSourceOuts = source.getNumResults();
-
-  // Create fused init_args, with target's init_args before source's init_args.
-  SmallVector<Value> fusedInitArgs;
-  llvm::append_range(fusedInitArgs, target.getInitArgs());
-  llvm::append_range(fusedInitArgs, source.getInitArgs());
-
-  // Create a new scf.for op after the source loop (with scf.yield terminator
-  // (without arguments) only in case its init_args is empty).
-  rewriter.setInsertionPointAfter(source);
-  scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
-      source.getLoc(), source.getLowerBound(), source.getUpperBound(),
-      source.getStep(), fusedInitArgs);
-
+  auto targetIterArgs = target.getRegionIterArgs();
+  auto targetInductionVar = target.getInductionVar();
+  SmallVector<Value> targetYieldOperands(source.getYieldedValues());
+  auto sourceIterArgs = source.getRegionIterArgs();
+  auto sourceInductionVar = source.getInductionVar();
+  SmallVector<Value> sourceYieldOperands(source.getYieldedValues());
+  scf::ForOp fusedLoop = cast<scf::ForOp>(*target.replaceWithAdditionalYields(
+      rewriter, source.getInitArgs(), /*replaceInitOperandUsesInLoop=*/false,
+      [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+        return sourceYieldOperands;
+      }));
   // Map original induction variables and operands to those of the fused loop.
   IRMapping mapping;
-  mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
-  mapping.map(target.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
-  mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
-  mapping.map(source.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
-
+  mapping.map(targetInductionVar, fusedLoop.getInductionVar());
+  mapping.map(targetIterArgs,
+              fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size()));
+  mapping.map(targetYieldOperands,
+              fusedLoop.getYieldedValues().take_front(targetIterArgs.size()));
+  mapping.map(sourceInductionVar, fusedLoop.getInductionVar());
+  mapping.map(sourceIterArgs,
+              fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size()));
+  mapping.map(sourceYieldOperands,
+              fusedLoop.getYieldedValues().take_back(sourceIterArgs.size()));
   // Merge target's body into the new (fused) for loop and then source's body.
-  rewriter.setInsertionPointToStart(fusedLoop.getBody());
-  for (Operation &op : target.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
+  rewriter.setInsertionPoint(fusedLoop.getBody()->getTerminator());
   for (Operation &op : source.getBody()->without_terminator())
     rewriter.clone(op, mapping);
-
-  // Build fused yield results by appropriately mapping original yield operands.
-  SmallVector<Value> yieldResults;
-  for (Value operand : target.getBody()->getTerminator()->getOperands())
-    yieldResults.push_back(mapping.lookupOrDefault(operand));
-  for (Value operand : source.getBody()->getTerminator()->getOperands())
-    yieldResults.push_back(mapping.lookupOrDefault(operand));
-  if (!yieldResults.empty())
-    rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
-
-  // Replace old loops by substituting their uses by results of the fused loop.
-  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
-  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
-
+  auto newTerm = rewriter.clone(*fusedLoop.getBody()->getTerminator(), mapping);
+  rewriter.replaceOp(fusedLoop.getBody()->getTerminator(), newTerm);
+  rewriter.replaceOp(source,
+                     fusedLoop.getResults().take_back(source.getNumResults()));
   return fusedLoop;
 }
 

>From cc8599f69a90f9b460bbff950505004c214ac72e Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 9 Jun 2024 17:20:46 -0500
Subject: [PATCH 16/34] refactor fuseIndependentSiblingForallLoops to reuse
 replaceWithAdditionalYields

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 12 +---
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 69 +++++++++----------
 .../SCF/transform-loop-fuse-sibling.mlir      |  3 +-
 3 files changed, 36 insertions(+), 48 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 6ad181e2f3d77..6850d632f10d0 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -637,17 +637,7 @@ FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
   auto yieldOp = cast<scf::InParallelOp>(getTerminator());
   ArrayRef<BlockArgument> newIterArgs =
       newLoop.getBody()->getArguments().take_back(newInitOperands.size());
-  {
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(yieldOp);
-    SmallVector<Value> newYieldedValues =
-        newYieldValuesFn(rewriter, getLoc(), newIterArgs);
-    assert(newInitOperands.size() == newYieldedValues.size() &&
-           "expected as many new yield values as new iter operands");
-    // rewriter.modifyOpInPlace(yieldOp, [&]() {
-    //   yieldOp.getResultsMutable().append(newYieldedValues);
-    // });
-  }
+  newLoop.getTerminator().erase();
 
   // Move the loop body to the new op.
   rewriter.mergeBlocks(getBody(), newLoop.getBody(),
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index a61428208c405..a822b4199fe9d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1128,7 +1128,6 @@ LoopTy createFused(LoopLikeOpInterface target, LoopLikeOpInterface source,
   auto sourceResults = source.getLoopResults();
   int64_t numTargetOuts = (*targetResults).size();
   int64_t numSourceOuts = (*sourceResults).size();
-  printf("numTargetOuts %ld\n", numTargetOuts);
 
   // Create fused shared_outs.
   SmallVector<Value> fusedOuts;
@@ -1178,51 +1177,49 @@ LoopTy createFused(LoopLikeOpInterface target, LoopLikeOpInterface source,
 scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
                                                       scf::ForallOp source,
                                                       RewriterBase &rewriter) {
-  unsigned numTargetOuts = target.getNumResults();
-  unsigned numSourceOuts = source.getNumResults();
-
-  // Create fused shared_outs.
-  SmallVector<Value> fusedOuts;
-  llvm::append_range(fusedOuts, target.getOutputs());
-  llvm::append_range(fusedOuts, source.getOutputs());
-
-  // Create a new scf.forall op after the source loop.
-  rewriter.setInsertionPointAfter(source);
-  scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
-      source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
-      source.getMixedStep(), fusedOuts, source.getMapping());
-
+  auto targetIterArgs = target.getRegionIterArgs();
+  auto targetInductionVar = target.getInductionVars();
+  SmallVector<Value> targetYieldOperands(target.getYieldedValues());
+  auto sourceIterArgs = source.getRegionIterArgs();
+  auto sourceInductionVar = source.getInductionVars();
+  scf::InParallelOp sourceTerm = source.getTerminator();
+  auto sourceYieldOps = sourceTerm.getYieldingOps();
+  auto sourceBody = source.getBody();
+  SmallVector<Value> sourceYieldOperands(llvm::map_range(
+      sourceTerm.getDests(), [](auto arg) { return cast<Value>(arg); }));
+  scf::ForallOp fusedLoop =
+      cast<scf::ForallOp>(*target.replaceWithAdditionalYields(
+          rewriter, source.getOutputs(), /*replaceInitOperandUsesInLoop=*/false,
+          [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+            for (Operation &op : sourceYieldOps)
+              b.clone(op);
+            return sourceYieldOperands;
+          }));
   // Map control operands.
   IRMapping mapping;
-  mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
-  mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
-
-  // Map shared outs.
-  mapping.map(target.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
-  mapping.map(source.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
-
+  mapping.map(targetInductionVar, fusedLoop.getInductionVars());
+  mapping.map(targetIterArgs,
+              fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size()));
+  mapping.map(targetYieldOperands,
+              fusedLoop.getYieldedValues().take_front(targetIterArgs.size()));
+  mapping.map(sourceInductionVar, fusedLoop.getInductionVars());
+  mapping.map(sourceIterArgs,
+              fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size()));
+  mapping.map(sourceYieldOperands,
+              fusedLoop.getYieldedValues().take_back(sourceIterArgs.size()));
   // Append everything except the terminator into the fused operation.
-  rewriter.setInsertionPointToStart(fusedLoop.getBody());
-  for (Operation &op : target.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
-  for (Operation &op : source.getBody()->without_terminator())
+  rewriter.setInsertionPoint(fusedLoop.getBody()->getTerminator());
+  for (Operation &op : sourceBody->without_terminator())
     rewriter.clone(op, mapping);
 
   // Fuse the old terminator in_parallel ops into the new one.
-  scf::InParallelOp targetTerm = target.getTerminator();
-  scf::InParallelOp sourceTerm = source.getTerminator();
   scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
-  rewriter.setInsertionPointToStart(fusedTerm.getBody());
-  for (Operation &op : targetTerm.getYieldingOps())
-    rewriter.clone(op, mapping);
+  rewriter.setInsertionPointToEnd(fusedTerm.getBody());
   for (Operation &op : sourceTerm.getYieldingOps())
     rewriter.clone(op, mapping);
 
-  // Replace old loops by substituting their uses by results of the fused loop.
-  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
-  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+  rewriter.replaceOp(source,
+                     fusedLoop.getResults().take_back(source.getNumResults()));
 
   return fusedLoop;
 }
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index 46c6be36c3271..47bfe0baa7651 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -189,7 +189,8 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK: func.func @matmul_fuse_2nd_forall_into_1st([[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+// CHECK-LABEL: func.func @matmul_fuse_2nd_forall_into_1st
+// CHECK-SAME:  [[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
 func.func @matmul_fuse_2nd_forall_into_1st(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>

>From 48b1af9cb4392b8ccad748e17ce40fa997db6a59 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 9 Jun 2024 19:12:40 -0500
Subject: [PATCH 17/34] wip

---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td |   4 +-
 mlir/lib/Dialect/SCF/IR/SCF.cpp            |   4 +
 mlir/lib/Dialect/SCF/Utils/Utils.cpp       | 170 ++++++---------------
 3 files changed, 50 insertions(+), 128 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index e7b9665f797fa..b9345f6ecdbb2 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -302,8 +302,8 @@ def ForallOp : SCF_Op<"forall", [
        AutomaticAllocationScope,
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
           ["getInitsMutable", "getRegionIterArgs", "getLoopResults", "getLoopInductionVars", 
-           "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps", "replaceWithAdditionalYields",
-           "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
+           "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps", "getYieldedValuesMutable",
+           "replaceWithAdditionalYields", "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
        RecursiveMemoryEffects,
        SingleBlockImplicitTerminator<"scf::InParallelOp">,
        DeclareOpInterfaceMethods<RegionBranchOpInterface>,
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 6850d632f10d0..b4a16e519a15a 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1472,6 +1472,10 @@ SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
   return storeOps;
 }
 
+std::optional<MutableArrayRef<OpOperand>> ForallOp::getYieldedValuesMutable() {
+  return getOutputsMutable();
+}
+
 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
   return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
 }
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index a822b4199fe9d..fb2d1d11fb6ae 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1090,134 +1090,76 @@ bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
 }
 
 template <typename LoopTy>
-void fuseTerminator(RewriterBase &rewriter, LoopTy target, LoopTy source,
-                    LoopTy &fused, IRMapping &mapping) {}
+void fuseTerminator(RewriterBase &rewriter, LoopTy source, LoopTy &fused,
+                    IRMapping &mapping) {}
 
 template <>
-void fuseTerminator(RewriterBase &rewriter, scf::ForallOp target,
-                    scf::ForallOp source, scf::ForallOp &fused,
-                    IRMapping &mapping) {
+void fuseTerminator(RewriterBase &rewriter, scf::ForallOp source,
+                    scf::ForallOp &fused, IRMapping &mapping) {
   // Fuse the old terminator in_parallel ops into the new one.
-  scf::InParallelOp targetTerm = target.getTerminator();
-  scf::InParallelOp sourceTerm = source.getTerminator();
   scf::InParallelOp fusedTerm = fused.getTerminator();
-  rewriter.setInsertionPointToStart(fusedTerm.getBody());
-  for (Operation &op : targetTerm.getYieldingOps())
-    rewriter.clone(op, mapping);
-  for (Operation &op : sourceTerm.getYieldingOps())
+  rewriter.setInsertionPointToEnd(fusedTerm.getBody());
+  for (Operation &op : source.getTerminator().getYieldingOps())
     rewriter.clone(op, mapping);
 }
 
 template <>
-void fuseTerminator(RewriterBase &rewriter, scf::ForOp target,
-                    scf::ForOp source, scf::ForOp &fused, IRMapping &mapping) {
+void fuseTerminator(RewriterBase &rewriter, scf::ForOp source,
+                    scf::ForOp &fused, IRMapping &mapping) {
   // Build fused yield results by appropriately mapping original yield operands.
-  SmallVector<Value> yieldResults;
-  for (Value operand : target.getBody()->getTerminator()->getOperands())
-    yieldResults.push_back(mapping.lookupOrDefault(operand));
-  for (Value operand : source.getBody()->getTerminator()->getOperands())
-    yieldResults.push_back(mapping.lookupOrDefault(operand));
-  if (!yieldResults.empty())
-    rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
+  auto newTerm = rewriter.clone(*fused.getBody()->getTerminator(), mapping);
+  rewriter.replaceOp(fused.getBody()->getTerminator(), newTerm);
 }
 
 template <typename LoopTy>
-LoopTy createFused(LoopLikeOpInterface target, LoopLikeOpInterface source,
-                   RewriterBase &rewriter) {
-  auto targetResults = target.getLoopResults();
-  auto sourceResults = source.getLoopResults();
-  int64_t numTargetOuts = (*targetResults).size();
-  int64_t numSourceOuts = (*sourceResults).size();
-
-  // Create fused shared_outs.
-  SmallVector<Value> fusedOuts;
-  llvm::append_range(fusedOuts, *targetResults);
-  llvm::append_range(fusedOuts, *sourceResults);
-
-  // Create a new scf.forall op after the source loop.
-  rewriter.setInsertionPointAfter(source);
-  // LoopTy fusedLoop = builder.create<LoopTy>(
-  //     source.getLoc(), source.getLoopLowerBounds(),
-  //     source.getLoopUpperBounds(), source.getLoopSteps(), fusedOuts,
-  //     source->getAttrs());
-  LoopTy fusedLoop = rewriter.cloneWithoutRegions(cast<LoopTy>(source));
-
-  // Map control operands.
-  IRMapping mapping;
-  mapping.map(*target.getLoopInductionVars(),
-              *fusedLoop.getLoopInductionVars());
-  mapping.map(*source.getLoopInductionVars(),
-              *fusedLoop.getLoopInductionVars());
-
-  // Map shared outs.
-  mapping.map(target.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
-  mapping.map(source.getRegionIterArgs(),
-              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
-
-  // Append everything except the terminator into the fused operation.
-  rewriter.setInsertionPointToStart(
-      &fusedLoop.getLoopRegions().front()->front());
-  for (Operation &op :
-       target.getLoopRegions().front()->front().without_terminator())
-    rewriter.clone(op, mapping);
-  for (Operation &op :
-       source.getLoopRegions().front()->front().without_terminator())
-    rewriter.clone(op, mapping);
-
-  fuseTerminator<LoopTy>(rewriter, cast<LoopTy>(target), cast<LoopTy>(source),
-                         cast<LoopTy>(fusedLoop), mapping);
-
-  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
-  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
-
-  return fusedLoop;
-}
-
-scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
-                                                      scf::ForallOp source,
-                                                      RewriterBase &rewriter) {
+LoopLikeOpInterface
+createFused(LoopLikeOpInterface target, LoopLikeOpInterface source,
+            RewriterBase &rewriter, NewYieldValuesFn newYieldValuesFn) {
   auto targetIterArgs = target.getRegionIterArgs();
-  auto targetInductionVar = target.getInductionVars();
+  auto targetInductionVar = *target.getLoopInductionVars();
   SmallVector<Value> targetYieldOperands(target.getYieldedValues());
   auto sourceIterArgs = source.getRegionIterArgs();
-  auto sourceInductionVar = source.getInductionVars();
-  scf::InParallelOp sourceTerm = source.getTerminator();
-  auto sourceYieldOps = sourceTerm.getYieldingOps();
-  auto sourceBody = source.getBody();
-  SmallVector<Value> sourceYieldOperands(llvm::map_range(
-      sourceTerm.getDests(), [](auto arg) { return cast<Value>(arg); }));
-  scf::ForallOp fusedLoop =
-      cast<scf::ForallOp>(*target.replaceWithAdditionalYields(
-          rewriter, source.getOutputs(), /*replaceInitOperandUsesInLoop=*/false,
-          [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
-            for (Operation &op : sourceYieldOps)
-              b.clone(op);
-            return sourceYieldOperands;
-          }));
+  auto sourceInductionVar = *source.getLoopInductionVars();
+  SmallVector<Value> sourceYieldOperands(source.getYieldedValues());
+  auto sourceRegion = source.getLoopRegions().front();
+  LoopLikeOpInterface fusedLoop = *target.replaceWithAdditionalYields(
+      rewriter, source.getInits(), /*replaceInitOperandUsesInLoop=*/false,
+      newYieldValuesFn);
+
   // Map control operands.
   IRMapping mapping;
-  mapping.map(targetInductionVar, fusedLoop.getInductionVars());
+  mapping.map(targetInductionVar, *fusedLoop.getLoopInductionVars());
   mapping.map(targetIterArgs,
               fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size()));
   mapping.map(targetYieldOperands,
               fusedLoop.getYieldedValues().take_front(targetIterArgs.size()));
-  mapping.map(sourceInductionVar, fusedLoop.getInductionVars());
+  mapping.map(sourceInductionVar, *fusedLoop.getLoopInductionVars());
   mapping.map(sourceIterArgs,
               fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size()));
   mapping.map(sourceYieldOperands,
               fusedLoop.getYieldedValues().take_back(sourceIterArgs.size()));
   // Append everything except the terminator into the fused operation.
-  rewriter.setInsertionPoint(fusedLoop.getBody()->getTerminator());
-  for (Operation &op : sourceBody->without_terminator())
+  rewriter.setInsertionPoint(
+      fusedLoop.getLoopRegions().front()->front().getTerminator());
+  for (Operation &op : sourceRegion->front().without_terminator())
     rewriter.clone(op, mapping);
 
-  // Fuse the old terminator in_parallel ops into the new one.
-  scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
-  rewriter.setInsertionPointToEnd(fusedTerm.getBody());
-  for (Operation &op : sourceTerm.getYieldingOps())
-    rewriter.clone(op, mapping);
+  fuseTerminator<LoopTy>(rewriter, cast<LoopTy>(source),
+                         cast<LoopTy>(fusedLoop), mapping);
+
+  return fusedLoop;
+}
 
+scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
+                                                      scf::ForallOp source,
+                                                      RewriterBase &rewriter) {
+  scf::ForallOp fusedLoop = cast<scf::ForallOp>(createFused<scf::ForallOp>(
+      target, source, rewriter,
+      [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+        for (Operation &op : source.getTerminator().getYieldingOps())
+          b.clone(op);
+        return source.getYieldedValues();
+      }));
   rewriter.replaceOp(source,
                      fusedLoop.getResults().take_back(source.getNumResults()));
 
@@ -1227,35 +1169,11 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
 scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
                                                 scf::ForOp source,
                                                 RewriterBase &rewriter) {
-  auto targetIterArgs = target.getRegionIterArgs();
-  auto targetInductionVar = target.getInductionVar();
-  SmallVector<Value> targetYieldOperands(source.getYieldedValues());
-  auto sourceIterArgs = source.getRegionIterArgs();
-  auto sourceInductionVar = source.getInductionVar();
-  SmallVector<Value> sourceYieldOperands(source.getYieldedValues());
-  scf::ForOp fusedLoop = cast<scf::ForOp>(*target.replaceWithAdditionalYields(
-      rewriter, source.getInitArgs(), /*replaceInitOperandUsesInLoop=*/false,
+  scf::ForOp fusedLoop = cast<scf::ForOp>(createFused<scf::ForOp>(
+      target, source, rewriter,
       [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
-        return sourceYieldOperands;
+        return source.getYieldedValues();
       }));
-  // Map original induction variables and operands to those of the fused loop.
-  IRMapping mapping;
-  mapping.map(targetInductionVar, fusedLoop.getInductionVar());
-  mapping.map(targetIterArgs,
-              fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size()));
-  mapping.map(targetYieldOperands,
-              fusedLoop.getYieldedValues().take_front(targetIterArgs.size()));
-  mapping.map(sourceInductionVar, fusedLoop.getInductionVar());
-  mapping.map(sourceIterArgs,
-              fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size()));
-  mapping.map(sourceYieldOperands,
-              fusedLoop.getYieldedValues().take_back(sourceIterArgs.size()));
-  // Merge target's body into the new (fused) for loop and then source's body.
-  rewriter.setInsertionPoint(fusedLoop.getBody()->getTerminator());
-  for (Operation &op : source.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
-  auto newTerm = rewriter.clone(*fusedLoop.getBody()->getTerminator(), mapping);
-  rewriter.replaceOp(fusedLoop.getBody()->getTerminator(), newTerm);
   rewriter.replaceOp(source,
                      fusedLoop.getResults().take_back(source.getNumResults()));
   return fusedLoop;

>From 7a51cb34afd5d8a2b67cceaef457f50c032affbd Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Mon, 17 Jun 2024 14:36:44 -0500
Subject: [PATCH 18/34] Decouple concrete loop type from `createFused` function

---
 mlir/lib/Dialect/SCF/Utils/Utils.cpp | 35 +++++++++++++++++++++-------
 1 file changed, 27 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index fb2d1d11fb6ae..910c41b3e3d54 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1111,10 +1111,29 @@ void fuseTerminator(RewriterBase &rewriter, scf::ForOp source,
   rewriter.replaceOp(fused.getBody()->getTerminator(), newTerm);
 }
 
-template <typename LoopTy>
-LoopLikeOpInterface
-createFused(LoopLikeOpInterface target, LoopLikeOpInterface source,
-            RewriterBase &rewriter, NewYieldValuesFn newYieldValuesFn) {
+// TODO: We should maybe add this as a method to LoopLikeOpInterface.
+//       For now, this acts as a placeholder.
+template <>
+void fuseTerminator(RewriterBase &rewriter, LoopLikeOpInterface source,
+                    LoopLikeOpInterface &fused, IRMapping &mapping) {
+  if (isa<scf::ForOp>(source) && isa<scf::ForOp>(fused)) {
+    fuseTerminator(rewriter, cast<scf::ForOp>(source), cast<scf::ForOp>(fused),
+                   mapping);
+  } else if (isa<scf::ForallOp>(source) && isa<scf::ForallOp>(fused)) {
+    fuseTerminator(rewriter, cast<scf::ForallOp>(source),
+                   cast<scf::ForallOp>(fused), mapping);
+  } else if (isa<scf::ParallelOp>(source) && isa<scf::ParallelOp>(fused)) {
+    fuseTerminator(rewriter, cast<scf::ParallelOp>(source),
+                   cast<scf::ParallelOp>(fused), mapping);
+  } else {
+    return;
+  }
+}
+
+LoopLikeOpInterface createFused(LoopLikeOpInterface target,
+                                LoopLikeOpInterface source,
+                                RewriterBase &rewriter,
+                                NewYieldValuesFn newYieldValuesFn) {
   auto targetIterArgs = target.getRegionIterArgs();
   auto targetInductionVar = *target.getLoopInductionVars();
   SmallVector<Value> targetYieldOperands(target.getYieldedValues());
@@ -1144,8 +1163,8 @@ createFused(LoopLikeOpInterface target, LoopLikeOpInterface source,
   for (Operation &op : sourceRegion->front().without_terminator())
     rewriter.clone(op, mapping);
 
-  fuseTerminator<LoopTy>(rewriter, cast<LoopTy>(source),
-                         cast<LoopTy>(fusedLoop), mapping);
+  // TODO: Replace with interface method if added
+  fuseTerminator(rewriter, source, fusedLoop, mapping);
 
   return fusedLoop;
 }
@@ -1153,7 +1172,7 @@ createFused(LoopLikeOpInterface target, LoopLikeOpInterface source,
 scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
                                                       scf::ForallOp source,
                                                       RewriterBase &rewriter) {
-  scf::ForallOp fusedLoop = cast<scf::ForallOp>(createFused<scf::ForallOp>(
+  scf::ForallOp fusedLoop = cast<scf::ForallOp>(createFused(
       target, source, rewriter,
       [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
         for (Operation &op : source.getTerminator().getYieldingOps())
@@ -1169,7 +1188,7 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
 scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
                                                 scf::ForOp source,
                                                 RewriterBase &rewriter) {
-  scf::ForOp fusedLoop = cast<scf::ForOp>(createFused<scf::ForOp>(
+  scf::ForOp fusedLoop = cast<scf::ForOp>(createFused(
       target, source, rewriter,
       [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
         return source.getYieldedValues();

>From 30873263faaab18267109231094af408b819059a Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Mon, 17 Jun 2024 15:10:44 -0500
Subject: [PATCH 19/34] Refactor ForallOp::replaceWithAdditionalYields

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp      | 9 +++------
 mlir/lib/Dialect/SCF/Utils/Utils.cpp | 5 ++---
 2 files changed, 5 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index b4a16e519a15a..c5a9e18e2610c 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -633,12 +633,7 @@ FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
       getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
       inits, getMapping());
 
-  // Generate the new yield values and append them to the scf.yield operation.
-  auto yieldOp = cast<scf::InParallelOp>(getTerminator());
-  ArrayRef<BlockArgument> newIterArgs =
-      newLoop.getBody()->getArguments().take_back(newInitOperands.size());
   newLoop.getTerminator().erase();
-
   // Move the loop body to the new op.
   rewriter.mergeBlocks(getBody(), newLoop.getBody(),
                        newLoop.getBody()->getArguments().take_front(
@@ -647,7 +642,9 @@ FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
   if (replaceInitOperandUsesInLoop) {
     // Replace all uses of `newInitOperands` with the corresponding basic block
     // arguments.
-    for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
+    for (auto it :
+         llvm::zip(newInitOperands, newLoop.getBody()->getArguments().take_back(
+                                        newInitOperands.size()))) {
       rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
                                  [&](OpOperand &use) {
                                    Operation *user = use.getOwner();
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 910c41b3e3d54..5ef6718bc5346 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1175,9 +1175,8 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
   scf::ForallOp fusedLoop = cast<scf::ForallOp>(createFused(
       target, source, rewriter,
       [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
-        for (Operation &op : source.getTerminator().getYieldingOps())
-          b.clone(op);
-        return source.getYieldedValues();
+        // `ForallOp` does not have yields, rather an `InParallelOp` terminator.
+        return ValueRange{};
       }));
   rewriter.replaceOp(source,
                      fusedLoop.getResults().take_back(source.getNumResults()));

>From bcf3d4aaed9e425f3a3b2d97660c6e816e333abe Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Mon, 17 Jun 2024 15:50:55 -0500
Subject: [PATCH 20/34] revert unnecessary changes

---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 7 ++++---
 mlir/lib/Dialect/SCF/IR/SCF.cpp            | 6 ------
 mlir/lib/Dialect/SCF/Utils/Utils.cpp       | 1 -
 3 files changed, 4 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index b9345f6ecdbb2..bf95fbe6721cf 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -301,9 +301,10 @@ def ForallOp : SCF_Op<"forall", [
        AttrSizedOperandSegments,
        AutomaticAllocationScope,
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
-          ["getInitsMutable", "getRegionIterArgs", "getLoopResults", "getLoopInductionVars", 
-           "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps", "getYieldedValuesMutable",
-           "replaceWithAdditionalYields", "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
+          ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars", 
+           "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
+           "replaceWithAdditionalYields", "promoteIfSingleIteration",
+           "yieldTiledValuesAndReplace"]>,
        RecursiveMemoryEffects,
        SingleBlockImplicitTerminator<"scf::InParallelOp">,
        DeclareOpInterfaceMethods<RegionBranchOpInterface>,
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c5a9e18e2610c..deface43028b1 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -616,8 +616,6 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
   regions.push_back(RegionSuccessor(getResults()));
 }
 
-std::optional<ResultRange> ForallOp::getLoopResults() { return getResults(); }
-
 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
 
 FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
@@ -1469,10 +1467,6 @@ SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
   return storeOps;
 }
 
-std::optional<MutableArrayRef<OpOperand>> ForallOp::getYieldedValuesMutable() {
-  return getOutputsMutable();
-}
-
 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
   return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
 }
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 5ef6718bc5346..2e61f9998a7d8 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -15,7 +15,6 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"

>From 0cb3c4ea08b22eea318fa47634914f921f08f7f2 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 18 Jun 2024 10:35:53 -0500
Subject: [PATCH 21/34] cleanup

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp                        | 4 ++--
 mlir/lib/Dialect/SCF/Utils/Utils.cpp                   | 6 +++---
 mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir | 3 +--
 3 files changed, 6 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index deface43028b1..2baef9ca45db1 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -625,13 +625,13 @@ FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
   // Create a new loop before the existing one, with the extra operands.
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(getOperation());
-  auto inits = llvm::to_vector(getOutputs());
+  SmallVector<Value> inits(getOutputs());
   inits.append(newInitOperands.begin(), newInitOperands.end());
   scf::ForallOp newLoop = rewriter.create<scf::ForallOp>(
       getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
       inits, getMapping());
 
-  newLoop.getTerminator().erase();
+  rewriter.eraseOp(newLoop.getTerminator());
   // Move the loop body to the new op.
   rewriter.mergeBlocks(getBody(), newLoop.getBody(),
                        newLoop.getBody()->getArguments().take_front(
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 2e61f9998a7d8..09da6e6233ffc 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1110,8 +1110,8 @@ void fuseTerminator(RewriterBase &rewriter, scf::ForOp source,
   rewriter.replaceOp(fused.getBody()->getTerminator(), newTerm);
 }
 
-// TODO: We should maybe add this as a method to LoopLikeOpInterface.
-//       For now, this acts as a placeholder.
+// TODO: We should maybe add a method to LoopLikeOpInterface that will
+// facilitate this transformation. For now, this acts as a placeholder.
 template <>
 void fuseTerminator(RewriterBase &rewriter, LoopLikeOpInterface source,
                     LoopLikeOpInterface &fused, IRMapping &mapping) {
@@ -1162,7 +1162,7 @@ LoopLikeOpInterface createFused(LoopLikeOpInterface target,
   for (Operation &op : sourceRegion->front().without_terminator())
     rewriter.clone(op, mapping);
 
-  // TODO: Replace with interface method if added
+  // TODO: Replace with corresponding interface method if added
   fuseTerminator(rewriter, source, fusedLoop, mapping);
 
   return fusedLoop;
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index 47bfe0baa7651..46c6be36c3271 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -189,8 +189,7 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-LABEL: func.func @matmul_fuse_2nd_forall_into_1st
-// CHECK-SAME:  [[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+// CHECK: func.func @matmul_fuse_2nd_forall_into_1st([[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
 func.func @matmul_fuse_2nd_forall_into_1st(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>

>From 7e41a549f966956204f6f0971831e0423a9aeb9d Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 21 Jun 2024 15:52:08 -0500
Subject: [PATCH 22/34] address some review comments

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp                     | 13 ++++++-------
 .../Dialect/SCF/TransformOps/SCFTransformOps.cpp    |  9 ++++++---
 mlir/lib/Dialect/SCF/Utils/Utils.cpp                | 13 ++++++-------
 3 files changed, 18 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 2baef9ca45db1..0c967ac68a081 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -626,7 +626,7 @@ FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(getOperation());
   SmallVector<Value> inits(getOutputs());
-  inits.append(newInitOperands.begin(), newInitOperands.end());
+  llvm::append_range(inits, newInitOperands);
   scf::ForallOp newLoop = rewriter.create<scf::ForallOp>(
       getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
       inits, getMapping());
@@ -640,14 +640,13 @@ FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
   if (replaceInitOperandUsesInLoop) {
     // Replace all uses of `newInitOperands` with the corresponding basic block
     // arguments.
-    for (auto it :
+    for (auto &&[newOperand, oldOperand] :
          llvm::zip(newInitOperands, newLoop.getBody()->getArguments().take_back(
                                         newInitOperands.size()))) {
-      rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
-                                 [&](OpOperand &use) {
-                                   Operation *user = use.getOwner();
-                                   return newLoop->isProperAncestor(user);
-                                 });
+      rewriter.replaceUsesWithIf(newOperand, oldOperand, [&](OpOperand &use) {
+        Operation *user = use.getOwner();
+        return newLoop->isProperAncestor(user);
+      });
     }
   }
 
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 99f92d7e24840..0e13b503098f0 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -261,8 +261,10 @@ loopScheduling(scf::ForOp forOp,
     return 1;
   };
 
-  std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound());
-  std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound());
+  std::optional<int64_t> ubConstant =
+      getConstantIntValue(forOp.getUpperBound());
+  std::optional<int64_t> lbConstant =
+      getConstantIntValue(forOp.getLowerBound());
   DenseMap<Operation *, unsigned> opCycles;
   std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
   for (Operation &op : forOp.getBody()->getOperations()) {
@@ -528,7 +530,8 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
            << "operations cannot be fused";
 
   Operation *fusedLoop;
-  /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
+  // TODO: Support fusion for loop-like ops besides scf.for, scf.forall
+  // and scf.parallel.
   if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
     fusedLoop = fuseIndependentSiblingForLoops(
         cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 09da6e6233ffc..dc15015e9bec2 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1076,7 +1076,7 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
 
 bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
                                          LoopLikeOpInterface &source) {
-  auto iterSpaceEq =
+  bool iterSpaceEq =
       target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
       target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
       target.getLoopSteps() == source.getLoopSteps();
@@ -1125,6 +1125,7 @@ void fuseTerminator(RewriterBase &rewriter, LoopLikeOpInterface source,
     fuseTerminator(rewriter, cast<scf::ParallelOp>(source),
                    cast<scf::ParallelOp>(fused), mapping);
   } else {
+    llvm_unreachable("unsupported loop types.");
     return;
   }
 }
@@ -1239,13 +1240,11 @@ scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
                                  newRedBlock.begin(),
                                  newRedBlock.getArguments());
     }
-    target.replaceAllUsesWith(results.take_front(inits1.size()));
-    source.replaceAllUsesWith(results.take_back(inits2.size()));
   }
-  term1->erase();
-  term2->erase();
-  target.erase();
-  source.erase();
+  rewriter.replaceOp(target, results.take_front(inits1.size()));
+  rewriter.replaceOp(source, results.take_back(inits2.size()));
+  rewriter.eraseOp(term1);
+  rewriter.eraseOp(term2);
 
   return fusedLoop;
 }

>From cc95d75d2cc09f8a33850f3867c8313e374a0dfd Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Mon, 24 Jun 2024 14:56:48 -0500
Subject: [PATCH 23/34] move `createFused` to `LoopLikeInterface.h`

---
 .../mlir/Interfaces/LoopLikeInterface.h       |  20 ++++
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 101 ++++--------------
 mlir/lib/Interfaces/LoopLikeInterface.cpp     |  42 ++++++++
 3 files changed, 82 insertions(+), 81 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index 42609e824c86a..d862439a07790 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -57,4 +57,24 @@ class HasParallelRegion : public TraitBase<ConcreteType, HasParallelRegion> {
 /// Include the generated interface declarations.
 #include "mlir/Interfaces/LoopLikeInterface.h.inc"
 
+namespace mlir {
+/// A function that rewrites `target`'s terminator as a teminator obtained by
+/// fusing `source` into `target`.
+using FuseTerminatorFn =
+    std::function<void(RewriterBase &rewriter, LoopLikeOpInterface source,
+                       LoopLikeOpInterface &target, IRMapping mapping)>;
+
+/// Returns a fused `LoopLikeOpInterface` created by fusing `source` to
+/// `target`.  The `NewYieldValuesFn` callback is used to pass to the
+/// `replaceWithAdditionalYields` interface method to replace the loop with a
+/// new loop with (possibly) additional yields, while the `FuseTerminatorFn`
+/// callback is repsonsible for updating the fused loop terminator.
+LoopLikeOpInterface createFused(LoopLikeOpInterface target,
+                                LoopLikeOpInterface source,
+                                RewriterBase &rewriter,
+                                NewYieldValuesFn newYieldValuesFn,
+                                FuseTerminatorFn fuseTerminatorFn);
+
+} // namespace mlir
+
 #endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index dc15015e9bec2..93e7a40845b2e 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1082,93 +1082,14 @@ bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
       target.getLoopSteps() == source.getLoopSteps();
   auto forAllTarget = dyn_cast<scf::ForallOp>(*target);
   auto forAllSource = dyn_cast<scf::ForallOp>(*source);
+  // TODO: Decouple checks on concrete loop types and move this function
+  // somewhere for general utility for `LoopLikeOpInterface`
   if (forAllTarget && forAllSource)
     return iterSpaceEq &&
            forAllTarget.getMapping() == forAllSource.getMapping();
   return iterSpaceEq;
 }
 
-template <typename LoopTy>
-void fuseTerminator(RewriterBase &rewriter, LoopTy source, LoopTy &fused,
-                    IRMapping &mapping) {}
-
-template <>
-void fuseTerminator(RewriterBase &rewriter, scf::ForallOp source,
-                    scf::ForallOp &fused, IRMapping &mapping) {
-  // Fuse the old terminator in_parallel ops into the new one.
-  scf::InParallelOp fusedTerm = fused.getTerminator();
-  rewriter.setInsertionPointToEnd(fusedTerm.getBody());
-  for (Operation &op : source.getTerminator().getYieldingOps())
-    rewriter.clone(op, mapping);
-}
-
-template <>
-void fuseTerminator(RewriterBase &rewriter, scf::ForOp source,
-                    scf::ForOp &fused, IRMapping &mapping) {
-  // Build fused yield results by appropriately mapping original yield operands.
-  auto newTerm = rewriter.clone(*fused.getBody()->getTerminator(), mapping);
-  rewriter.replaceOp(fused.getBody()->getTerminator(), newTerm);
-}
-
-// TODO: We should maybe add a method to LoopLikeOpInterface that will
-// facilitate this transformation. For now, this acts as a placeholder.
-template <>
-void fuseTerminator(RewriterBase &rewriter, LoopLikeOpInterface source,
-                    LoopLikeOpInterface &fused, IRMapping &mapping) {
-  if (isa<scf::ForOp>(source) && isa<scf::ForOp>(fused)) {
-    fuseTerminator(rewriter, cast<scf::ForOp>(source), cast<scf::ForOp>(fused),
-                   mapping);
-  } else if (isa<scf::ForallOp>(source) && isa<scf::ForallOp>(fused)) {
-    fuseTerminator(rewriter, cast<scf::ForallOp>(source),
-                   cast<scf::ForallOp>(fused), mapping);
-  } else if (isa<scf::ParallelOp>(source) && isa<scf::ParallelOp>(fused)) {
-    fuseTerminator(rewriter, cast<scf::ParallelOp>(source),
-                   cast<scf::ParallelOp>(fused), mapping);
-  } else {
-    llvm_unreachable("unsupported loop types.");
-    return;
-  }
-}
-
-LoopLikeOpInterface createFused(LoopLikeOpInterface target,
-                                LoopLikeOpInterface source,
-                                RewriterBase &rewriter,
-                                NewYieldValuesFn newYieldValuesFn) {
-  auto targetIterArgs = target.getRegionIterArgs();
-  auto targetInductionVar = *target.getLoopInductionVars();
-  SmallVector<Value> targetYieldOperands(target.getYieldedValues());
-  auto sourceIterArgs = source.getRegionIterArgs();
-  auto sourceInductionVar = *source.getLoopInductionVars();
-  SmallVector<Value> sourceYieldOperands(source.getYieldedValues());
-  auto sourceRegion = source.getLoopRegions().front();
-  LoopLikeOpInterface fusedLoop = *target.replaceWithAdditionalYields(
-      rewriter, source.getInits(), /*replaceInitOperandUsesInLoop=*/false,
-      newYieldValuesFn);
-
-  // Map control operands.
-  IRMapping mapping;
-  mapping.map(targetInductionVar, *fusedLoop.getLoopInductionVars());
-  mapping.map(targetIterArgs,
-              fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size()));
-  mapping.map(targetYieldOperands,
-              fusedLoop.getYieldedValues().take_front(targetIterArgs.size()));
-  mapping.map(sourceInductionVar, *fusedLoop.getLoopInductionVars());
-  mapping.map(sourceIterArgs,
-              fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size()));
-  mapping.map(sourceYieldOperands,
-              fusedLoop.getYieldedValues().take_back(sourceIterArgs.size()));
-  // Append everything except the terminator into the fused operation.
-  rewriter.setInsertionPoint(
-      fusedLoop.getLoopRegions().front()->front().getTerminator());
-  for (Operation &op : sourceRegion->front().without_terminator())
-    rewriter.clone(op, mapping);
-
-  // TODO: Replace with corresponding interface method if added
-  fuseTerminator(rewriter, source, fusedLoop, mapping);
-
-  return fusedLoop;
-}
-
 scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
                                                       scf::ForallOp source,
                                                       RewriterBase &rewriter) {
@@ -1177,6 +1098,15 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
       [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
         // `ForallOp` does not have yields, rather an `InParallelOp` terminator.
         return ValueRange{};
+      },
+      [&](RewriterBase &b, LoopLikeOpInterface source,
+          LoopLikeOpInterface &target, IRMapping mapping) {
+        auto sourceForall = cast<scf::ForallOp>(source);
+        auto targetForall = cast<scf::ForallOp>(target);
+        scf::InParallelOp fusedTerm = targetForall.getTerminator();
+        b.setInsertionPointToEnd(fusedTerm.getBody());
+        for (Operation &op : sourceForall.getTerminator().getYieldingOps())
+          b.clone(op, mapping);
       }));
   rewriter.replaceOp(source,
                      fusedLoop.getResults().take_back(source.getNumResults()));
@@ -1191,12 +1121,21 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
       target, source, rewriter,
       [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
         return source.getYieldedValues();
+      },
+      [&](RewriterBase &b, LoopLikeOpInterface source,
+          LoopLikeOpInterface &target, IRMapping mapping) {
+        auto sourceFor = cast<scf::ForOp>(source);
+        auto targetFor = cast<scf::ForOp>(target);
+        auto newTerm = b.clone(*targetFor.getBody()->getTerminator(), mapping);
+        b.replaceOp(targetFor.getBody()->getTerminator(), newTerm);
       }));
   rewriter.replaceOp(source,
                      fusedLoop.getResults().take_back(source.getNumResults()));
   return fusedLoop;
 }
 
+// TODO: Finish refactoring this a la the above, but likely requires additional
+// interface methods.
 scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
     scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) {
   Block *block1 = target.getBody();
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index 1e0e87b64e811..aefd388461570 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -8,6 +8,8 @@
 
 #include "mlir/Interfaces/LoopLikeInterface.h"
 
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "llvm/ADT/DenseSet.h"
 
@@ -113,3 +115,43 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
 
   return success();
 }
+
+LoopLikeOpInterface mlir::createFused(LoopLikeOpInterface target,
+                                      LoopLikeOpInterface source,
+                                      RewriterBase &rewriter,
+                                      NewYieldValuesFn newYieldValuesFn,
+                                      FuseTerminatorFn fuseTerminatorFn) {
+  auto targetIterArgs = target.getRegionIterArgs();
+  auto targetInductionVar = *target.getLoopInductionVars();
+  SmallVector<Value> targetYieldOperands(target.getYieldedValues());
+  auto sourceIterArgs = source.getRegionIterArgs();
+  auto sourceInductionVar = *source.getLoopInductionVars();
+  SmallVector<Value> sourceYieldOperands(source.getYieldedValues());
+  auto sourceRegion = source.getLoopRegions().front();
+  LoopLikeOpInterface fusedLoop = *target.replaceWithAdditionalYields(
+      rewriter, source.getInits(), /*replaceInitOperandUsesInLoop=*/false,
+      newYieldValuesFn);
+
+  // Map control operands.
+  IRMapping mapping;
+  mapping.map(targetInductionVar, *fusedLoop.getLoopInductionVars());
+  mapping.map(targetIterArgs,
+              fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size()));
+  mapping.map(targetYieldOperands,
+              fusedLoop.getYieldedValues().take_front(targetIterArgs.size()));
+  mapping.map(sourceInductionVar, *fusedLoop.getLoopInductionVars());
+  mapping.map(sourceIterArgs,
+              fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size()));
+  mapping.map(sourceYieldOperands,
+              fusedLoop.getYieldedValues().take_back(sourceIterArgs.size()));
+  // Append everything except the terminator into the fused operation.
+  rewriter.setInsertionPoint(
+      fusedLoop.getLoopRegions().front()->front().getTerminator());
+  for (Operation &op : sourceRegion->front().without_terminator())
+    rewriter.clone(op, mapping);
+
+  // TODO: Replace with corresponding interface method if added
+  fuseTerminatorFn(rewriter, source, fusedLoop, mapping);
+
+  return fusedLoop;
+}

>From 3430a36fda3c53d466550a7d8fd13b331f96f005 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 26 Jun 2024 13:51:34 -0500
Subject: [PATCH 24/34] address more review comments

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp      |  4 ++--
 mlir/lib/Dialect/SCF/Utils/Utils.cpp | 10 +++++-----
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 0c967ac68a081..1e42376ce58ca 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -629,9 +629,9 @@ FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
   llvm::append_range(inits, newInitOperands);
   scf::ForallOp newLoop = rewriter.create<scf::ForallOp>(
       getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
-      inits, getMapping());
+      inits, getMapping(),
+      /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
 
-  rewriter.eraseOp(newLoop.getTerminator());
   // Move the loop body to the new op.
   rewriter.mergeBlocks(getBody(), newLoop.getBody(),
                        newLoop.getBody()->getArguments().take_front(
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 93e7a40845b2e..e7496cd97cd63 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1124,7 +1124,6 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
       },
       [&](RewriterBase &b, LoopLikeOpInterface source,
           LoopLikeOpInterface &target, IRMapping mapping) {
-        auto sourceFor = cast<scf::ForOp>(source);
         auto targetFor = cast<scf::ForOp>(target);
         auto newTerm = b.clone(*targetFor.getBody()->getTerminator(), mapping);
         b.replaceOp(targetFor.getBody()->getTerminator(), newTerm);
@@ -1151,8 +1150,9 @@ scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
 
   rewriter.setInsertionPoint(source);
   auto fusedLoop = rewriter.create<scf::ParallelOp>(
-      source.getLoc(), source.getLowerBound(), source.getUpperBound(),
-      source.getStep(), newInitVars);
+      rewriter.getFusedLoc(target.getLoc(), source.getLoc()),
+      source.getLowerBound(), source.getUpperBound(), source.getStep(),
+      newInitVars);
   Block *newBlock = fusedLoop.getBody();
   rewriter.inlineBlockBefore(block2, newBlock, newBlock->begin(),
                              newBlock->getArguments());
@@ -1168,8 +1168,8 @@ scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
     SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
     newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
 
-    auto newReduceOp =
-        rewriter.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
+    auto newReduceOp = rewriter.create<scf::ReduceOp>(
+        rewriter.getFusedLoc(term1.getLoc(), term2.getLoc()), newReduceArgs);
 
     for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
              term1.getReductions(), term2.getReductions()))) {

>From 8447c121b95279a283b5e7b25f094f6abb062216 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 26 Jun 2024 20:06:43 -0500
Subject: [PATCH 25/34] switch to function_ref

---
 mlir/include/mlir/Interfaces/LoopLikeInterface.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index d862439a07790..cfe2c14b838f6 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -61,8 +61,8 @@ namespace mlir {
 /// A function that rewrites `target`'s terminator as a teminator obtained by
 /// fusing `source` into `target`.
 using FuseTerminatorFn =
-    std::function<void(RewriterBase &rewriter, LoopLikeOpInterface source,
-                       LoopLikeOpInterface &target, IRMapping mapping)>;
+    function_ref<void(RewriterBase &rewriter, LoopLikeOpInterface source,
+                      LoopLikeOpInterface &target, IRMapping mapping)>;
 
 /// Returns a fused `LoopLikeOpInterface` created by fusing `source` to
 /// `target`.  The `NewYieldValuesFn` callback is used to pass to the

>From fbd7b72bb44c7833a683d93fccaa9d992856ee8b Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 26 Jun 2024 22:52:20 -0500
Subject: [PATCH 26/34] check optional values

---
 mlir/lib/Interfaces/LoopLikeInterface.cpp | 27 +++++++++++++++++------
 1 file changed, 20 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index aefd388461570..6f0ebec0519be 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -122,24 +122,37 @@ LoopLikeOpInterface mlir::createFused(LoopLikeOpInterface target,
                                       NewYieldValuesFn newYieldValuesFn,
                                       FuseTerminatorFn fuseTerminatorFn) {
   auto targetIterArgs = target.getRegionIterArgs();
-  auto targetInductionVar = *target.getLoopInductionVars();
+  std::optional<SmallVector<Value>> targetInductionVar =
+      target.getLoopInductionVars();
   SmallVector<Value> targetYieldOperands(target.getYieldedValues());
   auto sourceIterArgs = source.getRegionIterArgs();
-  auto sourceInductionVar = *source.getLoopInductionVars();
+  std::optional<SmallVector<Value>> sourceInductionVar =
+      *source.getLoopInductionVars();
   SmallVector<Value> sourceYieldOperands(source.getYieldedValues());
   auto sourceRegion = source.getLoopRegions().front();
-  LoopLikeOpInterface fusedLoop = *target.replaceWithAdditionalYields(
-      rewriter, source.getInits(), /*replaceInitOperandUsesInLoop=*/false,
-      newYieldValuesFn);
+
+  FailureOr<LoopLikeOpInterface> maybeFusedLoop =
+      target.replaceWithAdditionalYields(rewriter, source.getInits(),
+                                         /*replaceInitOperandUsesInLoop=*/false,
+                                         newYieldValuesFn);
+  if (failed(maybeFusedLoop))
+    llvm_unreachable("failed to replace loop");
+  LoopLikeOpInterface fusedLoop = *maybeFusedLoop;
 
   // Map control operands.
   IRMapping mapping;
-  mapping.map(targetInductionVar, *fusedLoop.getLoopInductionVars());
+  std::optional<SmallVector<Value>> fusedInductionVar =
+      fusedLoop.getLoopInductionVars();
+  if (fusedInductionVar) {
+    if (!targetInductionVar || !sourceInductionVar)
+      llvm_unreachable("expected target and source loops to have induction vars");
+    mapping.map(*targetInductionVar, *fusedInductionVar);
+    mapping.map(*sourceInductionVar, *fusedInductionVar);
+  }
   mapping.map(targetIterArgs,
               fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size()));
   mapping.map(targetYieldOperands,
               fusedLoop.getYieldedValues().take_front(targetIterArgs.size()));
-  mapping.map(sourceInductionVar, *fusedLoop.getLoopInductionVars());
   mapping.map(sourceIterArgs,
               fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size()));
   mapping.map(sourceYieldOperands,

>From ffb73a7a76b382414f8f8295f6d6dc14a3edfa99 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 26 Jun 2024 23:34:38 -0500
Subject: [PATCH 27/34] replace equalIterationSpaces with
 checkFusionStructuredLegality

---
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h   |  5 +++--
 .../SCF/Transforms/ParallelLoopFusion.cpp     | 20 +------------------
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          |  4 ++--
 3 files changed, 6 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index ac4434b337890..ca3ab0aeae1de 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -160,8 +160,9 @@ void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
 // Fusion related helpers
 //===----------------------------------------------------------------------===//
 
-bool checkFusionStructuralLegality(LoopLikeOpInterface &target,
-                                   LoopLikeOpInterface &source);
+/// Check structural compatibility between two loops such as iteration space.
+bool checkFusionStructuralLegality(LoopLikeOpInterface target,
+                                   LoopLikeOpInterface source);
 
 /// Given two scf.forall loops, `target` and `source`, fuses `target` into
 /// `source`. Assumes that the given loops are siblings and are independent of
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 326a8f93162b9..fd57a9228186e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -38,24 +38,6 @@ static bool hasNestedParallelOp(ParallelOp ploop) {
   return walkResult.wasInterrupted();
 }
 
-/// Verify equal iteration spaces.
-static bool equalIterationSpaces(ParallelOp firstPloop,
-                                 ParallelOp secondPloop) {
-  if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
-    return false;
-
-  auto matchOperands = [&](const OperandRange &lhs,
-                           const OperandRange &rhs) -> bool {
-    // TODO: Extend this to support aliases and equal constants.
-    return std::equal(lhs.begin(), lhs.end(), rhs.begin());
-  };
-  return matchOperands(firstPloop.getLowerBound(),
-                       secondPloop.getLowerBound()) &&
-         matchOperands(firstPloop.getUpperBound(),
-                       secondPloop.getUpperBound()) &&
-         matchOperands(firstPloop.getStep(), secondPloop.getStep());
-}
-
 /// Checks if the parallel loops have mixed access to the same buffers. Returns
 /// `true` if the first parallel loop writes to the same indices that the second
 /// loop reads.
@@ -156,7 +138,7 @@ static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
                           llvm::function_ref<bool(Value, Value)> mayAlias) {
   return !hasNestedParallelOp(firstPloop) &&
          !hasNestedParallelOp(secondPloop) &&
-         equalIterationSpaces(firstPloop, secondPloop) &&
+         checkFusionStructuralLegality(firstPloop, secondPloop) &&
          succeeded(verifyDependencies(firstPloop, secondPloop,
                                       firstToSecondPloopIndices, mayAlias));
 }
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index e7496cd97cd63..fab6592d9eb2a 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1074,8 +1074,8 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
 // Fusion related helpers
 //===----------------------------------------------------------------------===//
 
-bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
-                                         LoopLikeOpInterface &source) {
+bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target,
+                                         LoopLikeOpInterface source) {
   bool iterSpaceEq =
       target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
       target.getLoopUpperBounds() == source.getLoopUpperBounds() &&

>From a6d0588da17170b1d3653efb51704b10d770dc58 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 27 Jun 2024 11:31:03 -0500
Subject: [PATCH 28/34] check if isOpSibling in checkFusionStructuralLegality

---
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h   |  6 +-
 .../SCF/TransformOps/SCFTransformOps.cpp      | 84 +----------------
 .../SCF/Transforms/ParallelLoopFusion.cpp     |  3 +-
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 91 ++++++++++++++++++-
 .../SCF/transform-loop-fuse-sibling.mlir      |  3 +-
 5 files changed, 99 insertions(+), 88 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index ca3ab0aeae1de..59aeff2da14ea 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -160,9 +160,11 @@ void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
 // Fusion related helpers
 //===----------------------------------------------------------------------===//
 
-/// Check structural compatibility between two loops such as iteration space.
+/// Check structural compatibility between two loops such as iteration space
+/// and dominance.
 bool checkFusionStructuralLegality(LoopLikeOpInterface target,
-                                   LoopLikeOpInterface source);
+                                   LoopLikeOpInterface source,
+                                   Diagnostic &diag);
 
 /// Given two scf.forall loops, `target` and `source`, fuses `target` into
 /// `source`. Assumes that the given loops are siblings and are independent of
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 0e13b503098f0..3e0a483615a3d 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -425,78 +425,6 @@ void transform::TakeAssumedBranchOp::getEffects(
 // LoopFuseSiblingOp
 //===----------------------------------------------------------------------===//
 
-/// Check if `target` and `source` are siblings, in the context that `target`
-/// is being fused into `source`.
-///
-/// This is a simple check that just checks if both operations are in the same
-/// block and some checks to ensure that the fused IR does not violate
-/// dominance.
-static DiagnosedSilenceableFailure isOpSibling(Operation *target,
-                                               Operation *source) {
-  // Check if both operations are same.
-  if (target == source)
-    return emitSilenceableFailure(source)
-           << "target and source need to be different loops";
-
-  // Check if both operations are in the same block.
-  if (target->getBlock() != source->getBlock())
-    return emitSilenceableFailure(source)
-           << "target and source are not in the same block";
-
-  // Check if fusion will violate dominance.
-  DominanceInfo domInfo(source);
-  if (target->isBeforeInBlock(source)) {
-    // Since `target` is before `source`, all users of results of `target`
-    // need to be dominated by `source`.
-    for (Operation *user : target->getUsers()) {
-      if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
-        return emitSilenceableFailure(target)
-               << "user of results of target should be properly dominated by "
-                  "source";
-      }
-    }
-  } else {
-    // Since `target` is after `source`, all values used by `target` need
-    // to dominate `source`.
-
-    // Check if operands of `target` are dominated by `source`.
-    for (Value operand : target->getOperands()) {
-      Operation *operandOp = operand.getDefiningOp();
-      // Operands without defining operations are block arguments. When `target`
-      // and `source` occur in the same block, these operands dominate `source`.
-      if (!operandOp)
-        continue;
-
-      // Operand's defining operation should properly dominate `source`.
-      if (!domInfo.properlyDominates(operandOp, source,
-                                     /*enclosingOpOk=*/false))
-        return emitSilenceableFailure(target)
-               << "operands of target should be properly dominated by source";
-    }
-
-    // Check if values used by `target` are dominated by `source`.
-    bool failed = false;
-    OpOperand *failedValue = nullptr;
-    visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
-      Operation *operandOp = operand->get().getDefiningOp();
-      if (operandOp && !domInfo.properlyDominates(operandOp, source,
-                                                  /*enclosingOpOk=*/false)) {
-        // `operand` is not an argument of an enclosing block and the defining
-        // op of `operand` is outside `target` but does not dominate `source`.
-        failed = true;
-        failedValue = operand;
-      }
-    });
-
-    if (failed)
-      return emitSilenceableFailure(failedValue->getOwner())
-             << "values used inside regions of target should be properly "
-                "dominated by source";
-  }
-
-  return DiagnosedSilenceableFailure::success();
-}
-
 DiagnosedSilenceableFailure
 transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
                                     transform::TransformResults &results,
@@ -520,14 +448,10 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
     return emitSilenceableFailure(target->getLoc())
            << "target or source is not a loop op";
 
-  // Check if the target and source are siblings.
-  DiagnosedSilenceableFailure diag = isOpSibling(target, source);
-  if (!diag.succeeded())
-    return diag;
-
-  if (!mlir::checkFusionStructuralLegality(target, source))
-    return emitSilenceableFailure(target->getLoc())
-           << "operations cannot be fused";
+  // Check if loops can be fused
+  Diagnostic diag(target.getLoc(), DiagnosticSeverity::Error);
+  if (!mlir::checkFusionStructuralLegality(target, source, diag))
+    return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
 
   Operation *fusedLoop;
   // TODO: Support fusion for loop-like ops besides scf.for, scf.forall
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index fd57a9228186e..b46535078dd8b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -136,9 +136,10 @@ verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
 static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
                           const IRMapping &firstToSecondPloopIndices,
                           llvm::function_ref<bool(Value, Value)> mayAlias) {
+  Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark);
   return !hasNestedParallelOp(firstPloop) &&
          !hasNestedParallelOp(secondPloop) &&
-         checkFusionStructuralLegality(firstPloop, secondPloop) &&
+         checkFusionStructuralLegality(firstPloop, secondPloop, diag) &&
          succeeded(verifyDependencies(firstPloop, secondPloop,
                                       firstToSecondPloopIndices, mayAlias));
 }
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index fab6592d9eb2a..b1a367281a6ca 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -1074,8 +1075,86 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
 // Fusion related helpers
 //===----------------------------------------------------------------------===//
 
+/// Check if `target` and `source` are siblings, in the context that `target`
+/// is being fused into `source`.
+///
+/// This is a simple check that just checks if both operations are in the same
+/// block and some checks to ensure that the fused IR does not violate
+/// dominance.
+static bool isOpSibling(Operation *target, Operation *source,
+                        Diagnostic &diag) {
+  // Check if both operations are same.
+  if (target == source) {
+    diag << "target and source need to be different loops";
+    return false;
+  }
+
+  // Check if both operations are in the same block.
+  if (target->getBlock() != source->getBlock()) {
+    diag << "target and source are not in the same block";
+    return false;
+  }
+
+  // Check if fusion will violate dominance.
+  DominanceInfo domInfo(source);
+  if (target->isBeforeInBlock(source)) {
+    // Since `target` is before `source`, all users of results of `target`
+    // need to be dominated by `source`.
+    for (Operation *user : target->getUsers()) {
+      if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
+        diag << "user of results of target should "
+                "be properly dominated by source";
+        return false;
+      }
+    }
+  } else {
+    // Since `target` is after `source`, all values used by `target` need
+    // to dominate `source`.
+
+    // Check if operands of `target` are dominated by `source`.
+    for (Value operand : target->getOperands()) {
+      Operation *operandOp = operand.getDefiningOp();
+      // Operands without defining operations are block arguments. When `target`
+      // and `source` occur in the same block, these operands dominate `source`.
+      if (!operandOp)
+        continue;
+
+      // Operand's defining operation should properly dominate `source`.
+      if (!domInfo.properlyDominates(operandOp, source,
+                                     /*enclosingOpOk=*/false)) {
+        diag << "operands of target should be properly dominated by source";
+        return false;
+      }
+    }
+
+    // Check if values used by `target` are dominated by `source`.
+    bool failed = false;
+    OpOperand *failedValue = nullptr;
+    visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
+      Operation *operandOp = operand->get().getDefiningOp();
+      if (operandOp && !domInfo.properlyDominates(operandOp, source,
+                                                  /*enclosingOpOk=*/false)) {
+        // `operand` is not an argument of an enclosing block and the defining
+        // op of `operand` is outside `target` but does not dominate `source`.
+        failed = true;
+        failedValue = operand;
+      }
+    });
+
+    if (failed) {
+      diag << "values used inside regions of target should be properly "
+              "dominated by source";
+      diag.attachNote(failedValue->getOwner()->getLoc()) << "see operation";
+      return false;
+    }
+  }
+
+  return true;
+}
+
 bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target,
-                                         LoopLikeOpInterface source) {
+                                         LoopLikeOpInterface source,
+                                         Diagnostic &diag) {
   bool iterSpaceEq =
       target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
       target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
@@ -1085,9 +1164,13 @@ bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target,
   // TODO: Decouple checks on concrete loop types and move this function
   // somewhere for general utility for `LoopLikeOpInterface`
   if (forAllTarget && forAllSource)
-    return iterSpaceEq &&
-           forAllTarget.getMapping() == forAllSource.getMapping();
-  return iterSpaceEq;
+    iterSpaceEq =
+        iterSpaceEq && forAllTarget.getMapping() == forAllSource.getMapping();
+  if (!iterSpaceEq) {
+    diag << "target and source iteration spaces must be equal";
+    return false;
+  }
+  return isOpSibling(target, source, diag);
 }
 
 scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index 46c6be36c3271..b03aa5cf38bfa 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -335,8 +335,9 @@ func.func @target_for_region_uses_result_of_source_for_err(%A: tensor<128xf32>,
     %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
     scf.yield %6 : tensor<128xf32>
   }
-  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
   // expected-error @below {{values used inside regions of target should be properly dominated by source}}
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
+    // expected-note @below {{see operation}}
     %dup2 = vector.transfer_read %1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
     %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
     %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>

>From ff47980d71330f65ecf05451f4d2345145a24e21 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 27 Jun 2024 11:46:02 -0500
Subject: [PATCH 29/34] remove extra dominance check

---
 mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp | 7 -------
 1 file changed, 7 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index b46535078dd8b..95ec8861aee2b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -158,13 +158,6 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
                      mayAlias))
     return;
 
-  DominanceInfo dom;
-  // We are fusing first loop into second, make sure there are no users of the
-  // first loop results between loops.
-  for (Operation *user : firstPloop->getUsers())
-    if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
-      return;
-
   IRRewriter rewriter(builder);
   secondPloop = mlir::fuseIndependentSiblingParallelLoops(
       firstPloop, secondPloop, rewriter);

>From c6847ec9212aa1754ad27c16f568a8d16346197d Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 27 Jun 2024 12:34:20 -0500
Subject: [PATCH 30/34] address more review comments

---
 mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp | 6 ++----
 mlir/lib/Dialect/SCF/Utils/Utils.cpp                  | 1 +
 2 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 3e0a483615a3d..8c93554f4016e 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -440,10 +440,8 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
            << "source handle (got " << llvm::range_size(sourceOps) << ")";
   }
 
-  LoopLikeOpInterface target =
-      dyn_cast<LoopLikeOpInterface>(*targetOps.begin());
-  LoopLikeOpInterface source =
-      dyn_cast<LoopLikeOpInterface>(*sourceOps.begin());
+  auto target = dyn_cast<LoopLikeOpInterface>(*targetOps.begin());
+  auto source = dyn_cast<LoopLikeOpInterface>(*sourceOps.begin());
   if (!target || !source)
     return emitSilenceableFailure(target->getLoc())
            << "target or source is not a loop op";
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index b1a367281a6ca..666b67517f4d4 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1220,6 +1220,7 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
 // interface methods.
 scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
     scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) {
+  OpBuilder::InsertionGuard guard(rewriter);
   Block *block1 = target.getBody();
   Block *block2 = source.getBody();
   auto term1 = cast<scf::ReduceOp>(block1->getTerminator());

>From f50c6aa14b36836950cc47909d4cca03d5ede8e3 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 27 Jun 2024 12:55:08 -0500
Subject: [PATCH 31/34] add more lit tests for scf.parallel

---
 .../SCF/transform-loop-fuse-sibling.mlir      | 144 ++++++++++++++++++
 1 file changed, 144 insertions(+)

diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index b03aa5cf38bfa..1d46a3d88f47d 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -100,6 +100,116 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func @fuse_two_parallel_reverse
+// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+func.func @fuse_two_parallel_reverse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+// CHECK-DAG:  [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG:  [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG:  [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG:  [[C1FP:%.*]] = arith.constant 1.
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1fp = arith.constant 1.0 : f32
+// CHECK:      [[SUM:%.*]] = memref.alloc()
+  %sum = memref.alloc()  : memref<2x2xf32>
+// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK:        [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
+// CHECK:        memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
+// CHECK:        [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
+// CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK:        scf.reduce
+// CHECK:      }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %c1fp : f32
+    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+// CHECK:      memref.dealloc [[SUM]]
+  memref.dealloc %sum : memref<2x2xf32>
+  return
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %parallel:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %parallel#1 into %parallel#0 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func @fuse_reductions_two
+//  CHECK-SAME:  (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
+func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
+//   CHECK-DAG:   %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
+//       CHECK:   %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
+//  CHECK-SAME:   to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
+//  CHECK-SAME:   init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
+//       CHECK:   %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
+//       CHECK:   %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
+//       CHECK:   scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) {
+//       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+//       CHECK:     %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
+//       CHECK:     scf.reduce.return %[[R]] : f32
+//       CHECK:   }
+//       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+//       CHECK:     %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
+//       CHECK:     scf.reduce.return %[[R]] : f32
+//       CHECK:   }
+//       CHECK:   return %[[RES]]#0, %[[RES]]#1 : f32, f32
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %init1 = arith.constant 1.0 : f32
+  %init2 = arith.constant 2.0 : f32
+  %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    scf.reduce(%A_elem : f32) {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.addf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+  }
+  %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    scf.reduce(%B_elem : f32) {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.mulf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+  }
+  return %res1, %res2 : f32, f32
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %parallel:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
 func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
   // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
@@ -382,3 +492,37 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+func.func @non_matching_iteration_spaces_err(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1fp = arith.constant 1.0 : f32
+  %sum = memref.alloc()  : memref<2x2xf32>
+  // expected-error @below {{target and source iteration spaces must be equal}}
+  scf.parallel (%i) = (%c0) to (%c2) step (%c1) {
+    %B_elem = memref.load %B[%i, %c0] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %c1fp : f32
+    memref.store %sum_elem, %sum[%i, %c0] : memref<2x2xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  memref.dealloc %sum : memref<2x2xf32>
+  return
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %parallel:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}

>From 6dd68c1b8408f05acaff9d040d4a686044295fcc Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 27 Jun 2024 13:24:11 -0500
Subject: [PATCH 32/34] check for equal loop types in
 checkFusionStructuralLegality

---
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          |  5 +++
 .../SCF/transform-loop-fuse-sibling.mlir      | 33 +++++++++++++++++++
 2 files changed, 38 insertions(+)

diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 666b67517f4d4..0c966bf182cbd 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1155,6 +1155,11 @@ static bool isOpSibling(Operation *target, Operation *source,
 bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target,
                                          LoopLikeOpInterface source,
                                          Diagnostic &diag) {
+  if (target->getName() != source->getName()) {
+    diag << "target and source must be same loop type";
+    return false;
+  }
+
   bool iterSpaceEq =
       target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
       target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index 1d46a3d88f47d..505013d328962 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -526,3 +526,36 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+func.func @non_matching_loop_types_err(%A: memref<2xf32>, %B: memref<2xf32>) {
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1fp = arith.constant 1.0 : f32
+  %sum = memref.alloc()  : memref<2xf32>
+  // expected-error @below {{target and source must be same loop type}}
+  scf.for %i = %c0 to %c2 step %c1 {
+    %B_elem = memref.load %B[%i] : memref<2xf32>
+    %sum_elem = arith.addf %B_elem, %c1fp : f32
+    memref.store %sum_elem, %sum[%i] : memref<2xf32>
+  }
+  scf.parallel (%i) = (%c0) to (%c2) step (%c1) {
+    %sum_elem = memref.load %sum[%i] : memref<2xf32>
+    %A_elem = memref.load %A[%i] : memref<2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %B[%i] : memref<2xf32>
+    scf.reduce
+  }
+  memref.dealloc %sum : memref<2xf32>
+  return
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %fused = transform.loop.fuse_sibling %0 into %1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}

>From 99d821b47cb731ac7a12b60c44d88af5ad2fb0d1 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 27 Jun 2024 13:42:20 -0500
Subject: [PATCH 33/34] address more comments

---
 mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp | 1 -
 mlir/lib/Dialect/SCF/Utils/Utils.cpp                   | 8 +++-----
 2 files changed, 3 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 95ec8861aee2b..b775f988576e3 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -161,7 +161,6 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
   IRRewriter rewriter(builder);
   secondPloop = mlir::fuseIndependentSiblingParallelLoops(
       firstPloop, secondPloop, rewriter);
-  ;
 }
 
 void mlir::scf::naivelyFuseParallelOps(
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 0c966bf182cbd..a79aef34e48b1 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1164,13 +1164,11 @@ bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target,
       target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
       target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
       target.getLoopSteps() == source.getLoopSteps();
-  auto forAllTarget = dyn_cast<scf::ForallOp>(*target);
-  auto forAllSource = dyn_cast<scf::ForallOp>(*source);
   // TODO: Decouple checks on concrete loop types and move this function
   // somewhere for general utility for `LoopLikeOpInterface`
-  if (forAllTarget && forAllSource)
-    iterSpaceEq =
-        iterSpaceEq && forAllTarget.getMapping() == forAllSource.getMapping();
+  if (auto forAllTarget = dyn_cast<scf::ForallOp>(*target))
+    iterSpaceEq = iterSpaceEq && forAllTarget.getMapping() ==
+                                     cast<scf::ForallOp>(*source).getMapping();
   if (!iterSpaceEq) {
     diag << "target and source iteration spaces must be equal";
     return false;

>From 4e4a96e376aaf778d013d384dc3c2b9dab405f35 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 3 Jul 2024 11:54:39 -0500
Subject: [PATCH 34/34] Fix bug in fusion refactor and add test

---
 mlir/lib/Interfaces/LoopLikeInterface.cpp     |  1 +
 .../SCF/transform-loop-fuse-sibling.mlir      | 56 +++++++++++++++++++
 2 files changed, 57 insertions(+)

diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index 6f0ebec0519be..effdf9d7ec57f 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -138,6 +138,7 @@ LoopLikeOpInterface mlir::createFused(LoopLikeOpInterface target,
   if (failed(maybeFusedLoop))
     llvm_unreachable("failed to replace loop");
   LoopLikeOpInterface fusedLoop = *maybeFusedLoop;
+  rewriter.moveOpBefore(fusedLoop, source);
 
   // Map control operands.
   IRMapping mapping;
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index 91ed2a5269d74..f8246b74a5744 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -371,6 +371,62 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+
+// -----
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 32) 
+#map = affine_map<(d0) -> (d0 * 32)>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+module {
+  // CHECK: func.func @loop_sibling_fusion(%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}
+  func.func @loop_sibling_fusion(%arg0: tensor<128xf32>, %arg1: tensor<128x128xf16>, %arg2: tensor<128x64xf32>, %arg3: tensor<128x128xf32>) -> (tensor<128xf32>, tensor<128x128xf16>) {
+  // CHECK:      %[[EMPTY:.*]] = tensor.empty() : tensor<128x128xf16>
+  // CHECK-NEXT: %[[RESULTS:.*]]:2 = scf.forall (%[[I:.*]]) in (4) shared_outs(%[[S1:.*]] = %[[ARG0]], %[[S2:.*]] = %[[ARG1]]) -> (tensor<128xf32>, tensor<128x128xf16>) {
+  // CHECK-NEXT:  %[[IDX:.*]] = affine.apply #[[$MAP]](%[[I]])
+  // CHECK-NEXT:  %[[SLICE0:.*]] = tensor.extract_slice %[[ARG3]][%[[IDX]], 0] [32, 1] [1, 1] : tensor<128x128xf32> to tensor<32xf32>
+  // CHECK-NEXT:  %[[SLICE1:.*]] = tensor.extract_slice %[[ARG3]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<128x128xf32> to tensor<32x128xf32>
+  // CHECK-NEXT:  %[[SLICE2:.*]] = tensor.extract_slice %[[EMPTY]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<128x128xf16> to tensor<32x128xf16>
+  // CHECK-NEXT:  %[[GENERIC:.*]] = linalg.generic {{.*}} ins(%[[SLICE1]] : {{.*}}) outs(%[[SLICE2]] : {{.*}})
+  // CHECK:       scf.forall.in_parallel {
+  // CHECK-NEXT:    tensor.parallel_insert_slice %[[SLICE0]] into %[[S1]][%[[IDX]]] [32] [1] : tensor<32xf32> into tensor<128xf32>
+  // CHECK-NEXT:    tensor.parallel_insert_slice %[[GENERIC]] into %[[S2]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<32x128xf16> into tensor<128x128xf16>
+  // CHECK-NEXT:  }
+  // CHECK-NEXT: } {mapping = [#gpu.warp<linear_dim_0>]}
+  // CHECK-NEXT: return %[[RESULTS]]#0, %[[RESULTS]]#1
+    %0 = scf.forall (%arg4) in (4) shared_outs(%arg5 = %arg0) -> (tensor<128xf32>) {
+      %3 = affine.apply #map(%arg4)
+      %extracted_slice = tensor.extract_slice %arg3[%3, 0] [32, 1] [1, 1] : tensor<128x128xf32> to tensor<32xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %extracted_slice into %arg5[%3] [32] [1] : tensor<32xf32> into tensor<128xf32>
+      }
+    } {mapping = [#gpu.warp<linear_dim_0>]}
+    %1 = tensor.empty() : tensor<128x128xf16>
+    %2 = scf.forall (%arg4) in (4) shared_outs(%arg5 = %arg1) -> (tensor<128x128xf16>) {
+      %3 = affine.apply #map(%arg4)
+      %extracted_slice = tensor.extract_slice %arg3[%3, 0] [32, 128] [1, 1] : tensor<128x128xf32> to tensor<32x128xf32>
+      %extracted_slice_0 = tensor.extract_slice %1[%3, 0] [32, 128] [1, 1] : tensor<128x128xf16> to tensor<32x128xf16>
+      %4 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<32x128xf32>) outs(%extracted_slice_0 : tensor<32x128xf16>) {
+      ^bb0(%in: f32, %out: f16):
+        %5 = arith.truncf %in : f32 to f16 
+        linalg.yield %5 : f16 
+      } -> tensor<32x128xf16>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %4 into %arg5[%3, 0] [32, 128] [1, 1] : tensor<32x128xf16> into tensor<128x128xf16>
+      }   
+    } {mapping = [#gpu.warp<linear_dim_0>]}
+    return %0, %2 : tensor<128xf32>, tensor<128x128xf16>
+  }
+}
+
+module attributes { transform.with_named_sequence } { 
+  transform.named_sequence @__transform_main(%root: !transform.any_op) {
+    %loops = transform.structured.match ops{["scf.forall"]} in %root : (!transform.any_op) -> !transform.any_op
+    %loop1, %loop2 = transform.split_handle %loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %loop3 = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
 // -----
 
 func.func @source_for_uses_result_of_target_for_err(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {



More information about the Mlir-commits mailing list