[Mlir-commits] [mlir] c1ba9c4 - [mlir][Affine] Refactor affine fusion code in pass to utilities

Diego Caballero llvmlistbot at llvm.org
Wed Nov 18 14:06:13 PST 2020


Author: Diego Caballero
Date: 2020-11-18T13:50:32-08:00
New Revision: c1ba9c43adb7ee101048e88ab33c94a1ceda398e

URL: https://github.com/llvm/llvm-project/commit/c1ba9c43adb7ee101048e88ab33c94a1ceda398e
DIFF: https://github.com/llvm/llvm-project/commit/c1ba9c43adb7ee101048e88ab33c94a1ceda398e.diff

LOG: [mlir][Affine] Refactor affine fusion code in pass to utilities

Refactoring/clean-up step needed to add support for producer-consumer fusion
with multi-store producer loops and, in general, to implement more general
loop fusion strategies in Affine. It introduces the following changes:
  - AffineLoopFusion pass now uses loop fusion utilities more broadly to compute
    fusion legality (canFuseLoops utility) and perform the fusion transformation
    (fuseLoops utility).
  - Loop fusion utilities have been extended to deal with AffineLoopFusion
    requirements and assumptions while preserving both loop fusion utilities and
    AffineLoopFusion current functionality within a unified implementation.
    'FusionStrategy' has been introduced for this purpose and, in the future, it
    will allow us to have a single loop fusion core implementation that will produce
    different fusion outputs depending on the strategy used.
  - Improve separation of concerns for legality and profitability analysis:
    'isFusionProfitable' no longer filters out illegal scenarios that 'canFuse'
    didn't detect, or the other way around. 'canFuse' now takes loop dependences
    into account to determine the fusion loop depth (producer-consumer fusion only).
  - As a result, maximal fusion now doesn't require any profitability analysis.
  - Slices are now computed only once and reused across the legality, profitability
    and fusion transformation steps (producer-consumer).
  - Refactor some utilities and remove redundant copies of them.

This patch is NFCI and should preserve the existing functionality of both the
AffineLoopFusion pass and the affine fusion utilities.

Reviewed By: andydavis1, bondhugula

Differential Revision: https://reviews.llvm.org/D90798

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Utils.h
    mlir/include/mlir/Transforms/LoopFusionUtils.h
    mlir/lib/Analysis/Utils.cpp
    mlir/lib/Transforms/LoopFusion.cpp
    mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
    mlir/test/lib/Transforms/TestLoopFusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h
index b502d909d5c0..30b6272181f5 100644
--- a/mlir/include/mlir/Analysis/Utils.h
+++ b/mlir/include/mlir/Analysis/Utils.h
@@ -82,6 +82,11 @@ struct ComputationSliceState {
 
   // Clears all bounds and operands in slice state.
   void clearBounds();
+
+  /// Return true if the computation slice is empty.
+  bool isEmpty() const { return ivs.empty(); }
+
+  void dump() const;
 };
 
 /// Computes the computation slice loop bounds for one loop nest as affine maps
@@ -212,7 +217,7 @@ struct MemRefRegion {
   /// The last field is a 2-d FlatAffineConstraints symbolic in %i.
   ///
   LogicalResult compute(Operation *op, unsigned loopDepth,
-                        ComputationSliceState *sliceState = nullptr,
+                        const ComputationSliceState *sliceState = nullptr,
                         bool addMemRefDimBounds = true);
 
   FlatAffineConstraints *getConstraints() { return &cst; }
@@ -309,6 +314,11 @@ bool isLoopParallel(AffineForOp forOp);
 /// number of constraints.
 IntegerSet simplifyIntegerSet(IntegerSet set);
 
+/// Returns the innermost common loop depth for the set of operations in 'ops'.
+unsigned getInnermostCommonLoopDepth(
+    ArrayRef<Operation *> ops,
+    SmallVectorImpl<AffineForOp> *surroundingLoops = nullptr);
+
 } // end namespace mlir
 
 #endif // MLIR_ANALYSIS_UTILS_H

diff  --git a/mlir/include/mlir/Transforms/LoopFusionUtils.h b/mlir/include/mlir/Transforms/LoopFusionUtils.h
index 36d2520b7c85..eade565e0325 100644
--- a/mlir/include/mlir/Transforms/LoopFusionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopFusionUtils.h
@@ -15,6 +15,7 @@
 #ifndef MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
 #define MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
 
+#include "mlir/IR/Value.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallVector.h"
@@ -38,6 +39,45 @@ struct FusionResult {
   FusionResult(ResultEnum v) : value(v) {}
 };
 
+/// Describes the fusion strategy to be used in the Affine loop fusion
+/// utilities. Currently, it is used to specialized the loop fusion utilities
+/// with the assumptions made in the AffineLoopFusion pass for producer-consumer
+/// and sibling fusion, while sharing a single implementation. The latter
+/// strategies are also limited to scenarios where a single memref is involved
+/// in the producer-consume or sibling relationship between the candidate
+/// loops. We use 'memref' to keep track of such a memref.
+// TODO: Remove 'memref' when we support more generic scenarios.
+// TODO: Generalize utilities so that producer-consumer and sibling fusion
+// strategies can be used without the assumptions made in the AffineLoopFusion
+// pass.
+struct FusionStrategy {
+  enum StrategyEnum {
+    // Generic loop fusion: Arbitrary loops are considered for fusion. No
+    // assumptions about a specific fusion strategy from AffineLoopFusion pass
+    // are made.
+    // TODO: Generic fusion is not fully implemented by fusion utilities yet.
+    // It should only be used for testing.
+    Generic,
+    // Producer-consumer fusion: Only loops with a producer-consumer
+    // memref dependence are considered for fusion. Currently, assumptions from
+    // the producer-consumer fusion implementation in AffineLoopFusion pass are
+    // made. See pass for specific details.
+    ProducerConsumer,
+    // Sibling fusion: Only sibling loops with no producer-consumer memref
+    // dependences are considered for fusion. Memref reuse is taken into account
+    // for profitability. Currently, assumptions from the sibling fusion
+    // implementation in AffineLoopFusion pass are made. See pass for specific
+    // details.
+    Sibling
+  } strategy;
+
+  // Target memref for this fusion transformation.
+  Value memref;
+
+  FusionStrategy(StrategyEnum strategy, Value memref)
+      : strategy(strategy), memref(memref) {}
+};
+
 /// Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the
 /// loop nest rooted at 'dstForOp' at 'dstLoopDepth'. Returns FusionResult
 /// 'Success' if fusion of the src/dst loop nests is feasible (i.e. they are
