[Mlir-commits] [mlir] c8fc5c0 - [mlir][Affine] Add support for multi-store producer fusion

Diego Caballero llvmlistbot at llvm.org
Mon Jan 25 10:32:04 PST 2021


Author: Diego Caballero
Date: 2021-01-25T20:31:17+02:00
New Revision: c8fc5c0385dbb47623c1cca5efa0b96d5e5f8151

URL: https://github.com/llvm/llvm-project/commit/c8fc5c0385dbb47623c1cca5efa0b96d5e5f8151
DIFF: https://github.com/llvm/llvm-project/commit/c8fc5c0385dbb47623c1cca5efa0b96d5e5f8151.diff

LOG: [mlir][Affine] Add support for multi-store producer fusion

This patch adds support for producer-consumer fusion scenarios with
multiple producer stores to the AffineLoopFusion pass. The patch
introduces some changes to the producer-consumer algorithm, including:

* For a given consumer loop, producer-consumer fusion iterates over its
producer candidates until a fixed point is reached.

* Producer candidates are gathered beforehand for each iteration of the
consumer loop and visited in reverse program order (not strictly guaranteed)
to maximize the number of loops fused per iteration.

In general, these changes were needed to simplify the multi-store producer
support and remove some of the workarounds that were introduced in the past
to support more fusion cases under the single-store producer limitation.

This patch also preserves the existing functionality of AffineLoopFusion with
one minor change in behavior. Producer-consumer fusion didn't fuse scenarios
with escaping memrefs and multiple outgoing edges (from a single store).
Multi-store producer scenarios will usually (always?) have multiple outgoing
edges so we couldn't fuse any with escaping memrefs, which would greatly limit
the applicability of this new feature. Therefore, the patch enables fusion for
these scenarios. Please, see modified tests for specific details.

Reviewed By: andydavis1, bondhugula

Differential Revision: https://reviews.llvm.org/D92876

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/AffineStructures.h
    mlir/include/mlir/Analysis/Utils.h
    mlir/include/mlir/Transforms/LoopFusionUtils.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Analysis/AffineStructures.cpp
    mlir/lib/Analysis/Utils.cpp
    mlir/lib/Transforms/LoopFusion.cpp
    mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
    mlir/test/Transforms/loop-fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h
index fa42db500684..6aa0a38243b9 100644
--- a/mlir/include/mlir/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Analysis/AffineStructures.h
@@ -232,6 +232,21 @@ class FlatAffineConstraints {
   //  TODO: add support for non-unit strides.
   LogicalResult addAffineForOpDomain(AffineForOp forOp);
 
+  /// Adds constraints (lower and upper bounds) for each loop in the loop nest
+  /// described by the bound maps 'lbMaps' and 'ubMaps' of a computation slice.
+  /// Every pair ('lbMaps[i]', 'ubMaps[i]') describes the bounds of a loop in
+  /// the nest, sorted outer-to-inner. 'operands' contains the bound operands
+  /// for a single bound map. All the bound maps will use the same bound
+  /// operands. Note that some loops described by a computation slice might not
+  /// exist yet in the IR so the Value attached to those dimension identifiers
+  /// might be empty. For that reason, this method doesn't perform Value
+  /// look-ups to retrieve the dimension identifier positions. Instead, it
+  /// assumes the position of the dim identifiers in the constraint system is
+  /// the same as the position of the loop in the loop nest.
+  LogicalResult addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
+                                       ArrayRef<AffineMap> ubMaps,
+                                       ArrayRef<Value> operands);
+
   /// Adds constraints imposed by the `affine.if` operation. These constraints
   /// are collected from the IntegerSet attached to the given `affine.if`
   /// instance argument (`ifOp`). It is asserted that:

diff  --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h
index 30b6272181f5..ee6f8095f25e 100644
--- a/mlir/include/mlir/Analysis/Utils.h
+++ b/mlir/include/mlir/Analysis/Utils.h
@@ -83,10 +83,25 @@ struct ComputationSliceState {
   // Clears all bounds and operands in slice state.
   void clearBounds();
 
-  /// Return true if the computation slice is empty.
+  /// Returns true if the computation slice is empty.
   bool isEmpty() const { return ivs.empty(); }
 
+  /// Returns true if the computation slice encloses all the iterations of the
+  /// sliced loop nest. Returns false if it does not. Returns llvm::None if it
+  /// cannot determine if the slice is maximal or not.
+  // TODO: Cache 'isMaximal' so that we don't recompute it when the slice
+  // information hasn't changed.
+  Optional<bool> isMaximal() const;
+
   void dump() const;
+
+private:
+  /// Fast check to determine if the computation slice is maximal. Returns true
+  /// if each slice dimension maps to an existing dst dimension and both the src
+  /// and the dst loops for those dimensions have the same bounds. Returns false
+  /// if both the src and the dst loops don't have the same bounds. Returns
+  /// llvm::None if none of the above can be proven.
+  Optional<bool> isSliceMaximalFastCheck() const;
 };
 
 /// Computes the computation slice loop bounds for one loop nest as affine maps

diff  --git a/mlir/include/mlir/Transforms/LoopFusionUtils.h b/mlir/include/mlir/Transforms/LoopFusionUtils.h
index eade565e0325..10d6b83d022f 100644
--- a/mlir/include/mlir/Transforms/LoopFusionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopFusionUtils.h
@@ -50,7 +50,8 @@ struct FusionResult {
 // TODO: Generalize utilities so that producer-consumer and sibling fusion
 // strategies can be used without the assumptions made in the AffineLoopFusion
 // pass.
-struct FusionStrategy {
+class FusionStrategy {
+public:
   enum StrategyEnum {
     // Generic loop fusion: Arbitrary loops are considered for fusion. No
     // assumptions about a specific fusion strategy from AffineLoopFusion pass
@@ -69,13 +70,34 @@ struct FusionStrategy {
     // implementation in AffineLoopFusion pass are made. See pass for specific
     // details.
     Sibling
-  } strategy;
+  };
 
-  // Target memref for this fusion transformation.
-  Value memref;
+  /// Construct a generic or producer-consumer fusion strategy.
+  FusionStrategy(StrategyEnum strategy) : strategy(strategy) {
+    assert(strategy != Sibling &&
+           "Sibling fusion strategy requires a specific memref");
+  }
+
+  /// Construct a sibling fusion strategy targeting 'memref'. This construct
+  /// should only be used for sibling fusion.
+  FusionStrategy(Value memref) : strategy(Sibling), memref(memref) {}
+
+  /// Returns the fusion strategy.
+  StrategyEnum getStrategy() const { return strategy; };
 
-  FusionStrategy(StrategyEnum strategy, Value memref)
-      : strategy(strategy), memref(memref) {}
+  /// Returns the memref attached to this sibling fusion strategy.
+  Value getSiblingFusionMemRef() const {
+    assert(strategy == Sibling && "Memref is only valid for sibling fusion");
+    return memref;
+  }
+
+private:
+  /// Fusion strategy.
+  StrategyEnum strategy;
+
+  /// Target memref for this fusion transformation. Only used for sibling
+  /// fusion.
+  Value memref;
 };
 
 /// Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the
@@ -86,11 +108,10 @@ struct FusionStrategy {
 /// NOTE: This function is not feature complete and should only be used in
 /// testing.
 /// TODO: Update comments when this function is fully implemented.
-FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
-                          unsigned dstLoopDepth,
-                          ComputationSliceState *srcSlice,
-                          FusionStrategy fusionStrategy = {
-                              FusionStrategy::Generic, Value()});
+FusionResult
+canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth,
+             ComputationSliceState *srcSlice,
+             FusionStrategy fusionStrategy = FusionStrategy::Generic);
 
 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
 /// and source slice loop bounds specified in 'srcSlice'.
@@ -134,6 +155,12 @@ bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
                           const ComputationSliceState &slice,
                           int64_t *computeCost);
 