@@ -48,12 +88,14 @@ struct FusionResult {
 /// TODO: Update comments when this function is fully implemented.
 FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
                           unsigned dstLoopDepth,
-                          ComputationSliceState *srcSlice);
+                          ComputationSliceState *srcSlice,
+                          FusionStrategy fusionStrategy = {
+                              FusionStrategy::Generic, Value()});
 
 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
 /// and source slice loop bounds specified in 'srcSlice'.
 void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
-               ComputationSliceState *srcSlice);
+               const ComputationSliceState &srcSlice);
 
 /// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
 /// and operation count) for a loop nest up until (and including) the innermost
@@ -89,7 +131,8 @@ int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats);
 // TODO: Improve this cost model.
 bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
                           AffineForOp dstForOp, LoopNestStats &dstStats,
-                          ComputationSliceState *slice, int64_t *computeCost);
+                          const ComputationSliceState &slice,
+                          int64_t *computeCost);
 
 } // end namespace mlir
 

diff  --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index b02212a09bba..35678432dd16 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -105,6 +105,28 @@ void ComputationSliceState::clearBounds() {
   ubOperands.clear();
 }
 
+void ComputationSliceState::dump() const {
+  llvm::errs() << "\tIVs:\n";
+  for (Value iv : ivs)
+    llvm::errs() << "\t\t" << iv << "\n";
+
+  llvm::errs() << "\tLBs:\n";
+  for (auto &en : llvm::enumerate(lbs)) {
+    llvm::errs() << "\t\t" << en.value() << "\n";
+    llvm::errs() << "\t\tOperands:\n";
+    for (Value lbOp : lbOperands[en.index()])
+      llvm::errs() << "\t\t\t" << lbOp << "\n";
+  }
+
+  llvm::errs() << "\tUBs:\n";
+  for (auto &en : llvm::enumerate(ubs)) {
+    llvm::errs() << "\t\t" << en.value() << "\n";
+    llvm::errs() << "\t\tOperands:\n";
+    for (Value ubOp : ubOperands[en.index()])
+      llvm::errs() << "\t\t\t" << ubOp << "\n";
+  }
+}
+
 unsigned MemRefRegion::getRank() const {
   return memref.getType().cast<MemRefType>().getRank();
 }
@@ -211,7 +233,7 @@ LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
 // TODO: extend this to any other memref dereferencing ops
 // (dma_start, dma_wait).
 LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
-                                    ComputationSliceState *sliceState,
+                                    const ComputationSliceState *sliceState,
                                     bool addMemRefDimBounds) {
   assert((isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) &&
          "affine read/write op expected");
@@ -541,13 +563,12 @@ static LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value, 8> &ivs,
   return success();
 }
 
-// Returns the innermost common loop depth for the set of operations in 'ops'.
+/// Returns the innermost common loop depth for the set of operations in 'ops'.
 // TODO: Move this to LoopUtils.
-static unsigned
-getInnermostCommonLoopDepth(ArrayRef<Operation *> ops,
-                            SmallVectorImpl<AffineForOp> &surroundingLoops) {
+unsigned mlir::getInnermostCommonLoopDepth(
+    ArrayRef<Operation *> ops, SmallVectorImpl<AffineForOp> *surroundingLoops) {
   unsigned numOps = ops.size();
-  assert(numOps > 0);
+  assert(numOps > 0 && "Expected at least one operation");
 
   std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
   unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
@@ -564,7 +585,8 @@ getInnermostCommonLoopDepth(ArrayRef<Operation *> ops,
       if (loops[i - 1][d] != loops[i][d])
         return loopDepth;
     }
-    surroundingLoops.push_back(loops[i - 1][d]);
+    if (surroundingLoops)
+      surroundingLoops->push_back(loops[i - 1][d]);
     ++loopDepth;
   }
   return loopDepth;
@@ -684,7 +706,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
   }
   SmallVector<AffineForOp, 4> surroundingLoops;
   unsigned innermostCommonLoopDepth =
-      getInnermostCommonLoopDepth(ops, surroundingLoops);
+      getInnermostCommonLoopDepth(ops, &surroundingLoops);
   if (loopDepth > innermostCommonLoopDepth) {
     LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
     return failure();

diff  --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index ed79be02b816..6716260aa0d1 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -741,77 +741,6 @@ static void moveLoadsAccessingMemrefTo(Value memref,
   srcLoads->swap(srcLoadsToKeep);
 }
 
-// Returns the innermost common loop depth for the set of operations in 'ops'.
-static unsigned getInnermostCommonLoopDepth(ArrayRef<Operation *> ops) {
-  unsigned numOps = ops.size();
-  assert(numOps > 0);
-
-  std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
-  unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
-  for (unsigned i = 0; i < numOps; ++i) {
-    getLoopIVs(*ops[i], &loops[i]);
-    loopDepthLimit =
-        std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
-  }
-
-  unsigned loopDepth = 0;
-  for (unsigned d = 0; d < loopDepthLimit; ++d) {
-    unsigned i;
-    for (i = 1; i < numOps; ++i) {
-      if (loops[i - 1][d] != loops[i][d])
-        break;
-    }
-    if (i != numOps)
-      break;
-    ++loopDepth;
-  }
-  return loopDepth;
-}
-
-// Returns the maximum loop depth at which no dependences between 'loadOpInsts'
-// and 'storeOpInsts' are satisfied.
-static unsigned getMaxLoopDepth(ArrayRef<Operation *> loadOpInsts,
-                                ArrayRef<Operation *> storeOpInsts) {
-  // Merge loads and stores into the same array.
-  SmallVector<Operation *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
-  ops.append(storeOpInsts.begin(), storeOpInsts.end());
-
-  // Compute the innermost common loop depth for loads and stores.
-  unsigned loopDepth = getInnermostCommonLoopDepth(ops);
-
-  // Return common loop depth for loads if there are no store ops.
-  if (storeOpInsts.empty())
-    return loopDepth;
-
-  // Check dependences on all pairs of ops in 'ops' and store the minimum
-  // loop depth at which a dependence is satisfied.
-  for (unsigned i = 0, e = ops.size(); i < e; ++i) {
-    auto *srcOpInst = ops[i];
-    MemRefAccess srcAccess(srcOpInst);
-    for (unsigned j = 0; j < e; ++j) {
-      auto *dstOpInst = ops[j];
-      MemRefAccess dstAccess(dstOpInst);
-
-      unsigned numCommonLoops =
-          getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
-      for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
-        FlatAffineConstraints dependenceConstraints;
-        // TODO: Cache dependence analysis results, check cache here.
-        DependenceResult result = checkMemrefAccessDependence(
-            srcAccess, dstAccess, d, &dependenceConstraints,
-            /*dependenceComponents=*/nullptr);
-        if (hasDependence(result)) {
-          // Store minimum loop depth and break because we want the min 'd' at
-          // which there is a dependence.
-          loopDepth = std::min(loopDepth, d - 1);
-          break;
-        }
-      }
-    }
-  }
-  return loopDepth;
-}
-
 // Sinks all sequential loops to the innermost levels (while preserving
 // relative order among them) and moves all parallel loops to the
 // outermost (while again preserving relative order among them).
@@ -1077,14 +1006,16 @@ canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on
 // the memref being produced and consumed, which is an input to the cost model.
 // For producer-consumer fusion, 'srcStoreOpInst' will be the same as
-// 'srcOpInst', as we are slicing w.r.t to that producer.
-// For input-reuse fusion, 'srcOpInst' will be the src loop nest LoadOp which
-// reads from the same memref as dst loop nest load ops, and 'srcStoreOpInst'
-// will be the unique store op in the src node, which will be used to check
-// that the write region is the same after input-reuse fusion.
-// Returns true if it is profitable to fuse the candidate loop nests. Returns
-// false otherwise. `dstLoopDepth` is set to the most profitable depth at which
-// to materialize the source loop nest slice.
+// 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse
+// fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the
+// same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the
+// unique store op in the src node, which will be used to check that the write
+// region is the same after input-reuse fusion. Computation slices are provided
+// in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which
+// fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is
+// profitable to fuse the candidate loop nests. Returns false otherwise.
+// `dstLoopDepth` is set to the most profitable depth at which to materialize
+// the source loop nest slice.
 // The profitability model executes the following steps:
 // *) Computes the backward computation slice at 'srcOpInst'. This
 //    computation slice of the loop nest surrounding 'srcOpInst' is
@@ -1112,9 +1043,9 @@ canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
 //    is lower.
 static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
                                ArrayRef<Operation *> dstLoadOpInsts,
-                               ArrayRef<Operation *> dstStoreOpInsts,
-                               ComputationSliceState *sliceState,
-                               unsigned *dstLoopDepth, bool maximalFusion,
+                               ArrayRef<ComputationSliceState> depthSliceUnions,
+                               unsigned maxLegalFusionDepth,
+                               unsigned *dstLoopDepth,
                                double computeToleranceThreshold) {
   LLVM_DEBUG({
     llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
@@ -1124,10 +1055,14 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
     };
   });
 
+  if (maxLegalFusionDepth == 0) {
+    LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth == 0 .\n");
+    return false;
+  }
+
   // Compute cost of sliced and unsliced src loop nest.
   SmallVector<AffineForOp, 4> srcLoopIVs;
   getLoopIVs(*srcOpInst, &srcLoopIVs);
-  unsigned numSrcLoopIVs = srcLoopIVs.size();
 
   // Walk src loop nest and collect stats.
   LoopNestStats srcLoopNestStats;
@@ -1142,19 +1077,8 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
   if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats))
     return false;
 
-  // Compute the maximum loop depth at which we can can insert the src slice
-  // and still satisfy dest loop nest dependences, for producer-consumer fusion.
-  unsigned maxDstLoopDepth =
-      (srcOpInst == srcStoreOpInst)
-          ? getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts)
-          : dstLoopIVs.size();
-  if (maxDstLoopDepth == 0) {
-    LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxDstLoopDepth == 0 .\n");
-    return false;
-  }
-
   // Search for min cost value for 'dstLoopDepth'. At each value of
-  // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
+  // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
   // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
   // of these bounds). Next the union slice bounds are used to calculate
   // the cost of the slice and the cost of the slice inserted into the dst
@@ -1163,8 +1087,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
   double maxStorageReduction = 0.0;
   Optional<uint64_t> sliceMemEstimate = None;
 
-  SmallVector<ComputationSliceState, 4> sliceStates;
-  sliceStates.resize(maxDstLoopDepth);
   // The best loop depth at which to materialize the slice.
   Optional<unsigned> bestDstLoopDepth = None;
 
@@ -1190,21 +1112,14 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
 
   // Evaluate all depth choices for materializing the slice in the destination
   // loop nest.
-  for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
-    // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'.
-    if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts,
-                                       /*loopDepth=*/i,
-                                       /*numCommonLoops=*/0,
-                                       /*isBackwardSlice=*/true,
-                                       &sliceStates[i - 1]))) {
-      LLVM_DEBUG(llvm::dbgs()
-                 << "computeSliceUnion failed for loopDepth: " << i << "\n");
+  for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
+    // Skip slice union if it wasn't computed for this depth.
+    if (depthSliceUnions[i - 1].isEmpty())
       continue;
-    }
 
     int64_t fusedLoopNestComputeCost;
     if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0],
-                              dstLoopNestStats, &sliceStates[i - 1],
+                              dstLoopNestStats, depthSliceUnions[i - 1],
                               &fusedLoopNestComputeCost)) {
       LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n.");
       continue;
@@ -1216,11 +1131,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
         1;
 
     // Determine what the slice write MemRefRegion would be, if the src loop
-    // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop
-    // nest at loop depth 'i'
+    // nest slice 'depthSliceUnions[i - 1]' were to be inserted into the dst
+    // loop nest at loop depth 'i'.
     MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
     if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
-                                        &sliceStates[i - 1]))) {
+                                        &depthSliceUnions[i - 1]))) {
       LLVM_DEBUG(llvm::dbgs()
                  << "Failed to compute slice write region at loopDepth: " << i
                  << "\n");
@@ -1269,8 +1184,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
     // (as per computeToleranceThreshold), we will simply pick the one that
     // reduces the intermediary size the most.
     if ((storageReduction > maxStorageReduction) &&
-        (maximalFusion ||
-         (additionalComputeFraction < computeToleranceThreshold))) {
+        (additionalComputeFraction < computeToleranceThreshold)) {
       maxStorageReduction = storageReduction;
       bestDstLoopDepth = i;
       minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
@@ -1278,10 +1192,9 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
     }
   }
 