+/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
+/// producer-consumer dependence between write ops in 'srcOps' and read ops in
+/// 'dstOps'.
+void gatherProducerConsumerMemrefs(ArrayRef<Operation *> srcOps,
+                                   ArrayRef<Operation *> dstOps,
+                                   DenseSet<Value> &producerConsumerMemrefs);
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 438a468673b5..a03b439af339 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -17,6 +17,111 @@ include "mlir/Pass/PassBase.td"
 
 def AffineLoopFusion : FunctionPass<"affine-loop-fusion"> {
   let summary = "Fuse affine loop nests";
+  let description = [{
+    This pass performs fusion of loop nests using a slicing-based approach. It
+    combines two fusion strategies: producer-consumer fusion and sibling fusion.
+    Producer-consumer fusion is aimed at fusing pairs of loops where the first
+    one writes to a memref that the second reads. Sibling fusion targets pairs
+    of loops that share no dependences between them but that load from the same
+    memref. The fused loop nests, when possible, are rewritten to access
+    significantly smaller local buffers instead of the original memref's, and
+    the latter are often either completely optimized away or contracted. This
+    transformation leads to enhanced locality and lower memory footprint through
+    the elimination or contraction of temporaries/intermediate memref's. These
+    benefits are sometimes achieved at the expense of redundant computation
+    through a cost model that evaluates available choices such as the depth at
+    which a source slice should be materialized in the designation slice.
+
+    Example 1: Producer-consumer fusion.
+    Input:
+    ```mlir
+    func @producer_consumer_fusion(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
+      %0 = alloc() : memref<10xf32>
+      %1 = alloc() : memref<10xf32>
+      %cst = constant 0.000000e+00 : f32
+      affine.for %arg2 = 0 to 10 {
+        affine.store %cst, %0[%arg2] : memref<10xf32>
+        affine.store %cst, %1[%arg2] : memref<10xf32>
+      }
+      affine.for %arg2 = 0 to 10 {
+        %2 = affine.load %0[%arg2] : memref<10xf32>
+        %3 = addf %2, %2 : f32
+        affine.store %3, %arg0[%arg2] : memref<10xf32>
+      }
+      affine.for %arg2 = 0 to 10 {
+        %2 = affine.load %1[%arg2] : memref<10xf32>
+        %3 = mulf %2, %2 : f32
+        affine.store %3, %arg1[%arg2] : memref<10xf32>
+      }
+      return
+    }
+    ```
+    Output:
+    ```mlir
+    func @producer_consumer_fusion(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
+      %0 = alloc() : memref<1xf32>
+      %1 = alloc() : memref<1xf32>
+      %cst = constant 0.000000e+00 : f32
+      affine.for %arg2 = 0 to 10 {
+        affine.store %cst, %0[0] : memref<1xf32>
+        affine.store %cst, %1[0] : memref<1xf32>
+        %2 = affine.load %1[0] : memref<1xf32>
+        %3 = mulf %2, %2 : f32
+        affine.store %3, %arg1[%arg2] : memref<10xf32>
+        %4 = affine.load %0[0] : memref<1xf32>
+        %5 = addf %4, %4 : f32
+        affine.store %5, %arg0[%arg2] : memref<10xf32>
+      }
+      return
+    }
+    ```
+
+    Example 2: Sibling fusion.
+    Input:
+    ```mlir
+    func @sibling_fusion(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>,
+                         %arg2: memref<10x10xf32>, %arg3: memref<10x10xf32>,
+                         %arg4: memref<10x10xf32>) {
+      affine.for %arg5 = 0 to 3 {
+        affine.for %arg6 = 0 to 3 {
+          %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
+          %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32>
+          %2 = mulf %0, %1 : f32
+          affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32>
+        }
+      }
+      affine.for %arg5 = 0 to 3 {
+        affine.for %arg6 = 0 to 3 {
+          %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
+          %1 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32>
+          %2 = addf %0, %1 : f32
+          affine.store %2, %arg4[%arg5, %arg6] : memref<10x10xf32>
+        }
+      }
+      return
+    }
+    ```
+    Output:
+    ```mlir
+    func @sibling_fusion(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>,
+                         %arg2: memref<10x10xf32>, %arg3: memref<10x10xf32>,
+                         %arg4: memref<10x10xf32>) {
+      affine.for %arg5 = 0 to 3 {
+        affine.for %arg6 = 0 to 3 {
+          %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
+          %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32>
+          %2 = mulf %0, %1 : f32
+          affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32>
+          %3 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
+          %4 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32>
+          %5 = addf %3, %4 : f32
+          affine.store %5, %arg4[%arg5, %arg6] : memref<10x10xf32>
+        }
+      }
+      return
+    }
+    ```
+  }];
   let constructor = "mlir::createLoopFusionPass()";
   let options = [
     Option<"computeToleranceThreshold", "fusion-compute-tolerance", "double",

diff  --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index bf4bcab4657f..e81d92ae4a00 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -709,6 +709,70 @@ LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) {
                               /*eq=*/false, /*lower=*/false);
 }
 
+/// Adds constraints (lower and upper bounds) for each loop in the loop nest
+/// described by the bound maps 'lbMaps' and 'ubMaps' of a computation slice.
+/// Every pair ('lbMaps[i]', 'ubMaps[i]') describes the bounds of a loop in
+/// the nest, sorted outer-to-inner. 'operands' contains the bound operands
+/// for a single bound map. All the bound maps will use the same bound
+/// operands. Note that some loops described by a computation slice might not
+/// exist yet in the IR so the Value attached to those dimension identifiers
+/// might be empty. For that reason, this method doesn't perform Value
+/// look-ups to retrieve the dimension identifier positions. Instead, it
+/// assumes the position of the dim identifiers in the constraint system is
+/// the same as the position of the loop in the loop nest.
+LogicalResult
+FlatAffineConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
+                                              ArrayRef<AffineMap> ubMaps,
+                                              ArrayRef<Value> operands) {
+  assert(lbMaps.size() == ubMaps.size());
+  assert(lbMaps.size() <= getNumDimIds());
+
+  for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
+    AffineMap lbMap = lbMaps[i];
+    AffineMap ubMap = ubMaps[i];
+    assert(!lbMap || lbMap.getNumInputs() == operands.size());
+    assert(!ubMap || ubMap.getNumInputs() == operands.size());
+
+    // Check if this slice is just an equality along this dimension. If so,
+    // retrieve the existing loop it equates to and add it to the system.
+    if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
+        ubMap.getNumResults() == 1 &&
+        lbMap.getResult(0) + 1 == ubMap.getResult(0) &&
+        // The condition above will be true for maps describing a single
+        // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
+        // Make sure we skip those cases by checking that the lb result is not
+        // just a constant.
+        !lbMap.getResult(0).isa<AffineConstantExpr>()) {
+      // Limited support: we expect the lb result to be just a loop dimension.
+      // Not supported otherwise for now.
+      AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
+      if (!result)
+        return failure();
+
+      AffineForOp loop =
+          getForInductionVarOwner(operands[result.getPosition()]);
+      if (!loop)
+        return failure();
+
+      if (failed(addAffineForOpDomain(loop)))
+        return failure();
+      continue;
+    }
+
+    // This slice refers to a loop that doesn't exist in the IR yet. Add its
+    // bounds to the system assuming its dimension identifier position is the
+    // same as the position of the loop in the loop nest.
+    if (lbMap && failed(addLowerOrUpperBound(i, lbMap, operands, /*eq=*/false,
+                                             /*lower=*/true)))
+      return failure();
+
+    if (ubMap && failed(addLowerOrUpperBound(i, ubMap, operands, /*eq=*/false,
+                                             /*lower=*/false)))
+      return failure();
+  }
+  return success();
+}
+
 void FlatAffineConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
   // Create the base constraints from the integer set attached to ifOp.
   FlatAffineConstraints cst(ifOp.getIntegerSet());

diff  --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index a1e7d1ffe844..383a6587bbef 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -12,8 +12,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/Utils.h"
-
 #include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/PresburgerSet.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -127,6 +127,128 @@ void ComputationSliceState::dump() const {
   }
 }
 
+/// Fast check to determine if the computation slice is maximal. Returns true if
+/// each slice dimension maps to an existing dst dimension and both the src
+/// and the dst loops for those dimensions have the same bounds. Returns false
+/// if both the src and the dst loops don't have the same bounds. Returns
+/// llvm::None if none of the above can be proven.
+Optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
+  assert(lbs.size() == ubs.size() && lbs.size() && ivs.size() &&
+         "Unexpected number of lbs, ubs and ivs in slice");
+
+  for (unsigned i = 0, end = lbs.size(); i < end; ++i) {
+    AffineMap lbMap = lbs[i];
+    AffineMap ubMap = ubs[i];
+
+    // Check if this slice is just an equality along this dimension.
+    if (!lbMap || !ubMap || lbMap.getNumResults() != 1 ||
+        ubMap.getNumResults() != 1 ||
+        lbMap.getResult(0) + 1 != ubMap.getResult(0) ||
+        // The condition above will be true for maps describing a single
+        // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
+        // Make sure we skip those cases by checking that the lb result is not
+        // just a constant.
+        lbMap.getResult(0).isa<AffineConstantExpr>())
+      return llvm::None;
+
+    // Limited support: we expect the lb result to be just a loop dimension for
+    // now.
+    AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
+    if (!result)
+      return llvm::None;
+
+    // Retrieve dst loop bounds.
+    AffineForOp dstLoop =
+        getForInductionVarOwner(lbOperands[i][result.getPosition()]);
+    if (!dstLoop)
+      return llvm::None;
+    AffineMap dstLbMap = dstLoop.getLowerBoundMap();
+    AffineMap dstUbMap = dstLoop.getUpperBoundMap();
+
+    // Retrieve src loop bounds.
+    AffineForOp srcLoop = getForInductionVarOwner(ivs[i]);
+    assert(srcLoop && "Expected affine for");
+    AffineMap srcLbMap = srcLoop.getLowerBoundMap();
+    AffineMap srcUbMap = srcLoop.getUpperBoundMap();
+
+    // Limited support: we expect simple src and dst loops with a single
+    // constant component per bound for now.
+    if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 ||
+        dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1)
+      return llvm::None;
+
+    AffineExpr srcLbResult = srcLbMap.getResult(0);
+    AffineExpr dstLbResult = dstLbMap.getResult(0);
+    AffineExpr srcUbResult = srcUbMap.getResult(0);
+    AffineExpr dstUbResult = dstUbMap.getResult(0);
+    if (!srcLbResult.isa<AffineConstantExpr>() ||
+        !srcUbResult.isa<AffineConstantExpr>() ||
+        !dstLbResult.isa<AffineConstantExpr>() ||
+        !dstUbResult.isa<AffineConstantExpr>())
+      return llvm::None;
+
+    // Check if src and dst loop bounds are the same. If not, we can guarantee
+    // that the slice is not maximal.
+    if (srcLbResult != dstLbResult || srcUbResult != dstUbResult)
+      return false;
+  }
+
+  return true;
+}
+
+/// Returns true if the computation slice encloses all the iterations of the
+/// sliced loop nest. Returns false if it does not. Returns llvm::None if it
+/// cannot determine if the slice is maximal or not.
+Optional<bool> ComputationSliceState::isMaximal() const {
+  // Fast check to determine if the computation slice is maximal. If the result
+  // is inconclusive, we proceed with a more expensive analysis.
+  Optional<bool> isMaximalFastCheck = isSliceMaximalFastCheck();
+  if (isMaximalFastCheck.hasValue())
+    return isMaximalFastCheck;
+
+  // Create constraints for the src loop nest being sliced.
+  FlatAffineConstraints srcConstraints;
+  srcConstraints.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0,
+                       /*numLocals=*/0, ivs);
+  for (Value iv : ivs) {
+    AffineForOp loop = getForInductionVarOwner(iv);
+    assert(loop && "Expected affine for");
+    if (failed(srcConstraints.addAffineForOpDomain(loop)))
+      return llvm::None;
+  }
+
+  // Create constraints for the slice using the dst loop nest information. We
+  // retrieve existing dst loops from the lbOperands.
+  SmallVector<Value, 8> consumerIVs;
+  for (Value lbOp : lbOperands[0])
+    if (getForInductionVarOwner(lbOp))
+      consumerIVs.push_back(lbOp);
+
+  // Add empty IV Values for those new loops that are not equalities and,
+  // therefore, are not yet materialized in the IR.
+  for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i)
+    consumerIVs.push_back(Value());
+
+  FlatAffineConstraints sliceConstraints;
+  sliceConstraints.reset(/*numDims=*/consumerIVs.size(), /*numSymbols=*/0,
+                         /*numLocals=*/0, consumerIVs);
+
+  if (failed(sliceConstraints.addDomainFromSliceMaps(lbs, ubs, lbOperands[0])))
+    return llvm::None;
+
+  if (srcConstraints.getNumDimIds() != sliceConstraints.getNumDimIds())
+    // Constraint dims are 
diff erent. The integer set 
diff erence can't be
+    // computed so we don't know if the slice is maximal.
+    return llvm::None;
+
+  // Compute the 
diff erence between the src loop nest and the slice integer
+  // sets.
+  PresburgerSet srcSet(srcConstraints);
+  PresburgerSet sliceSet(sliceConstraints);
+  PresburgerSet 
diff Set = srcSet.subtract(sliceSet);
+  return 
diff Set.isIntegerEmpty();
+}
+
 unsigned MemRefRegion::getRank() const {
   return memref.getType().cast<MemRefType>().getRank();
 }

diff  --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 6fe112b89baf..96889f15c502 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -270,64 +270,6 @@ struct MemRefDependenceGraph {
     return false;
   }
 
-  // Returns the unique AffineWriteOpInterface in `node` that meets all the
-  // following:
-  //   *) store is the only one that writes to a function-local memref live out
-  //      of `node`,
-  //   *) store is not the source of a self-dependence on `node`.
-  // Otherwise, returns a null AffineWriteOpInterface.
-  AffineWriteOpInterface getUniqueOutgoingStore(Node *node) {
-    AffineWriteOpInterface uniqueStore;
-
-    // Return null if `node` doesn't have any outgoing edges.
-    auto outEdgeIt = outEdges.find(node->id);
-    if (outEdgeIt == outEdges.end())
-      return nullptr;
-
-    const auto &nodeOutEdges = outEdgeIt->second;
-    for (auto *op : node->stores) {
-      auto storeOp = cast<AffineWriteOpInterface>(op);
-      auto memref = storeOp.getMemRef();
-      // Skip this store if there are no dependences on its memref. This means
-      // that store either:
-      // *) writes to a memref that is only read within the same loop nest
-      //    (self-dependence edges are not represented in graph at the moment),
-      // *) writes to a function live out memref (function parameter), or
-      // *) is dead.
-      if (llvm::all_of(nodeOutEdges, [=](const Edge &edge) {
-            return (edge.value != memref);
-          }))
-        continue;
-
-      if (uniqueStore)
-        // Found multiple stores to function-local live-out memrefs.
-        return nullptr;
-      // Found first store to function-local live-out memref.
-      uniqueStore = storeOp;
-    }
-
-    return uniqueStore;
-  }
-
-  // Returns true if node 'id' can be removed from the graph. Returns false
-  // otherwise. A node can be removed from the graph iff the following
-  // conditions are met:
-  // *) The node does not write to any memref which escapes (or is a
-  //    function/block argument).
-  // *) The node has no successors in the dependence graph.
-  bool canRemoveNode(unsigned id) {
-    if (writesToLiveInOrEscapingMemrefs(id))
-      return false;
-    Node *node = getNode(id);
-    for (auto *storeOpInst : node->stores) {
-      // Return false if there exist out edges from 'id' on 'memref'.
-      auto storeMemref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
-      if (getOutEdgeCount(id, storeMemref) > 0)
-        return false;
-    }
-    return true;
-  }
-
   // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
   // is for 'value' if non-null, or for any value otherwise. Returns false
   // otherwise.
@@ -495,42 +437,49 @@ struct MemRefDependenceGraph {
     return dstNodeInst;
   }
 
-  // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
-  // has been replaced in node at 'dstId' by a private memref depending
-  // on the value of 'createPrivateMemRef'.
-  void updateEdges(unsigned srcId, unsigned dstId, Value oldMemRef,
-                   bool createPrivateMemRef) {
+  // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
+  // taking into account that:
+  //   *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
+  //   *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
+  //      private memref.
+  void updateEdges(unsigned srcId, unsigned dstId,
+                   const DenseSet<Value> &privateMemRefs, bool removeSrcId) {
     // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'.
     if (inEdges.count(srcId) > 0) {
       SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
       for (auto &inEdge : oldInEdges) {
-        // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'.
-        if (inEdge.value != oldMemRef)
+        // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
+        if (privateMemRefs.count(inEdge.value) == 0)
           addEdge(inEdge.id, dstId, inEdge.value);
       }
     }
     // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
+    // If 'srcId' is going to be removed, remap all the out edges to 'dstId'.
     if (outEdges.count(srcId) > 0) {
       SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
       for (auto &outEdge : oldOutEdges) {
         // Remove any out edges from 'srcId' to 'dstId' across memrefs.
         if (outEdge.id == dstId)
           removeEdge(srcId, outEdge.id, outEdge.value);
+        else if (removeSrcId) {
+          addEdge(dstId, outEdge.id, outEdge.value);
+          removeEdge(srcId, outEdge.id, outEdge.value);
+        }
       }
     }
     // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
     // replaced by a private memref). These edges could come from nodes
     // other than 'srcId' which were removed in the previous step.
-    if (inEdges.count(dstId) > 0 && createPrivateMemRef) {
+    if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) {
       SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
       for (auto &inEdge : oldInEdges)
-        if (inEdge.value == oldMemRef)
+        if (privateMemRefs.count(inEdge.value) > 0)
           removeEdge(inEdge.id, dstId, inEdge.value);
     }
   }
 
   // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
-  // of sibling node 'sidId' into node 'dstId'.
+  // of sibling node 'sibId' into node 'dstId'.
   void updateEdges(unsigned sibId, unsigned dstId) {
     // For each edge in 'inEdges[sibId]':
     // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
@@ -624,6 +573,141 @@ struct MemRefDependenceGraph {
   void dump() const { print(llvm::errs()); }
 };
 
+/// Returns true if node 'srcId' can be removed after fusing it with node
+/// 'dstId'. The node can be removed if any of the following conditions are met:
+///   1. 'srcId' has no output dependences after fusion and no escaping memrefs.
+///   2. 'srcId' has no output dependences after fusion, has escaping memrefs
+///       and the fusion slice is maximal.
+///   3. 'srcId' has output dependences after fusion, the fusion slice is
+///      maximal and the fusion insertion point dominates all the dependences.
+static bool canRemoveSrcNodeAfterFusion(
+    unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
+    Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
+    MemRefDependenceGraph *mdg) {
+
+  Operation *dstNodeOp = mdg->getNode(dstId)->op;
+  bool hasOutDepsAfterFusion = false;
+
+  for (auto &outEdge : mdg->outEdges[srcId]) {
+    Operation *depNodeOp = mdg->getNode(outEdge.id)->op;
+    // Skip dependence with dstOp since it will be removed after fusion.
+    if (depNodeOp == dstNodeOp)
+      continue;
+
+    // Only fusion within the same block is supported. Use domination analysis
+    // when needed.
+    if (depNodeOp->getBlock() != dstNodeOp->getBlock())
+      return false;
+
+    // Check if the insertion point of the fused loop dominates the dependence.
+    // Otherwise, the src loop can't be removed.
+    if (fusedLoopInsPoint != depNodeOp &&
+        !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) {
+      LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't "
+                                 "dominate dependence\n");
+      return false;
+    }
+
+    hasOutDepsAfterFusion = true;
+  }
+
+  // If src loop has dependences after fusion or it writes to an live-out or
+  // escaping memref, we can only remove it if the fusion slice is maximal so
+  // that all the dependences are preserved.
+  if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
+    Optional<bool> isMaximal = fusionSlice.isMaximal();
+    if (!isMaximal.hasValue()) {
+      LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine "
+                                 "if fusion is maximal\n");
+      return false;
+    }
+
+    if (!isMaximal.getValue()) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Src loop can't be removed: fusion is not maximal\n");
+      return false;
+    }
+  }
+
+  return true;
+}
+
+/// Returns in 'srcIdCandidates' the producer fusion candidates for consumer
+/// 'dstId'. Candidates are sorted by node id order. This order corresponds to
+/// the program order when the 'mdg' is created. However, program order is not
+/// guaranteed and must not be required by the client. Program order won't be
+/// held if the 'mdg' is reused from a previous fusion step or if the node
+/// creation order changes in the future to support more advance cases.
+// TODO: Move this to a loop fusion utility once 'mdg' is also moved.
+static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
+                                  SmallVectorImpl<unsigned> &srcIdCandidates) {
+  // Skip if no input edges along which to fuse.
+  if (mdg->inEdges.count(dstId) == 0)
+    return;
+
+  // Gather memrefs from loads in 'dstId'.
+  auto *dstNode = mdg->getNode(dstId);
+  DenseSet<Value> consumedMemrefs;
+  for (Operation *load : dstNode->loads)
+    consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
+
+  // Traverse 'dstId' incoming edges and gather the nodes that contain a store
+  // to one of the consumed memrefs.
+  for (auto &srcEdge : mdg->inEdges[dstId]) {
+    auto *srcNode = mdg->getNode(srcEdge.id);
+    // Skip if 'srcNode' is not a loop nest.
+    if (!isa<AffineForOp>(srcNode->op))
+      continue;
+
+    if (any_of(srcNode->stores, [&](Operation *op) {
+          auto storeOp = cast<AffineWriteOpInterface>(op);
+          return consumedMemrefs.count(storeOp.getMemRef()) > 0;
+        }))
+      srcIdCandidates.push_back(srcNode->id);
+  }
+
+  std::sort(srcIdCandidates.begin(), srcIdCandidates.end());
+  srcIdCandidates.erase(
+      std::unique(srcIdCandidates.begin(), srcIdCandidates.end()),
+      srcIdCandidates.end());
+}
+
+/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
+/// producer-consumer dependence between 'srcId' and 'dstId'.
+static void
+gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
+                              MemRefDependenceGraph *mdg,
+                              DenseSet<Value> &producerConsumerMemrefs) {
+  auto *dstNode = mdg->getNode(dstId);
+  auto *srcNode = mdg->getNode(srcId);
+  gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
+                                producerConsumerMemrefs);
+}
+
+/// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
+/// that escape the function. A memref escapes the function if either:
+///   1. It's a function argument, or
+///   2. It's used by a non-affine op (e.g., std load/store, std call, etc.)
+void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
+                           DenseSet<Value> &escapingMemRefs) {
+  auto *node = mdg->getNode(id);
+  for (auto *storeOpInst : node->stores) {
+    auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
+    if (escapingMemRefs.count(memref))
+      continue;
+    // Check if 'memref' escapes because it's a block argument.
+    if (memref.isa<BlockArgument>()) {
+      escapingMemRefs.insert(memref);
+      continue;
+    }
+    // Check if 'memref' escapes through a non-affine op (e.g., std load/store,
+    // call op, etc.).
+    for (Operation *user : memref.getUsers())
+      if (!isMemRefDereferencingOp(*user))
+        escapingMemRefs.insert(memref);
+  }
+}
+
 } // end anonymous namespace
 
 // Initializes the data dependence graph by walking operations in 'f'.