-  // A simple cost model: fuse if it reduces the memory footprint. If
-  // -maximal-fusion is set, fuse nevertheless.
+  // A simple cost model: fuse if it reduces the memory footprint.
 
-  if (!maximalFusion && !bestDstLoopDepth.hasValue()) {
+  if (!bestDstLoopDepth.hasValue()) {
     LLVM_DEBUG(
         llvm::dbgs()
         << "All fusion choices involve more than the threshold amount of "
@@ -1310,33 +1223,30 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
 
   Optional<double> storageReduction = None;
 
-  if (!maximalFusion) {
-    if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
-      LLVM_DEBUG(
-          llvm::dbgs()
-          << "  fusion memory benefit cannot be evaluated; NOT fusing.\n");
-      return false;
-    }
+  if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "  fusion memory benefit cannot be evaluated; NOT fusing.\n");
+    return false;
+  }
 
-    auto srcMemSizeVal = srcMemSize.getValue();
-    auto dstMemSizeVal = dstMemSize.getValue();
+  auto srcMemSizeVal = srcMemSize.getValue();
+  auto dstMemSizeVal = dstMemSize.getValue();
 
-    assert(sliceMemEstimate.hasValue() && "expected value");
-    auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
+  assert(sliceMemEstimate.hasValue() && "expected value");
+  auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
 
-    LLVM_DEBUG(llvm::dbgs() << "   src mem: " << srcMemSizeVal << "\n"
-                            << "   dst mem: " << dstMemSizeVal << "\n"
-                            << "   fused mem: " << fusedMem << "\n"
-                            << "   slice mem: " << sliceMemEstimate << "\n");
+  LLVM_DEBUG(llvm::dbgs() << "   src mem: " << srcMemSizeVal << "\n"
+                          << "   dst mem: " << dstMemSizeVal << "\n"
+                          << "   fused mem: " << fusedMem << "\n"
+                          << "   slice mem: " << sliceMemEstimate << "\n");
 
-    if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
-      LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
-      return false;
-    }
-    storageReduction =
-        100.0 *
-        (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
+  if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
+    LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
+    return false;
   }
+  storageReduction =
+      100.0 *
+      (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
 
   double additionalComputeFraction =
       100.0 * (minFusedLoopNestComputeCost /
@@ -1355,24 +1265,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
     llvm::dbgs() << msg.str();
   });
 
-  // Update return parameter 'sliceState' with 'bestSliceState'.
-  ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
-  sliceState->lbs = bestSliceState->lbs;
-  sliceState->ubs = bestSliceState->ubs;
-  sliceState->lbOperands = bestSliceState->lbOperands;
-  sliceState->ubOperands = bestSliceState->ubOperands;
-
-  // Canonicalize slice bound affine maps.
-  for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
-    if (sliceState->lbs[i] != AffineMap()) {
-      canonicalizeMapAndOperands(&sliceState->lbs[i],
-                                 &sliceState->lbOperands[i]);
-    }
-    if (sliceState->ubs[i] != AffineMap()) {
-      canonicalizeMapAndOperands(&sliceState->ubs[i],
-                                 &sliceState->ubOperands[i]);
-    }
-  }
   return true;
 }
 
@@ -1592,138 +1484,142 @@ struct GreedyFusion {
           if (insertPointInst == nullptr)
             continue;
 
+          auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
+          auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
+
           // Compute the innermost common loop depth for dstNode loads/stores.
-          SmallVector<Operation *, 2> dstOps(dstNode->loads.begin(),
-                                             dstNode->loads.end());
-          dstOps.append(dstNode->stores.begin(), dstNode->stores.end());
-          unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstOps);
+          SmallVector<Operation *, 2> dstMemrefOps;
+          for (Operation *op : dstNode->loads)
+            if (cast<AffineReadOpInterface>(op).getMemRef() == memref)
+              dstMemrefOps.push_back(op);
+          for (Operation *op : dstNode->stores)
+            if (cast<AffineWriteOpInterface>(op).getMemRef() == memref)
+              dstMemrefOps.push_back(op);
+          unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps);
+
           // Check the feasibility of fusing src loop nest into dst loop nest
           // at loop depths in range [1, dstLoopDepthTest].
-          // TODO: Use slice union computation and union of memref
-          // read/write regions to cost model and fusion.
-          bool canFuse = false;
+          unsigned maxLegalFusionDepth = 0;
+          SmallVector<ComputationSliceState, 8> depthSliceUnions;
+          depthSliceUnions.resize(dstLoopDepthTest);
+          FusionStrategy strategy(FusionStrategy::ProducerConsumer, memref);
           for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
-            ComputationSliceState sliceUnion;
             FusionResult result = mlir::canFuseLoops(
-                cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op),
-                /*dstLoopDepth=*/i, &sliceUnion);
+                srcAffineForOp, dstAffineForOp,
+                /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);
+
             if (result.value == FusionResult::Success)
-              canFuse = true;
+              maxLegalFusionDepth = i;
           }
 
-          // Skip if fusion is not feasible at all loop depths.
-          if (!canFuse)
+          // Skip if fusion is not feasible at any loop depths.
+          if (maxLegalFusionDepth == 0)
             continue;
 
-          // Gather 'dstNode' store ops to 'memref'.
-          SmallVector<Operation *, 2> dstStoreOpInsts;
-          for (auto *storeOpInst : dstNode->stores)
-            if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() == memref)
-              dstStoreOpInsts.push_back(storeOpInst);
-
-          unsigned bestDstLoopDepth;
-          mlir::ComputationSliceState sliceState;
-          // Check if fusion would be profitable.
-          if (!isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts,
-                                  dstStoreOpInsts, &sliceState,
-                                  &bestDstLoopDepth, maximalFusion,
-                                  computeToleranceThreshold))
+          // Check if fusion would be profitable. We skip profitability analysis
+          // for maximal fusion since we already know the maximal legal depth to
+          // fuse.
+          unsigned bestDstLoopDepth = maxLegalFusionDepth;
+          if (!maximalFusion &&
+              !isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts,
+                                  depthSliceUnions, maxLegalFusionDepth,
+                                  &bestDstLoopDepth, computeToleranceThreshold))
             continue;
 