@@ -631,6 +715,7 @@ struct MemRefDependenceGraph {
 // TODO: Add support for taking a Block arg to construct the
 // dependence graph at a 
diff erent depth.
 bool MemRefDependenceGraph::init(FuncOp f) {
+  LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
   DenseMap<Value, SetVector<unsigned>> memrefAccesses;
 
   // TODO: support multi-block functions.
@@ -686,6 +771,12 @@ bool MemRefDependenceGraph::init(FuncOp f) {
     }
   }
 
+  for (auto &idAndNode : nodes) {
+    LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n"
+                            << *(idAndNode.second.op) << "\n");
+    (void)idAndNode;
+  }
+
   // Add dependence edges between nodes which produce SSA values and their
   // users.
   for (auto &idAndNode : nodes) {
@@ -725,22 +816,6 @@ bool MemRefDependenceGraph::init(FuncOp f) {
   return true;
 }
 
-// Removes load operations from 'srcLoads' which operate on 'memref', and
-// adds them to 'dstLoads'.
-static void moveLoadsAccessingMemrefTo(Value memref,
-                                       SmallVectorImpl<Operation *> *srcLoads,
-                                       SmallVectorImpl<Operation *> *dstLoads) {
-  dstLoads->clear();
-  SmallVector<Operation *, 4> srcLoadsToKeep;
-  for (auto *load : *srcLoads) {
-    if (cast<AffineReadOpInterface>(load).getMemRef() == memref)
-      dstLoads->push_back(load);
-    else
-      srcLoadsToKeep.push_back(load);
-  }
-  srcLoads->swap(srcLoadsToKeep);
-}
-
 // Sinks all sequential loops to the innermost levels (while preserving
 // relative order among them) and moves all parallel loops to the
 // outermost (while again preserving relative order among them).
@@ -932,75 +1007,6 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
   return false;
 }
 
-// Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId'
-// may write to multiple memrefs but it is required that only one of them,
-// 'srcLiveOutStoreOp', has output edges.
-// Returns true if 'dstNode's read/write region to 'memref' is a super set of
-// 'srcNode's write region to 'memref' and 'srcId' has only one output edge.
-// TODO: Generalize this to handle more live in/out cases.
-static bool
-canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
-                               AffineWriteOpInterface srcLiveOutStoreOp,
-                               MemRefDependenceGraph *mdg) {
-  assert(srcLiveOutStoreOp && "Expected a valid store op");
-  auto *dstNode = mdg->getNode(dstId);
-  Value memref = srcLiveOutStoreOp.getMemRef();
-  // Return false if 'srcNode' has more than one output edge on 'memref'.
-  if (mdg->getOutEdgeCount(srcId, memref) > 1)
-    return false;
-
-  // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOp' on 'memref'.
-  MemRefRegion srcWriteRegion(srcLiveOutStoreOp.getLoc());
-  if (failed(srcWriteRegion.compute(srcLiveOutStoreOp, /*loopDepth=*/0))) {
-    LLVM_DEBUG(llvm::dbgs()
-               << "Unable to compute MemRefRegion for source operation\n.");
-    return false;
-  }
-  SmallVector<int64_t, 4> srcShape;
-  // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'.
-  // by 'srcStoreOp' at depth 'dstLoopDepth'.
-  Optional<int64_t> srcNumElements =
-      srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape);
-  if (!srcNumElements.hasValue())
-    return false;
-
-  // Compute MemRefRegion 'dstRegion' for 'dstStore/LoadOpInst' on 'memref'.
-  // TODO: Compute 'unionboundingbox' of all write regions (one for
-  // each store op in 'dstStoreOps').
-  SmallVector<Operation *, 2> dstStoreOps;
-  dstNode->getStoreOpsForMemref(memref, &dstStoreOps);
-  SmallVector<Operation *, 2> dstLoadOps;
-  dstNode->getLoadOpsForMemref(memref, &dstLoadOps);
-
-  auto *dstOpInst = dstStoreOps.empty() ? dstLoadOps[0] : dstStoreOps[0];
-  MemRefRegion dstRegion(dstOpInst->getLoc());
-  if (failed(dstRegion.compute(dstOpInst, /*loopDepth=*/0))) {
-    LLVM_DEBUG(llvm::dbgs()
-               << "Unable to compute MemRefRegion for dest operation\n.");
-    return false;
-  }
-  SmallVector<int64_t, 4> dstShape;
-  // Query 'dstRegion' for 'dstShape' and 'dstNumElements'.
-  // by 'dstOpInst' at depth 'dstLoopDepth'.
-  Optional<int64_t> dstNumElements =
-      dstRegion.getConstantBoundingSizeAndShape(&dstShape);
-  if (!dstNumElements.hasValue())
-    return false;
-
-  // Return false if write region is not a superset of 'srcNodes' write
-  // region to 'memref'.
-  // TODO: Check the shape and lower bounds here too.
-  if (srcNumElements != dstNumElements)
-    return false;
-
-  // Return false if 'memref' is used by a non-affine operation that is
-  // between node 'srcId' and node 'dstId'.
-  if (hasNonAffineUsersOnThePath(srcId, dstId, mdg))
-    return false;
-
-  return true;
-}
-
 // Checks the profitability of fusing a backwards slice of the loop nest
 // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on
@@ -1029,9 +1035,6 @@ canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
 //    the largest computation slice at the maximal dst loop depth (closest to
 //    the load) to minimize reuse distance and potentially enable subsequent
 //    load/store forwarding.
-//    NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for
-//    the same memref as is written by 'srcOpInst', then the union of slice
-//    loop bounds is used to compute the slice and associated slice cost.
 //    NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
 //    nest, at which the src computation slice is inserted/fused.
 //    NOTE: We attempt to maximize the dst loop depth, but there are cases
@@ -1041,18 +1044,18 @@ canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
 // *) Compares the total cost of the unfused loop nests to the min cost fused
 //    loop nest computed in the previous step, and returns true if the latter
 //    is lower.
+// TODO: Extend profitability analysis to support scenarios with multiple
+// stores.
 static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
-                               ArrayRef<Operation *> dstLoadOpInsts,
+                               AffineForOp dstForOp,
                                ArrayRef<ComputationSliceState> depthSliceUnions,
                                unsigned maxLegalFusionDepth,
                                unsigned *dstLoopDepth,
                                double computeToleranceThreshold) {
   LLVM_DEBUG({
     llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
-    llvm::dbgs() << ' ' << *srcOpInst << " and destination op(s)\n";
-    for (auto dstOpInst : dstLoadOpInsts) {
-      llvm::dbgs() << " " << *dstOpInst << "\n";
-    };
+    llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n";
+    llvm::dbgs() << dstForOp << "\n";
   });
 
   if (maxLegalFusionDepth == 0) {
@@ -1070,11 +1073,8 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
     return false;
 
   // Compute cost of dst loop nest.
-  SmallVector<AffineForOp, 4> dstLoopIVs;
-  getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
-
   LoopNestStats dstLoopNestStats;
-  if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats))
+  if (!getLoopNestStats(dstForOp, &dstLoopNestStats))
     return false;
 
   // Search for min cost value for 'dstLoopDepth'. At each value of
@@ -1108,18 +1108,19 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
   int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue();
 
   // Compute op instance count for the src loop nest.
-  uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], dstLoopNestStats);
+  uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
 
   // Evaluate all depth choices for materializing the slice in the destination
   // loop nest.
   for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
+    const ComputationSliceState &slice = depthSliceUnions[i - 1];
     // Skip slice union if it wasn't computed for this depth.
-    if (depthSliceUnions[i - 1].isEmpty())
+    if (slice.isEmpty())
       continue;
 
     int64_t fusedLoopNestComputeCost;