+          assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
+          assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
+                 "Missing slice union for depth");
+
           // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
-          auto sliceLoopNest = mlir::insertBackwardComputationSlice(
-              srcStoreOp, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
-          if (sliceLoopNest) {
-            LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n"
-                                    << *sliceLoopNest.getOperation() << "\n");
-            // Move 'dstAffineForOp' before 'insertPointInst' if needed.
-            auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
-            if (insertPointInst != dstAffineForOp.getOperation()) {
-              dstAffineForOp.getOperation()->moveBefore(insertPointInst);
-            }
-            // Update edges between 'srcNode' and 'dstNode'.
-            mdg->updateEdges(srcNode->id, dstNode->id, memref,
-                             createPrivateMemref);
-
-            // Collect slice loop stats.
-            LoopNestStateCollector sliceCollector;
-            sliceCollector.collect(sliceLoopNest.getOperation());
-            // Promote single iteration slice loops to single IV value.
-            for (auto forOp : sliceCollector.forOps) {
-              promoteIfSingleIteration(forOp);
-            }
-            if (createPrivateMemref) {
-              // Create private memref for 'memref' in 'dstAffineForOp'.
-              SmallVector<Operation *, 4> storesForMemref;
-              for (auto *storeOpInst : sliceCollector.storeOpInsts) {
-                if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() ==
-                    memref)
-                  storesForMemref.push_back(storeOpInst);
-              }
-              // TODO: Use union of memref write regions to compute
-              // private memref footprint.
-              auto newMemRef = createPrivateMemRef(
-                  dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
-                  fastMemorySpace, localBufSizeThreshold);
-              visitedMemrefs.insert(newMemRef);
-              // Create new node in dependence graph for 'newMemRef' alloc op.
-              unsigned newMemRefNodeId =
-                  mdg->addNode(newMemRef.getDefiningOp());
-              // Add edge from 'newMemRef' node to dstNode.
-              mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
+          fuseLoops(srcAffineForOp, dstAffineForOp,
+                    depthSliceUnions[bestDstLoopDepth - 1]);
+
+          LLVM_DEBUG(llvm::dbgs()
+                     << "Fused src loop " << srcId << " into dst loop " << dstId
+                     << " at depth " << bestDstLoopDepth << ":\n"
+                     << dstAffineForOp << "\n");
+
+          // Move 'dstAffineForOp' before 'insertPointInst' if needed.
+          if (insertPointInst != dstAffineForOp.getOperation())
+            dstAffineForOp.getOperation()->moveBefore(insertPointInst);
+
+          // Update edges between 'srcNode' and 'dstNode'.
+          mdg->updateEdges(srcNode->id, dstNode->id, memref,
+                           createPrivateMemref);
+
+          // Collect slice loop stats.
+          LoopNestStateCollector dstForCollector;
+          dstForCollector.collect(dstAffineForOp);
+          if (createPrivateMemref) {
+            // Create private memref for 'memref' in 'dstAffineForOp'.
+            SmallVector<Operation *, 4> storesForMemref;
+            for (auto *storeOpInst : dstForCollector.storeOpInsts) {
+              if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() ==
+                  memref)
+                storesForMemref.push_back(storeOpInst);
             }
+            // TODO: Use union of memref write regions to compute
+            // private memref footprint.
+            auto newMemRef = createPrivateMemRef(
+                dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
+                fastMemorySpace, localBufSizeThreshold);
+            visitedMemrefs.insert(newMemRef);
+            // Create new node in dependence graph for 'newMemRef' alloc op.
+            unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
+            // Add edge from 'newMemRef' node to dstNode.
+            mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
+          }
 
-            // Collect dst loop stats after memref privatization transformation.
-            LoopNestStateCollector dstLoopCollector;
-            dstLoopCollector.collect(dstAffineForOp.getOperation());
-
-            // Add new load ops to current Node load op list 'loads' to
-            // continue fusing based on new operands.
-            for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
-              // NOTE: Change 'loads' to a hash set in case efficiency is an
-              // issue. We still use a vector since it's expected to be small.
-              if (!llvm::is_contained(loads, loadOpInst))
-                loads.push_back(loadOpInst);
-            }
-            // Clear visited memrefs after fusion so that previously visited src
-            // nodes are considered for fusion again in the context of the new
-            // fused node.
-            // TODO: This shouldn't be necessary if we visited candidates in the
-            // dependence graph in post-order or once we fully support
-            // multi-store producers. Currently, in a multi-store producer
-            // scenario such as A->B, A->C, B->C, we fail to fuse A+B due to the
-            // multiple outgoing edges. However, after fusing B+C, A has a
-            // single outgoing edge and can be fused if we revisit it in the
-            // context of the new fused B+C node.
-            visitedMemrefs.clear();
-
-            // Clear and add back loads and stores.
-            mdg->clearNodeLoadAndStores(dstNode->id);
-            mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
-                           dstLoopCollector.storeOpInsts);
-            // Remove old src loop nest if it no longer has outgoing dependence
-            // edges, and if it does not write to a memref which escapes the
-            // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has
-            // been fused into 'dstNode' and write region of 'dstNode' covers
-            // the write region of 'srcNode', and 'srcNode' has no other users
-            // so it is safe to remove.
-            if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) {
-              mdg->removeNode(srcNode->id);
-              srcNode->op->erase();
-            } else {
-              // Add remaining users of 'oldMemRef' back on the worklist (if not
-              // already there), as its replacement with a local/private memref
-              // has reduced dependences on 'oldMemRef' which may have created
-              // new fusion opportunities.
-              if (mdg->outEdges.count(srcNode->id) > 0) {
-                SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges =
-                    mdg->outEdges[srcNode->id];
-                for (auto &outEdge : oldOutEdges) {
-                  if (outEdge.value == memref &&
-                      worklistSet.count(outEdge.id) == 0) {
-                    worklist.push_back(outEdge.id);
-                    worklistSet.insert(outEdge.id);
-                  }
+          // Collect dst loop stats after memref privatization transformation.
+          LoopNestStateCollector dstLoopCollector;
+          dstLoopCollector.collect(dstAffineForOp.getOperation());
+
+          // Add new load ops to current Node load op list 'loads' to continue
+          // fusing based on new operands.
+          for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
+            // NOTE: Change 'loads' to a hash set in case efficiency is an
+            // issue. We still use a vector since it's expected to be small.
+            if (!llvm::is_contained(loads, loadOpInst))
+              loads.push_back(loadOpInst);
+          }
+          // Clear visited memrefs after fusion so that previously visited src
+          // nodes are considered for fusion again in the context of the new
+          // fused node.
+          // TODO: This shouldn't be necessary if we visited candidates in the
+          // dependence graph in post-order or once we fully support multi-store
+          // producers. Currently, in a multi-store producer scenario such as
+          // A->B, A->C, B->C, we fail to fuse A+B due to the multiple outgoing
+          // edges. However, after fusing B+C, A has a single outgoing edge and
+          // can be fused if we revisit it in the context of the new fused B+C
+          // node.
+          visitedMemrefs.clear();
+
+          // Clear and add back loads and stores.
+          mdg->clearNodeLoadAndStores(dstNode->id);
+          mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
+                         dstLoopCollector.storeOpInsts);
+          // Remove old src loop nest if it no longer has outgoing dependence
+          // edges, and if it does not write to a memref which escapes the
+          // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has been
+          // fused into 'dstNode' and write region of 'dstNode' covers the write
+          // region of 'srcNode', and 'srcNode' has no other users so it is safe
+          // to remove.
+          if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) {
+            mdg->removeNode(srcNode->id);
+            srcNode->op->erase();
+          } else {
+            // Add remaining users of 'oldMemRef' back on the worklist (if not
+            // already there), as its replacement with a local/private memref
+            // has reduced dependences on 'oldMemRef' which may have created new
+            // fusion opportunities.
+            if (mdg->outEdges.count(srcNode->id) > 0) {
+              SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges =
+                  mdg->outEdges[srcNode->id];
+              for (auto &outEdge : oldOutEdges) {
+                if (outEdge.value == memref &&
+                    worklistSet.count(outEdge.id) == 0) {
+                  worklist.push_back(outEdge.id);
+                  worklistSet.insert(outEdge.id);
                 }
               }
             }
@@ -1759,6 +1655,8 @@ struct GreedyFusion {
   void fuseWithSiblingNodes(Node *dstNode) {
     DenseSet<unsigned> visitedSibNodeIds;
     std::pair<unsigned, Value> idAndMemref;
+    auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
+
     while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
       unsigned sibId = idAndMemref.first;
       Value memref = idAndMemref.second;
@@ -1791,31 +1689,53 @@ struct GreedyFusion {
       SmallVector<Operation *, 2> dstLoadOpInsts;
       dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
 
-      // Gather 'dstNode' store ops to 'memref'.
-      SmallVector<Operation *, 2> dstStoreOpInsts;
-      dstNode->getStoreOpsForMemref(memref, &dstStoreOpInsts);
-
-      unsigned bestDstLoopDepth;
-      mlir::ComputationSliceState sliceState;
+      SmallVector<AffineForOp, 4> dstLoopIVs;
+      getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
+      unsigned dstLoopDepthTest = dstLoopIVs.size();
+      auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
+
+      // Compute loop depth and slice union for fusion.
+      SmallVector<ComputationSliceState, 8> depthSliceUnions;
+      depthSliceUnions.resize(dstLoopDepthTest);
+      unsigned maxLegalFusionDepth = 0;
+      FusionStrategy strategy(FusionStrategy::Sibling, memref);
+      for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
+        FusionResult result = mlir::canFuseLoops(
+            sibAffineForOp, dstAffineForOp,
+            /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);
+
+        if (result.value == FusionResult::Success)
+          maxLegalFusionDepth = i;
+      }
 
-      // Check if fusion would be profitable.
-      if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts,
-                              dstStoreOpInsts, &sliceState, &bestDstLoopDepth,
-                              maximalFusion, computeToleranceThreshold))
+      // Skip if fusion is not feasible at any loop depths.
+      if (maxLegalFusionDepth == 0)
         continue;
 
+      unsigned bestDstLoopDepth = dstLoopDepthTest;
+      if (!maximalFusion) {
+        // Check if fusion would be profitable.
+        if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts,
+                                depthSliceUnions, maxLegalFusionDepth,
+                                &bestDstLoopDepth, computeToleranceThreshold))
+          continue;
+      }
+
+      assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
+      assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
+             "Fusion depth has no computed slice union");
+
       // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
-      auto sliceLoopNest = mlir::insertBackwardComputationSlice(
-          sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
-      if (sliceLoopNest != nullptr) {
-        auto dstForInst = cast<AffineForOp>(dstNode->op);
-        // Update operation position of fused loop nest (if needed).
-        if (insertPointInst != dstForInst.getOperation()) {
-          dstForInst.getOperation()->moveBefore(insertPointInst);
-        }
-        // Update data dependence graph state post fusion.
-        updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode);
+      mlir::fuseLoops(sibAffineForOp, dstAffineForOp,
+                      depthSliceUnions[bestDstLoopDepth - 1]);
+
+      auto dstForInst = cast<AffineForOp>(dstNode->op);
+      // Update operation position of fused loop nest (if needed).
+      if (insertPointInst != dstForInst.getOperation()) {
+        dstForInst.getOperation()->moveBefore(insertPointInst);
       }
+      // Update data dependence graph state post fusion.
+      updateStateAfterSiblingFusion(sibNode, dstNode);
     }
   }
 
@@ -1943,19 +1863,12 @@ struct GreedyFusion {
     return false;
   }
 