-    if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0],
-                              dstLoopNestStats, depthSliceUnions[i - 1],
+    if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp,
+                              dstLoopNestStats, slice,
                               &fusedLoopNestComputeCost)) {
       LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n.");
       continue;
@@ -1131,11 +1132,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
         1;
 
     // Determine what the slice write MemRefRegion would be, if the src loop
-    // nest slice 'depthSliceUnions[i - 1]' were to be inserted into the dst
-    // loop nest at loop depth 'i'.
+    // nest slice 'slice' were to be inserted into the dst loop nest at loop
+    // depth 'i'.
     MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
     if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
-                                        &depthSliceUnions[i - 1]))) {
+                                        &slice))) {
       LLVM_DEBUG(llvm::dbgs()
                  << "Failed to compute slice write region at loopDepth: " << i
                  << "\n");
@@ -1218,7 +1219,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
                    << "\n  fused loop nest compute cost: "
                    << minFusedLoopNestComputeCost << "\n");
 
-  auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]);
+  auto dstMemSize = getMemoryFootprintBytes(dstForOp);
   auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
 
   Optional<double> storageReduction = None;
@@ -1322,8 +1323,6 @@ struct GreedyFusion {
   MemRefDependenceGraph *mdg;
   // Worklist of graph nodes visited during the fusion pass.
   SmallVector<unsigned, 8> worklist;
-  // Set of graph nodes which are present on the worklist.
-  llvm::SmallDenseSet<unsigned, 16> worklistSet;
   // Parameter for local buffer size threshold.
   unsigned localBufSizeThreshold;
   // Parameter for fast memory space.
@@ -1344,16 +1343,14 @@ struct GreedyFusion {
         fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
         computeToleranceThreshold(computeToleranceThreshold) {}
 
-  // Initializes 'worklist' with nodes from 'mdg'
+  /// Initializes 'worklist' with nodes from 'mdg'.
   void init() {
     // TODO: Add a priority queue for prioritizing nodes by 
diff erent
     // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
     worklist.clear();
-    worklistSet.clear();
     for (auto &idAndNode : mdg->nodes) {
       const Node &node = idAndNode.second;
       worklist.push_back(node.id);
-      worklistSet.insert(node.id);
     }
   }
 
@@ -1372,11 +1369,11 @@ struct GreedyFusion {
   }
 
   void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
+    LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
     init();
     while (!worklist.empty()) {
       unsigned dstId = worklist.back();
       worklist.pop_back();
-      worklistSet.erase(dstId);
 
       // Skip if this node was removed (fused into another node).
       if (mdg->nodes.count(dstId) == 0)
@@ -1386,114 +1383,85 @@ struct GreedyFusion {
       // Skip if 'dstNode' is not a loop nest.
       if (!isa<AffineForOp>(dstNode->op))
         continue;
+
+      LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
+
       // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
       // while preserving relative order. This can increase the maximum loop
       // depth at which we can fuse a slice of a producer loop nest into a
       // consumer loop nest.
       sinkSequentialLoops(dstNode);
-
-      SmallVector<Operation *, 4> loads = dstNode->loads;
-      SmallVector<Operation *, 4> dstLoadOpInsts;
-      DenseSet<Value> visitedMemrefs;
-      while (!loads.empty()) {
-        // Get memref of load on top of the stack.
-        auto memref = cast<AffineReadOpInterface>(loads.back()).getMemRef();
-        if (visitedMemrefs.count(memref) > 0)
-          continue;
-        visitedMemrefs.insert(memref);
-        // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
-        moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
-        // Skip if no input edges along which to fuse.
-        if (mdg->inEdges.count(dstId) == 0)
-          continue;
-        // Iterate through in-edges for 'dstId' and src node id for any
-        // edges on 'memref'.
-        SmallVector<unsigned, 2> srcNodeIds;
-        for (auto &srcEdge : mdg->inEdges[dstId]) {
-          // Skip 'srcEdge' if not for 'memref'.
-          if (srcEdge.value != memref)
-            continue;
-          srcNodeIds.push_back(srcEdge.id);
-        }
-        for (unsigned srcId : srcNodeIds) {
-          // Skip if this node was removed (fused into another node).
-          if (mdg->nodes.count(srcId) == 0)
-            continue;
+      auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
+
+      // Try to fuse 'dstNode' with candidate producer loops until a fixed point
+      // is reached. Fusing two loops may expose new fusion opportunities.
+      bool dstNodeChanged;
+      do {
+        // Gather src loop candidates for 'dstNode' and visit them in "quasi"
+        // reverse program order to minimize the number of iterations needed to
+        // reach the fixed point. Note that this is a best effort approach since
+        // 'getProducerCandidates' does not always guarantee that program order
+        // in 'srcIdCandidates'.
+        dstNodeChanged = false;
+        SmallVector<unsigned, 16> srcIdCandidates;
+        getProducerCandidates(dstId, mdg, srcIdCandidates);
+
+        for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
           // Get 'srcNode' from which to attempt fusion into 'dstNode'.
           auto *srcNode = mdg->getNode(srcId);
-          // Skip if 'srcNode' is not a loop nest.
-          if (!isa<AffineForOp>(srcNode->op))
+          auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
+          LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
+                                  << " for dst loop " << dstId << "\n");
+
+          DenseSet<Value> producerConsumerMemrefs;
+          gatherProducerConsumerMemrefs(srcId, dstId, mdg,
+                                        producerConsumerMemrefs);
+
+          // Skip if 'srcNode' out edge count on any memref is greater than
+          // 'maxSrcUserCount'.
+          if (any_of(producerConsumerMemrefs, [&](Value memref) {
+                return mdg->getOutEdgeCount(srcNode->id, memref) >
+                       maxSrcUserCount;
+              }))
             continue;
-          // Skip if 'srcNode' has more than one live-out store to a
-          // function-local memref.
-          // TODO: Support more generic multi-output src loop nests
-          // fusion.
-          auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode);
-          if (!srcStoreOp) {
-            // Get the src store op at the deepest loop depth.
-            // We will use 'LoopFusionUtils::canFuseLoops' to check fusion
-            // feasibility for loops with multiple stores.
-            unsigned maxLoopDepth = 0;
-            for (auto *op : srcNode->stores) {
-              auto storeOp = cast<AffineWriteOpInterface>(op);
-              if (storeOp.getMemRef() != memref) {
-                srcStoreOp = nullptr;
-                break;
-              }
-              unsigned loopDepth = getNestingDepth(storeOp);
-              if (loopDepth > maxLoopDepth) {
-                maxLoopDepth = loopDepth;
-                srcStoreOp = storeOp;
-              }
-            }
-            if (!srcStoreOp)
-              continue;
-          }
 
-          // Unique outgoing store found must write to 'memref' since 'memref'
-          // is the one that established the producer-consumer relationship
-          // between 'srcNode' and 'dstNode'.
-          assert(srcStoreOp.getMemRef() == memref &&
-                 "Found store to unexpected memref");
-
-          // Skip if 'srcNode' writes to any live in or escaping memrefs,
-          // and cannot be fused.
-          bool writesToLiveInOrOut =
-              mdg->writesToLiveInOrEscapingMemrefs(srcNode->id);
-          if (writesToLiveInOrOut &&
-              !canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg))
+          // Gather memrefs in 'srcNode' that are written and escape to the
+          // function (e.g., memref function arguments, returned memrefs,
+          // memrefs passed to function calls, etc.).
+          DenseSet<Value> srcEscapingMemRefs;
+          gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
+
+          // Skip if there are non-affine operations in between the 'srcNode'
+          // and 'dstNode' using their memrefs. If so, we wouldn't be able to
+          // compute a legal insertion point for now. 'srcNode' and 'dstNode'
+          // memrefs with non-affine operation users would be considered
+          // escaping memrefs so we can limit this check to only scenarios with
+          // escaping memrefs.
+          if (!srcEscapingMemRefs.empty() &&
+              hasNonAffineUsersOnThePath(srcId, dstId, mdg)) {
+            LLVM_DEBUG(
+                llvm::dbgs()
+                << "Can't fuse: non-affine users in between the loops\n.");
             continue;
-
-          // Don't create a private memref if 'writesToLiveInOrOut'.
-          bool createPrivateMemref = !writesToLiveInOrOut;
-          // Don't create a private memref if 'srcNode' has in edges on
-          // 'memref', or if 'dstNode' has out edges on 'memref'.
-          if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) > 0 ||
-              mdg->getOutEdgeCount(dstNode->id, memref) > 0) {
-            createPrivateMemref = false;
           }
 
-          // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'.
-          if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount)
-            continue;
-
           // Compute an operation list insertion point for the fused loop
           // nest which preserves dependences.
-          Operation *insertPointInst =
+          Operation *fusedLoopInsPoint =
               mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
-          if (insertPointInst == nullptr)
+          if (fusedLoopInsPoint == nullptr)
             continue;
 
-          auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
-          auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
-
-          // Compute the innermost common loop depth for dstNode loads/stores.
+          // Compute the innermost common loop depth for dstNode
+          // producer-consumer loads/stores.
           SmallVector<Operation *, 2> dstMemrefOps;
           for (Operation *op : dstNode->loads)
-            if (cast<AffineReadOpInterface>(op).getMemRef() == memref)
+            if (producerConsumerMemrefs.count(
+                    cast<AffineReadOpInterface>(op).getMemRef()) > 0)
               dstMemrefOps.push_back(op);
           for (Operation *op : dstNode->stores)
-            if (cast<AffineWriteOpInterface>(op).getMemRef() == memref)
+            if (producerConsumerMemrefs.count(
+                    cast<AffineWriteOpInterface>(op).getMemRef()))
               dstMemrefOps.push_back(op);
           unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps);
 
@@ -1502,7 +1470,7 @@ struct GreedyFusion {
           unsigned maxLegalFusionDepth = 0;
           SmallVector<ComputationSliceState, 8> depthSliceUnions;
           depthSliceUnions.resize(dstLoopDepthTest);
-          FusionStrategy strategy(FusionStrategy::ProducerConsumer, memref);
+          FusionStrategy strategy(FusionStrategy::ProducerConsumer);
           for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
             FusionResult result = mlir::canFuseLoops(
                 srcAffineForOp, dstAffineForOp,
@@ -1512,27 +1480,82 @@ struct GreedyFusion {
               maxLegalFusionDepth = i;
           }
 
-          // Skip if fusion is not feasible at any loop depths.
-          if (maxLegalFusionDepth == 0)
+          if (maxLegalFusionDepth == 0) {
+            LLVM_DEBUG(llvm::dbgs()
+                       << "Can't fuse: fusion is not legal at any depth\n");
             continue;
+          }
 
           // Check if fusion would be profitable. We skip profitability analysis
           // for maximal fusion since we already know the maximal legal depth to
           // fuse.
           unsigned bestDstLoopDepth = maxLegalFusionDepth;
-          if (!maximalFusion &&
-              !isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts,
-                                  depthSliceUnions, maxLegalFusionDepth,
-                                  &bestDstLoopDepth, computeToleranceThreshold))
-            continue;
+          if (!maximalFusion) {
+            // Retrieve producer stores from the src loop.
+            SmallVector<Operation *, 2> producerStores;
+            for (Operation *op : srcNode->stores)
+              if (producerConsumerMemrefs.count(
+                      cast<AffineWriteOpInterface>(op).getMemRef()))
+                producerStores.push_back(op);
+
+            // TODO: Suppport multiple producer stores in profitability
+            // analysis. We limit profitability analysis to only scenarios with
+            // a single producer store for now. Note that some multi-store
+            // producer scenarios will still go through profitability analysis
+            // if only one of the stores is involved the producer-consumer
+            // relationship of the candidate loops.
+            assert(producerStores.size() > 0 && "Expected producer store");
+            if (producerStores.size() > 1)
+              LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
+                                         "supported for this case\n");
+            else if (!isFusionProfitable(producerStores[0], producerStores[0],
+                                         dstAffineForOp, depthSliceUnions,
+                                         maxLegalFusionDepth, &bestDstLoopDepth,
+                                         computeToleranceThreshold))
+              continue;
+          }
 
           assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
-          assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
-                 "Missing slice union for depth");
+          ComputationSliceState &bestSlice =
+              depthSliceUnions[bestDstLoopDepth - 1];
+          assert(!bestSlice.isEmpty() && "Missing slice union for depth");
+
+          // Determine if 'srcId' can be removed after fusion, taking into
+          // account remaining dependences, escaping memrefs and the fusion
+          // insertion point.
+          bool removeSrcNode = canRemoveSrcNodeAfterFusion(
+              srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
+              mdg);
+
+          DenseSet<Value> privateMemrefs;
+          for (Value memref : producerConsumerMemrefs) {
+            // Don't create a private memref if 'srcNode' writes to escaping
+            // memrefs.
+            if (srcEscapingMemRefs.count(memref) > 0)
+              continue;
+
+            // Don't create a private memref if 'srcNode' has in edges on
+            // 'memref' or 'dstNode' has out edges on 'memref'.
+            if (mdg->getIncomingMemRefAccesses(srcId, memref) > 0 ||
+                mdg->getOutEdgeCount(dstId, memref) > 0)
+              continue;
+
+            // If 'srcNode' will be removed but it has out edges on 'memref' to
+            // nodes other than 'dstNode', we have to preserve dependences and
+            // cannot create a private memref.
+            if (removeSrcNode &&
+                any_of(mdg->outEdges[srcId], [&](const auto &edge) {
+                  return edge.value == memref && edge.id != dstId;
+                }))
+              continue;
+
+            // Create a private version of this memref.
+            privateMemrefs.insert(memref);
+          }
 
           // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