-  void updateStateAfterSiblingFusion(AffineForOp sliceLoopNest, Node *sibNode,
-                                     Node *dstNode) {
+  /// Update data dependence graph state to reflect sibling fusion of 'sibNode'
+  /// into 'dstNode'.
+  void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) {
     // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
     mdg->updateEdges(sibNode->id, dstNode->id);
 
-    // Collect slice loop stats.
-    LoopNestStateCollector sliceCollector;
-    sliceCollector.collect(sliceLoopNest.getOperation());
-    // Promote single iteration slice loops to single IV value.
-    for (auto forOp : sliceCollector.forOps) {
-      promoteIfSingleIteration(forOp);
-    }
-
     // Collect dst loop stats after memref privatization transformation.
     auto dstForInst = cast<AffineForOp>(dstNode->op);
     LoopNestStateCollector dstLoopCollector;

diff  --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
index 4ac710c02dea..7ce88ef5796a 100644
--- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
@@ -47,9 +47,9 @@ static void getLoadAndStoreMemRefAccesses(Operation *opA,
   });
 }
 
-// Returns true if 'op' is a load or store operation which access an memref
-// accessed 'values' and at least one of the access is a store operation.
-// Returns false otherwise.
+/// Returns true if 'op' is a load or store operation which access a memref
+/// accessed 'values' and at least one of the access is a store operation.
+/// Returns false otherwise.
 static bool isDependentLoadOrStoreOp(Operation *op,
                                      DenseMap<Value, bool> &values) {
   if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
@@ -187,26 +187,99 @@ gatherLoadsAndStores(AffineForOp forOp,
   return !hasIfOp;
 }
 
+/// Returns the maximum loop depth at which we could fuse producer loop
+/// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences.
+// TODO: Generalize this check for sibling and more generic fusion scenarios.
+// TODO: Support forward slice fusion.
+static unsigned getMaxLoopDepth(ArrayRef<Operation *> dstOps,
+                                FusionStrategy fusionStrategy) {
+  assert(fusionStrategy.strategy == FusionStrategy::ProducerConsumer &&
+         "Fusion strategy not supported");
+
+  if (dstOps.empty())
+    // Expected at least one memory operation.
+    // TODO: Revisit this case with a specific example.
+    return 0;
+
+  // Filter out ops in 'dstOps' that do not use the producer-consumer memref so
+  // that they are not considered for analysis.
+  // TODO: Currently, we pass the producer-consumer memref through
+  // fusionStrategy. We will retrieve the memrefs from 'srcOps' once we
+  // generalize the algorithm.
+  SmallVector<Operation *, 4> targetDstOps;
+  for (Operation *dstOp : dstOps) {
+    auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp);
+    Value memref = loadOp ? loadOp.getMemRef()
+                          : cast<AffineWriteOpInterface>(dstOp).getMemRef();
+    if (memref == fusionStrategy.memref)
+      targetDstOps.push_back(dstOp);
+  }
+
+  assert(!targetDstOps.empty() &&
+         "No dependences between 'srcForOp' and 'dstForOp'?");
+
+  // Compute the innermost common loop depth for loads and stores.
+  unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps);
+
+  // Return common loop depth for loads if there are no store ops.
+  if (all_of(targetDstOps,
+             [&](Operation *op) { return isa<AffineReadOpInterface>(op); }))
+    return loopDepth;
+
+  // Check dependences on all pairs of ops in 'targetDstOps' and store the
+  // minimum loop depth at which a dependence is satisfied.
+  for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) {
+    auto *srcOpInst = targetDstOps[i];
+    MemRefAccess srcAccess(srcOpInst);
+    for (unsigned j = 0; j < e; ++j) {
+      auto *dstOpInst = targetDstOps[j];
+      MemRefAccess dstAccess(dstOpInst);
+
+      unsigned numCommonLoops =
+          getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
+      for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
+        FlatAffineConstraints dependenceConstraints;
+        // TODO: Cache dependence analysis results, check cache here.
+        DependenceResult result = checkMemrefAccessDependence(
+            srcAccess, dstAccess, d, &dependenceConstraints,
+            /*dependenceComponents=*/nullptr);
+        if (hasDependence(result)) {
+          // Store minimum loop depth and break because we want the min 'd' at
+          // which there is a dependence.
+          loopDepth = std::min(loopDepth, d - 1);
+          break;
+        }
+      }
+    }
+  }
+
+  return loopDepth;
+}
+
 // TODO: Prevent fusion of loop nests with side-effecting operations.
+// TODO: This pass performs some computation that is the same for all the depths
+// (e.g., getMaxLoopDepth). Implement a version of this utility that processes
+// all the depths at once or only the legal maximal depth for maximal fusion.
 FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
                                 unsigned dstLoopDepth,
-                                ComputationSliceState *srcSlice) {
+                                ComputationSliceState *srcSlice,
+                                FusionStrategy fusionStrategy) {
   // Return 'failure' if 'dstLoopDepth == 0'.
   if (dstLoopDepth == 0) {
-    LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n.");
+    LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n");
     return FusionResult::FailPrecondition;
   }
   // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
   auto *block = srcForOp.getOperation()->getBlock();
   if (block != dstForOp.getOperation()->getBlock()) {
-    LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in 
diff erent blocks\n.");
+    LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in 
diff erent blocks\n");
     return FusionResult::FailPrecondition;
   }
 
   // Return 'failure' if no valid insertion point for fused loop nest in 'block'
   // exists which would preserve dependences.
   if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
-    LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n.");
+    LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n");
     return FusionResult::FailBlockDependence;
   }
 
@@ -220,25 +293,68 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
   // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
   SmallVector<Operation *, 4> opsA;
   if (!gatherLoadsAndStores(forOpA, opsA)) {
-    LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
+    LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
     return FusionResult::FailPrecondition;
   }
 
   // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
   SmallVector<Operation *, 4> opsB;
   if (!gatherLoadsAndStores(forOpB, opsB)) {
-    LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
+    LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
     return FusionResult::FailPrecondition;
   }
 
+  // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve
+  // loop dependences.
+  // TODO: Enable this check for sibling and more generic loop fusion
+  // strategies.
+  if (fusionStrategy.strategy == FusionStrategy::ProducerConsumer) {
+    // TODO: 'getMaxLoopDepth' does not support forward slice fusion.
+    assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
+    if (getMaxLoopDepth(opsB, fusionStrategy) < dstLoopDepth) {
+      LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
+      return FusionResult::FailFusionDependence;
+    }
+  }
+
   // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
   unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops(
       *srcForOp.getOperation(), *dstForOp.getOperation());
 
+  // Filter out ops in 'opsA' to compute the slice union based on the
+  // assumptions made by the fusion strategy.
+  SmallVector<Operation *, 4> strategyOpsA;
+  switch (fusionStrategy.strategy) {
+  case FusionStrategy::Generic:
+    // Generic fusion. Take into account all the memory operations to compute
+    // the slice union.
+    strategyOpsA.append(opsA.begin(), opsA.end());
+    break;
+  case FusionStrategy::ProducerConsumer:
+    // Producer-consumer fusion (AffineLoopFusion pass) only takes into
+    // account stores to 'memref' in 'srcForOp' to compute the slice union.
+    for (Operation *op : opsA) {
+      auto store = dyn_cast<AffineWriteOpInterface>(op);
+      if (store && store.getMemRef() == fusionStrategy.memref)
+        strategyOpsA.push_back(op);
+    }
+    break;
+  case FusionStrategy::Sibling:
+    // Sibling fusion (AffineLoopFusion pass) only takes into account the loads
+    // to 'memref' in 'srcForOp' to compute the slice union.
+    for (Operation *op : opsA) {
+      auto load = dyn_cast<AffineReadOpInterface>(op);
+      if (load && load.getMemRef() == fusionStrategy.memref)
+        strategyOpsA.push_back(op);
+    }
+    break;
+  }
+
   // Compute union of computation slices computed between all pairs of ops
   // from 'forOpA' and 'forOpB'.
-  if (failed(mlir::computeSliceUnion(opsA, opsB, dstLoopDepth, numCommonLoops,
-                                     isSrcForOpBeforeDstForOp, srcSlice))) {
+  if (failed(mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth,
+                                     numCommonLoops, isSrcForOpBeforeDstForOp,
+                                     srcSlice))) {
     LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
     return FusionResult::FailPrecondition;
   }
@@ -249,24 +365,30 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
 /// and source slice loop bounds specified in 'srcSlice'.
 void mlir::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
-                     ComputationSliceState *srcSlice) {
+                     const ComputationSliceState &srcSlice) {
   // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'.
-  OpBuilder b(srcSlice->insertPoint->getBlock(), srcSlice->insertPoint);
+  OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint);
   BlockAndValueMapping mapper;
   b.clone(*srcForOp, mapper);
 
   // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'.
   SmallVector<AffineForOp, 4> sliceLoops;
-  for (unsigned i = 0, e = srcSlice->ivs.size(); i < e; ++i) {
-    auto loopIV = mapper.lookupOrNull(srcSlice->ivs[i]);
+  for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) {
+    auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]);
     if (!loopIV)
       continue;
     auto forOp = getForInductionVarOwner(loopIV);
     sliceLoops.push_back(forOp);
-    if (AffineMap lbMap = srcSlice->lbs[i])
-      forOp.setLowerBound(srcSlice->lbOperands[i], lbMap);
-    if (AffineMap ubMap = srcSlice->ubs[i])
-      forOp.setUpperBound(srcSlice->ubOperands[i], ubMap);
+    if (AffineMap lbMap = srcSlice.lbs[i]) {
+      auto lbOperands = srcSlice.lbOperands[i];
+      canonicalizeMapAndOperands(&lbMap, &lbOperands);
+      forOp.setLowerBound(lbOperands, lbMap);
+    }
+    if (AffineMap ubMap = srcSlice.ubs[i]) {
+      auto ubOperands = srcSlice.ubOperands[i];
+      canonicalizeMapAndOperands(&ubMap, &ubOperands);
+      forOp.setUpperBound(ubOperands, ubMap);
+    }
   }
 
   // Promote any single iteration slice loops.
@@ -393,15 +515,15 @@ static uint64_t getSliceIterationCount(
 // was encountered).
 // TODO: Make this work with non-unit step loops.
 static bool buildSliceTripCountMap(
-    ComputationSliceState *slice,
+    const ComputationSliceState &slice,
     llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
-  unsigned numSrcLoopIVs = slice->ivs.size();
+  unsigned numSrcLoopIVs = slice.ivs.size();
   // Populate map from AffineForOp -> trip count
   for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
-    AffineForOp forOp = getForInductionVarOwner(slice->ivs[i]);
+    AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]);
     auto *op = forOp.getOperation();
-    AffineMap lbMap = slice->lbs[i];
-    AffineMap ubMap = slice->ubs[i];
+    AffineMap lbMap = slice.lbs[i];
+    AffineMap ubMap = slice.ubs[i];
     if (lbMap == AffineMap() || ubMap == AffineMap()) {
       // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
       if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
@@ -442,7 +564,7 @@ int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
 /// the entire loop nest.
 bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
                                 AffineForOp dstForOp, LoopNestStats &dstStats,
-                                ComputationSliceState *slice,
+                                const ComputationSliceState &slice,
                                 int64_t *computeCost) {
   llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
   DenseMap<Operation *, int64_t> computeCostMap;
@@ -454,7 +576,7 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
   int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
   assert(sliceIterationCount > 0);
   bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
-  auto *insertPointParent = slice->insertPoint->getParentOp();
+  auto *insertPointParent = slice.insertPoint->getParentOp();
 
   // The store and loads to this memref will disappear.
   // TODO: Add load coalescing to memref data flow opt pass.

diff  --git a/mlir/test/lib/Transforms/TestLoopFusion.cpp b/mlir/test/lib/Transforms/TestLoopFusion.cpp
index 6cf1ae0895bd..034abfe603c6 100644
--- a/mlir/test/lib/Transforms/TestLoopFusion.cpp
+++ b/mlir/test/lib/Transforms/TestLoopFusion.cpp
@@ -129,7 +129,7 @@ static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB,
     mlir::ComputationSliceState sliceUnion;
     FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
     if (result.value == FusionResult::Success) {
-      mlir::fuseLoops(forOpA, forOpB, &sliceUnion);
+      mlir::fuseLoops(forOpA, forOpB, sliceUnion);
       // Note: 'forOpA' is removed to simplify test output. A proper loop
       // fusion pass should check the data dependence graph and run memref
       // region analysis to ensure removing 'forOpA' is safe.


        


More information about the Mlir-commits mailing list