-          fuseLoops(srcAffineForOp, dstAffineForOp,
-                    depthSliceUnions[bestDstLoopDepth - 1]);
+          fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
+          dstNodeChanged = true;
 
           LLVM_DEBUG(llvm::dbgs()
                      << "Fused src loop " << srcId << " into dst loop " << dstId
@@ -1540,92 +1563,63 @@ struct GreedyFusion {
                      << dstAffineForOp << "\n");
 
           // Move 'dstAffineForOp' before 'insertPointInst' if needed.
-          if (insertPointInst != dstAffineForOp.getOperation())
-            dstAffineForOp->moveBefore(insertPointInst);
+          if (fusedLoopInsPoint != dstAffineForOp.getOperation())
+            dstAffineForOp.getOperation()->moveBefore(fusedLoopInsPoint);
 
           // Update edges between 'srcNode' and 'dstNode'.
-          mdg->updateEdges(srcNode->id, dstNode->id, memref,
-                           createPrivateMemref);
-
-          // Collect slice loop stats.
-          LoopNestStateCollector dstForCollector;
-          dstForCollector.collect(dstAffineForOp);
-          if (createPrivateMemref) {
-            // Create private memref for 'memref' in 'dstAffineForOp'.
-            SmallVector<Operation *, 4> storesForMemref;
-            for (auto *storeOpInst : dstForCollector.storeOpInsts) {
-              if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() ==
-                  memref)
-                storesForMemref.push_back(storeOpInst);
+          mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
+                           removeSrcNode);
+
+          // Create private memrefs.
+          if (!privateMemrefs.empty()) {
+            // Gather stores for all the private-to-be memrefs.
+            DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
+            dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
+              Value storeMemRef = storeOp.getMemRef();
+              if (privateMemrefs.count(storeMemRef) > 0)
+                privateMemRefToStores[storeMemRef].push_back(
+                    storeOp.getOperation());
+            });
+
+            // Replace original memrefs with private memrefs. Note that all the
+            // loads and stores on these memrefs will be replaced with a new
+            // loads and stores. Any reference to the original ones becomes
+            // invalid after this point.
+            for (auto &memrefToStoresPair : privateMemRefToStores) {
+              // TODO: Use union of memref write regions to compute
+              // private memref footprint.
+              SmallVector<Operation *, 4> &storesForMemref =
+                  memrefToStoresPair.second;
+              Value newMemRef = createPrivateMemRef(
+                  dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
+                  fastMemorySpace, localBufSizeThreshold);
+              // Create new node in dependence graph for 'newMemRef' alloc op.
+              unsigned newMemRefNodeId =
+                  mdg->addNode(newMemRef.getDefiningOp());
+              // Add edge from 'newMemRef' node to dstNode.
+              mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
             }
-            // TODO: Use union of memref write regions to compute
-            // private memref footprint.
-            auto newMemRef = createPrivateMemRef(
-                dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
-                fastMemorySpace, localBufSizeThreshold);
-            visitedMemrefs.insert(newMemRef);
-            // Create new node in dependence graph for 'newMemRef' alloc op.
-            unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
-            // Add edge from 'newMemRef' node to dstNode.
-            mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
           }
 
           // Collect dst loop stats after memref privatization transformation.
           LoopNestStateCollector dstLoopCollector;
           dstLoopCollector.collect(dstAffineForOp.getOperation());
 
-          // Add new load ops to current Node load op list 'loads' to continue
-          // fusing based on new operands.
-          for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
-            // NOTE: Change 'loads' to a hash set in case efficiency is an
-            // issue. We still use a vector since it's expected to be small.
-            if (!llvm::is_contained(loads, loadOpInst))
-              loads.push_back(loadOpInst);
-          }
-          // Clear visited memrefs after fusion so that previously visited src
-          // nodes are considered for fusion again in the context of the new
-          // fused node.
-          // TODO: This shouldn't be necessary if we visited candidates in the
-          // dependence graph in post-order or once we fully support multi-store
-          // producers. Currently, in a multi-store producer scenario such as
-          // A->B, A->C, B->C, we fail to fuse A+B due to the multiple outgoing
-          // edges. However, after fusing B+C, A has a single outgoing edge and
-          // can be fused if we revisit it in the context of the new fused B+C
-          // node.
-          visitedMemrefs.clear();
-
           // Clear and add back loads and stores.
           mdg->clearNodeLoadAndStores(dstNode->id);
           mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
                          dstLoopCollector.storeOpInsts);
-          // Remove old src loop nest if it no longer has outgoing dependence
-          // edges, and if it does not write to a memref which escapes the
-          // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has been
-          // fused into 'dstNode' and write region of 'dstNode' covers the write
-          // region of 'srcNode', and 'srcNode' has no other users so it is safe
-          // to remove.
-          if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) {
-            mdg->removeNode(srcNode->id);
-            srcNode->op->erase();
-          } else {
-            // Add remaining users of 'oldMemRef' back on the worklist (if not
-            // already there), as its replacement with a local/private memref
-            // has reduced dependences on 'oldMemRef' which may have created new
-            // fusion opportunities.
-            if (mdg->outEdges.count(srcNode->id) > 0) {
-              SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges =
-                  mdg->outEdges[srcNode->id];
-              for (auto &outEdge : oldOutEdges) {
-                if (outEdge.value == memref &&
-                    worklistSet.count(outEdge.id) == 0) {
-                  worklist.push_back(outEdge.id);
-                  worklistSet.insert(outEdge.id);
-                }
-              }
-            }
+
+          if (removeSrcNode) {
+            LLVM_DEBUG(llvm::dbgs()
+                       << "Removing src loop " << srcId << " after fusion\n");
+            // srcNode is no longer valid after it is removed from mdg.
+            srcAffineForOp.erase();
+            mdg->removeNode(srcId);
+            srcNode = nullptr;
           }
         }
-      }
+      } while (dstNodeChanged);
     }
   }
 
@@ -1636,7 +1630,6 @@ struct GreedyFusion {
     while (!worklist.empty()) {
       unsigned dstId = worklist.back();
       worklist.pop_back();
-      worklistSet.erase(dstId);
 
       // Skip if this node was removed (fused into another node).
       if (mdg->nodes.count(dstId) == 0)
@@ -1698,7 +1691,7 @@ struct GreedyFusion {
       SmallVector<ComputationSliceState, 8> depthSliceUnions;
       depthSliceUnions.resize(dstLoopDepthTest);
       unsigned maxLegalFusionDepth = 0;
-      FusionStrategy strategy(FusionStrategy::Sibling, memref);
+      FusionStrategy strategy(memref);
       for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
         FusionResult result = mlir::canFuseLoops(
             sibAffineForOp, dstAffineForOp,
@@ -1712,10 +1705,10 @@ struct GreedyFusion {
       if (maxLegalFusionDepth == 0)
         continue;
 
-      unsigned bestDstLoopDepth = dstLoopDepthTest;
+      unsigned bestDstLoopDepth = maxLegalFusionDepth;
       if (!maximalFusion) {
         // Check if fusion would be profitable.
-        if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts,
+        if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstAffineForOp,
                                 depthSliceUnions, maxLegalFusionDepth,
                                 &bestDstLoopDepth, computeToleranceThreshold))
           continue;

diff  --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
index 9759300f2e42..9749a8de2351 100644
--- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
@@ -191,11 +191,8 @@ gatherLoadsAndStores(AffineForOp forOp,
 /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences.
 // TODO: Generalize this check for sibling and more generic fusion scenarios.
 // TODO: Support forward slice fusion.
-static unsigned getMaxLoopDepth(ArrayRef<Operation *> dstOps,
-                                FusionStrategy fusionStrategy) {
-  assert(fusionStrategy.strategy == FusionStrategy::ProducerConsumer &&
-         "Fusion strategy not supported");
-
+static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
+                                ArrayRef<Operation *> dstOps) {
   if (dstOps.empty())
     // Expected at least one memory operation.
     // TODO: Revisit this case with a specific example.
@@ -203,15 +200,14 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> dstOps,
 
   // Filter out ops in 'dstOps' that do not use the producer-consumer memref so
   // that they are not considered for analysis.
-  // TODO: Currently, we pass the producer-consumer memref through
-  // fusionStrategy. We will retrieve the memrefs from 'srcOps' once we
-  // generalize the algorithm.
+  DenseSet<Value> producerConsumerMemrefs;
+  gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs);
   SmallVector<Operation *, 4> targetDstOps;
   for (Operation *dstOp : dstOps) {
     auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp);
     Value memref = loadOp ? loadOp.getMemRef()
                           : cast<AffineWriteOpInterface>(dstOp).getMemRef();
-    if (memref == fusionStrategy.memref)
+    if (producerConsumerMemrefs.count(memref) > 0)
       targetDstOps.push_back(dstOp);
   }
 
@@ -308,10 +304,10 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
   // loop dependences.
   // TODO: Enable this check for sibling and more generic loop fusion
   // strategies.
-  if (fusionStrategy.strategy == FusionStrategy::ProducerConsumer) {
+  if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) {
     // TODO: 'getMaxLoopDepth' does not support forward slice fusion.
     assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
-    if (getMaxLoopDepth(opsB, fusionStrategy) < dstLoopDepth) {
+    if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) {
       LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
       return FusionResult::FailFusionDependence;
     }
@@ -324,7 +320,7 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
   // Filter out ops in 'opsA' to compute the slice union based on the
   // assumptions made by the fusion strategy.
   SmallVector<Operation *, 4> strategyOpsA;
-  switch (fusionStrategy.strategy) {
+  switch (fusionStrategy.getStrategy()) {
   case FusionStrategy::Generic:
     // Generic fusion. Take into account all the memory operations to compute
     // the slice union.
@@ -332,10 +328,9 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
     break;
   case FusionStrategy::ProducerConsumer:
     // Producer-consumer fusion (AffineLoopFusion pass) only takes into
-    // account stores to 'memref' in 'srcForOp' to compute the slice union.
+    // account stores in 'srcForOp' to compute the slice union.
     for (Operation *op : opsA) {
-      auto store = dyn_cast<AffineWriteOpInterface>(op);
-      if (store && store.getMemRef() == fusionStrategy.memref)
+      if (isa<AffineWriteOpInterface>(op))
         strategyOpsA.push_back(op);
     }
     break;
@@ -344,7 +339,7 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
     // to 'memref' in 'srcForOp' to compute the slice union.
     for (Operation *op : opsA) {
       auto load = dyn_cast<AffineReadOpInterface>(op);
-      if (load && load.getMemRef() == fusionStrategy.memref)
+      if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef())
         strategyOpsA.push_back(op);
     }
     break;
@@ -628,3 +623,23 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
                            /*tripCountOverrideMap=*/nullptr, &computeCostMap);
   return true;
 }
+
+/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
+/// producer-consumer dependence between write ops in 'srcOps' and read ops in
+/// 'dstOps'.
+void mlir::gatherProducerConsumerMemrefs(
+    ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps,
+    DenseSet<Value> &producerConsumerMemrefs) {
+  // Gather memrefs from stores in 'srcOps'.
+  DenseSet<Value> srcStoreMemRefs;
+  for (Operation *op : srcOps)
+    if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op))
+      srcStoreMemRefs.insert(storeOp.getMemRef());
+
+  // Compute the intersection between memrefs from stores in 'srcOps' and
+  // memrefs from loads in 'dstOps'.
+  for (Operation *op : dstOps)
+    if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
+      if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0)
+        producerConsumerMemrefs.insert(loadOp.getMemRef());
+}

diff  --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir
index a23f0e2ee430..c1bccea4c9f5 100644
--- a/mlir/test/Transforms/loop-fusion.mlir
+++ b/mlir/test/Transforms/loop-fusion.mlir
@@ -364,8 +364,8 @@ func @should_fuse_and_move_to_preserve_war_dep() {
 
 // -----
 
-// CHECK-LABEL: func @should_fuse_with_private_memref_if_top_level_access() {
-func @should_fuse_with_private_memref_if_top_level_access() {
+// CHECK-LABEL: func @should_fuse_if_top_level_access() {
+func @should_fuse_if_top_level_access() {
   %m = alloc() : memref<10xf32>
   %cf7 = constant 7.0 : f32
 
@@ -378,14 +378,45 @@ func @should_fuse_with_private_memref_if_top_level_access() {
 
   %c0 = constant 4 : index
   %v1 = affine.load %m[%c0] : memref<10xf32>
-  // Top-level load to '%{{.*}}' should prevent fusion.
-  // CHECK:      affine.for %{{.*}} = 0 to 10 {
-  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+  // Top-level load to '%m' should prevent creating a private memref but
+  // loop nests should be fused and '%i0' should be removed.
+  // CHECK:      %[[m:.*]] = alloc() : memref<10xf32>
+  // CHECK-NOT:  alloc
+
+  // CHECK:      affine.for %[[i1:.*]] = 0 to 10 {
+  // CHECK-NEXT:   affine.store %{{.*}}, %[[m]][%[[i1]]] : memref<10xf32>
+  // CHECK-NEXT:   affine.load %[[m]][%[[i1]]] : memref<10xf32>
   // CHECK-NEXT: }
-  // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
+  // CHECK:      affine.load %[[m]][%{{.*}}] : memref<10xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @should_fuse_but_not_remove_src() {
+func @should_fuse_but_not_remove_src() {
+  %m = alloc() : memref<100xf32>
+  %cf7 = constant 7.0 : f32
+
+  affine.for %i0 = 0 to 100 {
+    affine.store %cf7, %m[%i0] : memref<100xf32>
+  }
+  affine.for %i1 = 0 to 17 {
+    %v0 = affine.load %m[%i1] : memref<100xf32>
+  }
+  %v1 = affine.load %m[99] : memref<100xf32>
+
+  // Loop '%i0' and '%i1' should be fused but '%i0' shouldn't be removed to
+  // preserve the dependence with the top-level access.
+  // CHECK:      affine.for %{{.*}} = 0 to 100 {
+  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<100xf32>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: affine.for %{{.*}} = 0 to 17 {
   // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
   // CHECK-NEXT:   affine.load %{{.*}}[0] : memref<1xf32>
   // CHECK-NEXT: }
+  // CHECK-NEXT: affine.load %{{.*}}[99] : memref<100xf32>
+  // CHECK-NEXT: return
   return
 }
 
@@ -1110,8 +1141,8 @@ func @should_fuse_with_private_memrefs_with_
diff _shapes() {
 
 // -----
 
-// CHECK-LABEL: func @should_not_fuse_live_out_arg(%{{.*}}: memref<10xf32>) {
-func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) {
+// CHECK-LABEL: func @should_fuse_live_out_arg_but_preserve_src_loop(%{{.*}}: memref<10xf32>) {
+func @should_fuse_live_out_arg_but_preserve_src_loop(%arg0: memref<10xf32>) {
   %cf7 = constant 7.0 : f32
 
   affine.for %i0 = 0 to 10 {
@@ -1129,6 +1160,7 @@ func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) {
   // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:  }
   // CHECK-NEXT:  affine.for %{{.*}} = 0 to 9 {
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:  }
   // CHECK-NEXT:  return
@@ -1160,8 +1192,8 @@ func @should_fuse_live_out_arg(%arg0: memref<10xf32>) {
 
 // -----
 
-// CHECK-LABEL: func @should_not_fuse_escaping_memref() -> memref<10xf32>
-func @should_not_fuse_escaping_memref() -> memref<10xf32> {
+// CHECK-LABEL: func @should_fuse_escaping_memref_but_preserve_src_loop() -> memref<10xf32>
+func @should_fuse_escaping_memref_but_preserve_src_loop() -> memref<10xf32> {
   %cf7 = constant 7.0 : f32
   %m = alloc() : memref<10xf32>
   affine.for %i0 = 0 to 10 {
@@ -1170,19 +1202,21 @@ func @should_not_fuse_escaping_memref() -> memref<10xf32> {
   affine.for %i1 = 0 to 9 {
     %v0 = affine.load %m[%i1] : memref<10xf32>
   }
-  // This tests that the loop nest '%{{.*}}' should not be removed after fusion
-  // because it writes to memref '%{{.*}}' which is returned by the function.
+  // This tests that the loop nest '%i0' should not be removed after fusion
+  // because it writes to memref '%m', which is returned by the function, and
+  // the '%i1' memory region does not cover '%i0' memory region.
+
   // CHECK-DAG:   alloc() : memref<10xf32>
   // CHECK:       affine.for %{{.*}} = 0 to 10 {
   // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:  }
   // CHECK-NEXT:  affine.for %{{.*}} = 0 to 9 {
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:  }
   // CHECK-NEXT:  return %{{.*}} : memref<10xf32>
   return %m : memref<10xf32>
 }
-
 // -----
 
 // This should fuse with the %in becoming a 1x1x1.
@@ -1230,7 +1264,7 @@ func @R3_to_R2_reshape() {
 
 // -----
 
-func @should_not_fuse_multi_output_producer() {
+func @should_fuse_multi_output_producer() {
   %a = alloc() : memref<10xf32>
   %b = alloc() : memref<10xf32>
 
@@ -1246,12 +1280,10 @@ func @should_not_fuse_multi_output_producer() {
   }
 
   // CHECK:       affine.for %{{.*}} = 0 to 10 {
-  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT:  }
-  // CHECK-NEXT:  affine.for %{{.*}} = 0 to 10 {
-  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT:    affine.load %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT:    affine.load %{{.*}}[0] : memref<1xf32>
   // CHECK-NEXT:  }
   // CHECK-NEXT:  return
   return
@@ -1504,8 +1536,8 @@ func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32>
 
 // -----
 
-// CHECK-LABEL: func @should_fuse_after_private_memref_creation() {
-func @should_fuse_after_private_memref_creation() {
+// CHECK-LABEL: func @should_fuse_only_two_loops_and_remove_producer() {
+func @should_fuse_only_two_loops_and_remove_producer() {
   %a = alloc() : memref<10xf32>
   %b = alloc() : memref<10xf32>
 
@@ -1525,18 +1557,21 @@ func @should_fuse_after_private_memref_creation() {
 
   // On the first visit to '%i2', the fusion algorithm can not fuse loop nest
   // '%i0' into '%i2' because of the dependences '%i0' and '%i2' each have on
-  // '%i1'. However, once the loop nest '%i0' is fused into '%i1' with a
-  // private memref, the dependence between '%i0' and '%i1' on memref '%a' no
-  // longer exists, so '%i0' can now be fused into '%i2'.
-
+  // '%i1'. Then, '%i0' is fused into '%i1' and no private memref is created for
+  // memref '%a' to be able to remove '%i0' and still preserve the depencence on
+  // '%a' with '%i2'.
+  // TODO: Alternatively, we could fuse '%i0' into '%i1' with a private memref,
+  // the dependence between '%i0' and '%i1' on memref '%a' would no longer exist,
+  // and '%i0' could be fused into '%i2' as well. Note that this approach would
+  // duplicate the computation in loop nest '%i0' to loop nests '%i1' and '%i2',
+  // which would limit its profitability.
   // CHECK:       affine.for %{{.*}} = 0 to 10 {
-  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
-  // CHECK-NEXT:    affine.load %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:  }
   // CHECK-NEXT:  affine.for %{{.*}} = 0 to 10 {
-  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
-  // CHECK-NEXT:    affine.load %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:  }
   // CHECK-NEXT:   return
@@ -2220,7 +2255,7 @@ func @affine_2_dependent_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<10
     }
   }
 
-  // CHECK:  affine.for %{{.*}} = 0 to 1024 {
+  // CHECK:        affine.for %{{.*}} = 0 to 1024 {
   // CHECK-NEXT:     affine.for %{{.*}} = 0 to 1024 {
   // CHECK-NEXT:       affine.for %{{.*}} = 0 to 1024 {
   // CHECK-NEXT:         affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
@@ -2311,8 +2346,8 @@ func @should_fuse_function_live_out_multi_store_producer(%live_in_out_m : memref
   }
   // CHECK:      affine.for %[[i0:.*]] = 0 to 10 {
   // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32>
-  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32>
-  // CHECK-NEXT:   affine.load %{{.*}}[%[[i0]]] : memref<10xf32>
+  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT:   affine.load %{{.*}}[0] : memref<1xf32>
   // CHECK-NEXT: }
   // CHECK-NEXT: return
   return
@@ -2373,12 +2408,11 @@ func @mul_add_0(%arg0: memref<3x4xf32>, %arg1: memref<4x3xf32>, %arg2: memref<3x
 
 // -----
 
-// Verify that 'fuseProducerConsumerNodes' doesn't fuse a producer loop with
-// a store that has multiple outgoing edges. Sibling loop fusion should not fuse
-// any of these loops due to dependencies on external memref '%a'.
+// Verify that 'fuseProducerConsumerNodes' fuse a producer loop with a store
+// that has multiple outgoing edges.
 
-// CHECK-LABEL: func @should_not_fuse_multi_outgoing_edge_store_producer1
-func @should_not_fuse_multi_outgoing_edge_store_producer1(%a : memref<1xf32>) {
+// CHECK-LABEL: func @should_fuse_multi_outgoing_edge_store_producer
+func @should_fuse_multi_outgoing_edge_store_producer(%a : memref<1xf32>) {
   %cst = constant 0.000000e+00 : f32
   affine.for %arg0 = 0 to 1 {
     affine.store %cst, %a[%arg0] : memref<1xf32>
@@ -2391,9 +2425,12 @@ func @should_not_fuse_multi_outgoing_edge_store_producer1(%a : memref<1xf32>) {
   affine.for %arg0 = 0 to 1 {
     %0 = affine.load %a[%arg0] : memref<1xf32>
   }
-  // CHECK: affine.for %{{.*}} = 0 to 1
-  // CHECK: affine.for %{{.*}} = 0 to 1
-  // CHECK: affine.for %{{.*}} = 0 to 1
+  // CHECK:      affine.for %{{.*}} = 0 to 1 {
+  // CHECK-NEXT:   affine.store
+  // CHECK-NEXT:   affine.load
+  // CHECK-NEXT:   affine.load
+  // CHECK-NEXT: }
+
   return
 }
 
@@ -2663,3 +2700,109 @@ func @fuse_minor_affine_map(%in: memref<128xf32>, %out: memref<20x512xf32>) {
 // MAXIMAL:       affine.for
 // MAXIMAL-NEXT:    affine.for
 // MAXIMAL-NOT:   affine.for
+// MAXIMAL:       return
+
+// -----
+
+// CHECK-LABEL: func @should_fuse_multi_store_producer_and_privatize_memfefs
+func @should_fuse_multi_store_producer_and_privatize_memfefs() {
+  %a = alloc() : memref<10xf32>
+  %b = alloc() : memref<10xf32>
+  %c = alloc() : memref<10xf32>
+  %cst = constant 0.000000e+00 : f32
+  affine.for %arg0 = 0 to 10 {
+    affine.store %cst, %a[%arg0] : memref<10xf32>
+    affine.store %cst, %b[%arg0] : memref<10xf32>
+    affine.store %cst, %c[%arg0] : memref<10xf32>
+    %0 = affine.load %c[%arg0] : memref<10xf32>
+  }
+
+  affine.for %arg0 = 0 to 10 {
+    %0 = affine.load %a[%arg0] : memref<10xf32>
+  }
+
+  affine.for %arg0 = 0 to 10 {
+    %0 = affine.load %b[%arg0] : memref<10xf32>
+  }
+
+	// All the memrefs should be privatized except '%c', which is not involved in
+  // the producer-consumer fusion.
+  // CHECK:      affine.for %{{.*}} = 0 to 10 {
+  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:   affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:   affine.load %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT:   affine.load %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT: }
+
+  return
+}
+
+// -----
+
+func @should_fuse_multi_store_producer_with_scaping_memrefs_and_remove_src(
+    %a : memref<10xf32>, %b : memref<10xf32>) {
+  %cst = constant 0.000000e+00 : f32
+  affine.for %i0 = 0 to 10 {
+    affine.store %cst, %a[%i0] : memref<10xf32>
+    affine.store %cst, %b[%i0] : memref<10xf32>
+  }
+
+  affine.for %i1 = 0 to 10 {
+    %0 = affine.load %a[%i1] : memref<10xf32>
+  }
+
+  affine.for %i2 = 0 to 10 {
+    %0 = affine.load %b[%i2] : memref<10xf32>
+  }
+
+	// Producer loop '%i0' should be removed after fusion since fusion is maximal.
+  // No memref should be privatized since they escape the function.
+  // CHECK:       affine.for %{{.*}} = 0 to 10 {
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:  }
+  // CHECK-NOT:   affine.for
+
+  return
+}
+
+// -----
+
+func @should_fuse_multi_store_producer_with_scaping_memrefs_and_preserve_src(
+    %a : memref<10xf32>, %b : memref<10xf32>) {
+  %cst = constant 0.000000e+00 : f32
+  affine.for %i0 = 0 to 10 {
+    affine.store %cst, %a[%i0] : memref<10xf32>
+    affine.store %cst, %b[%i0] : memref<10xf32>
+  }
+
+  affine.for %i1 = 0 to 5 {
+    %0 = affine.load %a[%i1] : memref<10xf32>
+  }
+
+  affine.for %i2 = 0 to 10 {
+    %0 = affine.load %b[%i2] : memref<10xf32>
+  }
+
+	// Loops '%i0' and '%i2' should be fused first and '%i0' should be removed
+  // since fusion is maximal. Then the fused loop and '%i1' should be fused
+  // and the fused loop shouldn't be removed since fusion is not maximal.
+  // CHECK:       affine.for %{{.*}} = 0 to 10 {
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:  }
+  // CHECK:       affine.for %{{.*}} = 0 to 5 {
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:  }
+  // CHECK-NOT:   affine.for
+
+  return
+}


        


More information about the Mlir-commits mailing list