[llvm-branch-commits] [mlir] 735a07f - Revert "[mlir][Affine] Add support for multi-store producer fusion"

Diego Caballero via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jan 20 14:43:06 PST 2021


Author: Diego Caballero
Date: 2021-01-21T00:37:23+02:00
New Revision: 735a07f0478566f6f7c60a8a98eb8884db574113

URL: https://github.com/llvm/llvm-project/commit/735a07f0478566f6f7c60a8a98eb8884db574113
DIFF: https://github.com/llvm/llvm-project/commit/735a07f0478566f6f7c60a8a98eb8884db574113.diff

LOG: Revert "[mlir][Affine] Add support for multi-store producer fusion"

This reverts commit 7dd198852b4db52ae22242dfeda4eccda83aa8b2.

ASAN issue.

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/AffineStructures.h
    mlir/include/mlir/Analysis/Utils.h
    mlir/include/mlir/Transforms/LoopFusionUtils.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Analysis/AffineStructures.cpp
    mlir/lib/Analysis/Utils.cpp
    mlir/lib/Transforms/LoopFusion.cpp
    mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
    mlir/test/Transforms/loop-fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h
index 893d4ea4ff46..fa80db7d4b63 100644
--- a/mlir/include/mlir/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Analysis/AffineStructures.h
@@ -234,21 +234,6 @@ class FlatAffineConstraints {
   //  TODO: add support for non-unit strides.
   LogicalResult addAffineForOpDomain(AffineForOp forOp);
 
-  /// Adds constraints (lower and upper bounds) for each loop in the loop nest
-  /// described by the bound maps 'lbMaps' and 'ubMaps' of a computation slice.
-  /// Every pair ('lbMaps[i]', 'ubMaps[i]') describes the bounds of a loop in
-  /// the nest, sorted outer-to-inner. 'operands' contains the bound operands
-  /// for a single bound map. All the bound maps will use the same bound
-  /// operands. Note that some loops described by a computation slice might not
-  /// exist yet in the IR so the Value attached to those dimension identifiers
-  /// might be empty. For that reason, this method doesn't perform Value
-  /// look-ups to retrieve the dimension identifier positions. Instead, it
-  /// assumes the position of the dim identifiers in the constraint system is
-  /// the same as the position of the loop in the loop nest.
-  LogicalResult addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
-                                       ArrayRef<AffineMap> ubMaps,
-                                       ArrayRef<Value> operands);
-
   /// Adds constraints imposed by the `affine.if` operation. These constraints
   /// are collected from the IntegerSet attached to the given `affine.if`
   /// instance argument (`ifOp`). It is asserted that:

diff  --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h
index ee6f8095f25e..30b6272181f5 100644
--- a/mlir/include/mlir/Analysis/Utils.h
+++ b/mlir/include/mlir/Analysis/Utils.h
@@ -83,25 +83,10 @@ struct ComputationSliceState {
   // Clears all bounds and operands in slice state.
   void clearBounds();
 
-  /// Returns true if the computation slice is empty.
+  /// Return true if the computation slice is empty.
   bool isEmpty() const { return ivs.empty(); }
 
-  /// Returns true if the computation slice encloses all the iterations of the
-  /// sliced loop nest. Returns false if it does not. Returns llvm::None if it
-  /// cannot determine if the slice is maximal or not.
-  // TODO: Cache 'isMaximal' so that we don't recompute it when the slice
-  // information hasn't changed.
-  Optional<bool> isMaximal() const;
-
   void dump() const;
-
-private:
-  /// Fast check to determine if the computation slice is maximal. Returns true
-  /// if each slice dimension maps to an existing dst dimension and both the src
-  /// and the dst loops for those dimensions have the same bounds. Returns false
-  /// if both the src and the dst loops don't have the same bounds. Returns
-  /// llvm::None if none of the above can be proven.
-  Optional<bool> isSliceMaximalFastCheck() const;
 };
 
 /// Computes the computation slice loop bounds for one loop nest as affine maps

diff  --git a/mlir/include/mlir/Transforms/LoopFusionUtils.h b/mlir/include/mlir/Transforms/LoopFusionUtils.h
index 10d6b83d022f..eade565e0325 100644
--- a/mlir/include/mlir/Transforms/LoopFusionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopFusionUtils.h
@@ -50,8 +50,7 @@ struct FusionResult {
 // TODO: Generalize utilities so that producer-consumer and sibling fusion
 // strategies can be used without the assumptions made in the AffineLoopFusion
 // pass.
-class FusionStrategy {
-public:
+struct FusionStrategy {
   enum StrategyEnum {
     // Generic loop fusion: Arbitrary loops are considered for fusion. No
     // assumptions about a specific fusion strategy from AffineLoopFusion pass
@@ -70,34 +69,13 @@ class FusionStrategy {
     // implementation in AffineLoopFusion pass are made. See pass for specific
     // details.
     Sibling
-  };
+  } strategy;
 
-  /// Construct a generic or producer-consumer fusion strategy.
-  FusionStrategy(StrategyEnum strategy) : strategy(strategy) {
-    assert(strategy != Sibling &&
-           "Sibling fusion strategy requires a specific memref");
-  }
-
-  /// Construct a sibling fusion strategy targeting 'memref'. This construct
-  /// should only be used for sibling fusion.
-  FusionStrategy(Value memref) : strategy(Sibling), memref(memref) {}
-
-  /// Returns the fusion strategy.
-  StrategyEnum getStrategy() const { return strategy; };
-
-  /// Returns the memref attached to this sibling fusion strategy.
-  Value getSiblingFusionMemRef() const {
-    assert(strategy == Sibling && "Memref is only valid for sibling fusion");
-    return memref;
-  }
-
-private:
-  /// Fusion strategy.
-  StrategyEnum strategy;
-
-  /// Target memref for this fusion transformation. Only used for sibling
-  /// fusion.
+  // Target memref for this fusion transformation.
   Value memref;
+
+  FusionStrategy(StrategyEnum strategy, Value memref)
+      : strategy(strategy), memref(memref) {}
 };
 
 /// Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the
@@ -108,10 +86,11 @@ class FusionStrategy {
 /// NOTE: This function is not feature complete and should only be used in
 /// testing.
 /// TODO: Update comments when this function is fully implemented.
-FusionResult
-canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth,
-             ComputationSliceState *srcSlice,
-             FusionStrategy fusionStrategy = FusionStrategy::Generic);
+FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
+                          unsigned dstLoopDepth,
+                          ComputationSliceState *srcSlice,
+                          FusionStrategy fusionStrategy = {
+                              FusionStrategy::Generic, Value()});
 
 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
 /// and source slice loop bounds specified in 'srcSlice'.
@@ -155,12 +134,6 @@ bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
                           const ComputationSliceState &slice,
                           int64_t *computeCost);
 
-/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
-/// producer-consumer dependence between write ops in 'srcOps' and read ops in
-/// 'dstOps'.
-void gatherProducerConsumerMemrefs(ArrayRef<Operation *> srcOps,
-                                   ArrayRef<Operation *> dstOps,
-                                   DenseSet<Value> &producerConsumerMemrefs);
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index a03b439af339..438a468673b5 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -17,111 +17,6 @@ include "mlir/Pass/PassBase.td"
 
 def AffineLoopFusion : FunctionPass<"affine-loop-fusion"> {
   let summary = "Fuse affine loop nests";
-  let description = [{
-    This pass performs fusion of loop nests using a slicing-based approach. It
-    combines two fusion strategies: producer-consumer fusion and sibling fusion.
-    Producer-consumer fusion is aimed at fusing pairs of loops where the first
-    one writes to a memref that the second reads. Sibling fusion targets pairs
-    of loops that share no dependences between them but that load from the same
-    memref. The fused loop nests, when possible, are rewritten to access
-    significantly smaller local buffers instead of the original memref's, and
-    the latter are often either completely optimized away or contracted. This
-    transformation leads to enhanced locality and lower memory footprint through
-    the elimination or contraction of temporaries/intermediate memref's. These
-    benefits are sometimes achieved at the expense of redundant computation
-    through a cost model that evaluates available choices such as the depth at
-    which a source slice should be materialized in the designation slice.
-
-    Example 1: Producer-consumer fusion.
-    Input:
-    ```mlir
-    func @producer_consumer_fusion(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
-      %0 = alloc() : memref<10xf32>
-      %1 = alloc() : memref<10xf32>
-      %cst = constant 0.000000e+00 : f32
-      affine.for %arg2 = 0 to 10 {
-        affine.store %cst, %0[%arg2] : memref<10xf32>
-        affine.store %cst, %1[%arg2] : memref<10xf32>
-      }
-      affine.for %arg2 = 0 to 10 {
-        %2 = affine.load %0[%arg2] : memref<10xf32>
-        %3 = addf %2, %2 : f32
-        affine.store %3, %arg0[%arg2] : memref<10xf32>
-      }
-      affine.for %arg2 = 0 to 10 {
-        %2 = affine.load %1[%arg2] : memref<10xf32>
-        %3 = mulf %2, %2 : f32
-        affine.store %3, %arg1[%arg2] : memref<10xf32>
-      }
-      return
-    }
-    ```
-    Output:
-    ```mlir
-    func @producer_consumer_fusion(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
-      %0 = alloc() : memref<1xf32>
-      %1 = alloc() : memref<1xf32>
-      %cst = constant 0.000000e+00 : f32
-      affine.for %arg2 = 0 to 10 {
-        affine.store %cst, %0[0] : memref<1xf32>
-        affine.store %cst, %1[0] : memref<1xf32>
-        %2 = affine.load %1[0] : memref<1xf32>
-        %3 = mulf %2, %2 : f32
-        affine.store %3, %arg1[%arg2] : memref<10xf32>
-        %4 = affine.load %0[0] : memref<1xf32>
-        %5 = addf %4, %4 : f32
-        affine.store %5, %arg0[%arg2] : memref<10xf32>
-      }
-      return
-    }
-    ```
-
-    Example 2: Sibling fusion.
-    Input:
-    ```mlir
-    func @sibling_fusion(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>,
-                         %arg2: memref<10x10xf32>, %arg3: memref<10x10xf32>,
-                         %arg4: memref<10x10xf32>) {
-      affine.for %arg5 = 0 to 3 {
-        affine.for %arg6 = 0 to 3 {
-          %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
-          %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32>
-          %2 = mulf %0, %1 : f32
-          affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32>
-        }
-      }
-      affine.for %arg5 = 0 to 3 {
-        affine.for %arg6 = 0 to 3 {
-          %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
-          %1 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32>
-          %2 = addf %0, %1 : f32
-          affine.store %2, %arg4[%arg5, %arg6] : memref<10x10xf32>
-        }
-      }
-      return
-    }
-    ```
-    Output:
-    ```mlir
-    func @sibling_fusion(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>,
-                         %arg2: memref<10x10xf32>, %arg3: memref<10x10xf32>,
-                         %arg4: memref<10x10xf32>) {
-      affine.for %arg5 = 0 to 3 {
-        affine.for %arg6 = 0 to 3 {
-          %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
-          %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32>
-          %2 = mulf %0, %1 : f32
-          affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32>
-          %3 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
-          %4 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32>
-          %5 = addf %3, %4 : f32
-          affine.store %5, %arg4[%arg5, %arg6] : memref<10x10xf32>
-        }
-      }
-      return
-    }
-    ```
-  }];
   let constructor = "mlir::createLoopFusionPass()";
   let options = [
     Option<"computeToleranceThreshold", "fusion-compute-tolerance", "double",

diff  --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index 81dc7855184e..12c90fbcfc54 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -708,70 +708,6 @@ LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) {
                               /*eq=*/false, /*lower=*/false);
 }
 
-/// Adds constraints (lower and upper bounds) for each loop in the loop nest
-/// described by the bound maps 'lbMaps' and 'ubMaps' of a computation slice.
-/// Every pair ('lbMaps[i]', 'ubMaps[i]') describes the bounds of a loop in
-/// the nest, sorted outer-to-inner. 'operands' contains the bound operands
-/// for a single bound map. All the bound maps will use the same bound
-/// operands. Note that some loops described by a computation slice might not
-/// exist yet in the IR so the Value attached to those dimension identifiers
-/// might be empty. For that reason, this method doesn't perform Value
-/// look-ups to retrieve the dimension identifier positions. Instead, it
-/// assumes the position of the dim identifiers in the constraint system is
-/// the same as the position of the loop in the loop nest.
-LogicalResult
-FlatAffineConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
-                                              ArrayRef<AffineMap> ubMaps,
-                                              ArrayRef<Value> operands) {
-  assert(lbMaps.size() == ubMaps.size());
-  assert(lbMaps.size() <= getNumDimIds());
-
-  for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
-    AffineMap lbMap = lbMaps[i];
-    AffineMap ubMap = ubMaps[i];
-    assert(!lbMap || lbMap.getNumInputs() == operands.size());
-    assert(!ubMap || ubMap.getNumInputs() == operands.size());
-
-    // Check if this slice is just an equality along this dimension. If so,
-    // retrieve the existing loop it equates to and add it to the system.
-    if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
-        ubMap.getNumResults() == 1 &&
-        lbMap.getResult(0) + 1 == ubMap.getResult(0) &&
-        // The condition above will be true for maps describing a single
-        // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
-        // Make sure we skip those cases by checking that the lb result is not
-        // just a constant.
-        !lbMap.getResult(0).isa<AffineConstantExpr>()) {
-      // Limited support: we expect the lb result to be just a loop dimension.
-      // Not supported otherwise for now.
-      AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
-      if (!result)
-        return failure();
-
-      AffineForOp loop =
-          getForInductionVarOwner(operands[result.getPosition()]);
-      if (!loop)
-        return failure();
-
-      if (failed(addAffineForOpDomain(loop)))
-        return failure();
-      continue;
-    }
-
-    // This slice refers to a loop that doesn't exist in the IR yet. Add its
-    // bounds to the system assuming its dimension identifier position is the
-    // same as the position of the loop in the loop nest.
-    if (lbMap && failed(addLowerOrUpperBound(i, lbMap, operands, /*eq=*/false,
-                                             /*lower=*/true)))
-      return failure();
-
-    if (ubMap && failed(addLowerOrUpperBound(i, ubMap, operands, /*eq=*/false,
-                                             /*lower=*/false)))
-      return failure();
-  }
-  return success();
-}
-
 void FlatAffineConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
   // Create the base constraints from the integer set attached to ifOp.
   FlatAffineConstraints cst(ifOp.getIntegerSet());

diff  --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 383a6587bbef..a1e7d1ffe844 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -12,8 +12,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/Utils.h"
+
 #include "mlir/Analysis/AffineAnalysis.h"
-#include "mlir/Analysis/PresburgerSet.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -127,128 +127,6 @@ void ComputationSliceState::dump() const {
   }
 }
 
-/// Fast check to determine if the computation slice is maximal. Returns true if
-/// each slice dimension maps to an existing dst dimension and both the src
-/// and the dst loops for those dimensions have the same bounds. Returns false
-/// if both the src and the dst loops don't have the same bounds. Returns
-/// llvm::None if none of the above can be proven.
-Optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
-  assert(lbs.size() == ubs.size() && lbs.size() && ivs.size() &&
-         "Unexpected number of lbs, ubs and ivs in slice");
-
-  for (unsigned i = 0, end = lbs.size(); i < end; ++i) {
-    AffineMap lbMap = lbs[i];
-    AffineMap ubMap = ubs[i];
-
-    // Check if this slice is just an equality along this dimension.
-    if (!lbMap || !ubMap || lbMap.getNumResults() != 1 ||
-        ubMap.getNumResults() != 1 ||
-        lbMap.getResult(0) + 1 != ubMap.getResult(0) ||
-        // The condition above will be true for maps describing a single
-        // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
-        // Make sure we skip those cases by checking that the lb result is not
-        // just a constant.
-        lbMap.getResult(0).isa<AffineConstantExpr>())
-      return llvm::None;
-
-    // Limited support: we expect the lb result to be just a loop dimension for
-    // now.
-    AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
-    if (!result)
-      return llvm::None;
-
-    // Retrieve dst loop bounds.
-    AffineForOp dstLoop =
-        getForInductionVarOwner(lbOperands[i][result.getPosition()]);
-    if (!dstLoop)
-      return llvm::None;
-    AffineMap dstLbMap = dstLoop.getLowerBoundMap();
-    AffineMap dstUbMap = dstLoop.getUpperBoundMap();
-
-    // Retrieve src loop bounds.
-    AffineForOp srcLoop = getForInductionVarOwner(ivs[i]);
-    assert(srcLoop && "Expected affine for");
-    AffineMap srcLbMap = srcLoop.getLowerBoundMap();
-    AffineMap srcUbMap = srcLoop.getUpperBoundMap();
-
-    // Limited support: we expect simple src and dst loops with a single
-    // constant component per bound for now.
-    if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 ||
-        dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1)
-      return llvm::None;
-
-    AffineExpr srcLbResult = srcLbMap.getResult(0);
-    AffineExpr dstLbResult = dstLbMap.getResult(0);
-    AffineExpr srcUbResult = srcUbMap.getResult(0);
-    AffineExpr dstUbResult = dstUbMap.getResult(0);
-    if (!srcLbResult.isa<AffineConstantExpr>() ||
-        !srcUbResult.isa<AffineConstantExpr>() ||
-        !dstLbResult.isa<AffineConstantExpr>() ||
-        !dstUbResult.isa<AffineConstantExpr>())
-      return llvm::None;
-
-    // Check if src and dst loop bounds are the same. If not, we can guarantee
-    // that the slice is not maximal.
-    if (srcLbResult != dstLbResult || srcUbResult != dstUbResult)
-      return false;
-  }
-
-  return true;
-}
-
-/// Returns true if the computation slice encloses all the iterations of the
-/// sliced loop nest. Returns false if it does not. Returns llvm::None if it
-/// cannot determine if the slice is maximal or not.
-Optional<bool> ComputationSliceState::isMaximal() const {
-  // Fast check to determine if the computation slice is maximal. If the result
-  // is inconclusive, we proceed with a more expensive analysis.
-  Optional<bool> isMaximalFastCheck = isSliceMaximalFastCheck();
-  if (isMaximalFastCheck.hasValue())
-    return isMaximalFastCheck;
-
-  // Create constraints for the src loop nest being sliced.
-  FlatAffineConstraints srcConstraints;
-  srcConstraints.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0,
-                       /*numLocals=*/0, ivs);
-  for (Value iv : ivs) {
-    AffineForOp loop = getForInductionVarOwner(iv);
-    assert(loop && "Expected affine for");
-    if (failed(srcConstraints.addAffineForOpDomain(loop)))
-      return llvm::None;
-  }
-
-  // Create constraints for the slice using the dst loop nest information. We
-  // retrieve existing dst loops from the lbOperands.
-  SmallVector<Value, 8> consumerIVs;
-  for (Value lbOp : lbOperands[0])
-    if (getForInductionVarOwner(lbOp))
-      consumerIVs.push_back(lbOp);
-
-  // Add empty IV Values for those new loops that are not equalities and,
-  // therefore, are not yet materialized in the IR.
-  for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i)
-    consumerIVs.push_back(Value());
-
-  FlatAffineConstraints sliceConstraints;
-  sliceConstraints.reset(/*numDims=*/consumerIVs.size(), /*numSymbols=*/0,
-                         /*numLocals=*/0, consumerIVs);
-
-  if (failed(sliceConstraints.addDomainFromSliceMaps(lbs, ubs, lbOperands[0])))
-    return llvm::None;
-
-  if (srcConstraints.getNumDimIds() != sliceConstraints.getNumDimIds())
-    // Constraint dims are 
diff erent. The integer set 
diff erence can't be
-    // computed so we don't know if the slice is maximal.
-    return llvm::None;
-
-  // Compute the 
diff erence between the src loop nest and the slice integer
-  // sets.
-  PresburgerSet srcSet(srcConstraints);
-  PresburgerSet sliceSet(sliceConstraints);
-  PresburgerSet 
diff Set = srcSet.subtract(sliceSet);
-  return 
diff Set.isIntegerEmpty();
-}
-
 unsigned MemRefRegion::getRank() const {
   return memref.getType().cast<MemRefType>().getRank();
 }

diff  --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 6c56368ca6e1..6fe112b89baf 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -30,7 +30,6 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include <iomanip>
-#include <set>
 #include <sstream>
 #define DEBUG_TYPE "affine-loop-fusion"
 
@@ -271,6 +270,64 @@ struct MemRefDependenceGraph {
     return false;
   }
 
+  // Returns the unique AffineWriteOpInterface in `node` that meets all the
+  // following:
+  //   *) store is the only one that writes to a function-local memref live out
+  //      of `node`,
+  //   *) store is not the source of a self-dependence on `node`.
+  // Otherwise, returns a null AffineWriteOpInterface.
+  AffineWriteOpInterface getUniqueOutgoingStore(Node *node) {
+    AffineWriteOpInterface uniqueStore;
+
+    // Return null if `node` doesn't have any outgoing edges.
+    auto outEdgeIt = outEdges.find(node->id);
+    if (outEdgeIt == outEdges.end())
+      return nullptr;
+
+    const auto &nodeOutEdges = outEdgeIt->second;
+    for (auto *op : node->stores) {
+      auto storeOp = cast<AffineWriteOpInterface>(op);
+      auto memref = storeOp.getMemRef();
+      // Skip this store if there are no dependences on its memref. This means
+      // that store either:
+      // *) writes to a memref that is only read within the same loop nest
+      //    (self-dependence edges are not represented in graph at the moment),
+      // *) writes to a function live out memref (function parameter), or
+      // *) is dead.
+      if (llvm::all_of(nodeOutEdges, [=](const Edge &edge) {
+            return (edge.value != memref);
+          }))
+        continue;
+
+      if (uniqueStore)
+        // Found multiple stores to function-local live-out memrefs.
+        return nullptr;
+      // Found first store to function-local live-out memref.
+      uniqueStore = storeOp;
+    }
+
+    return uniqueStore;
+  }
+
+  // Returns true if node 'id' can be removed from the graph. Returns false
+  // otherwise. A node can be removed from the graph iff the following
+  // conditions are met:
+  // *) The node does not write to any memref which escapes (or is a
+  //    function/block argument).
+  // *) The node has no successors in the dependence graph.
+  bool canRemoveNode(unsigned id) {
+    if (writesToLiveInOrEscapingMemrefs(id))
+      return false;
+    Node *node = getNode(id);
+    for (auto *storeOpInst : node->stores) {
+      // Return false if there exist out edges from 'id' on 'memref'.
+      auto storeMemref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
+      if (getOutEdgeCount(id, storeMemref) > 0)
+        return false;
+    }
+    return true;
+  }
+
   // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
   // is for 'value' if non-null, or for any value otherwise. Returns false
   // otherwise.
@@ -438,49 +495,42 @@ struct MemRefDependenceGraph {
     return dstNodeInst;
   }
 
-  // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
-  // taking into account that:
-  //   *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
-  //   *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
-  //      private memref.
-  void updateEdges(unsigned srcId, unsigned dstId,
-                   const DenseSet<Value> &privateMemRefs, bool removeSrcId) {
+  // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
+  // has been replaced in node at 'dstId' by a private memref depending
+  // on the value of 'createPrivateMemRef'.
+  void updateEdges(unsigned srcId, unsigned dstId, Value oldMemRef,
+                   bool createPrivateMemRef) {
     // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'.
     if (inEdges.count(srcId) > 0) {
       SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
       for (auto &inEdge : oldInEdges) {
-        // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
-        if (privateMemRefs.count(inEdge.value) == 0)
+        // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'.
+        if (inEdge.value != oldMemRef)
           addEdge(inEdge.id, dstId, inEdge.value);
       }
     }
     // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
-    // If 'srcId' is going to be removed, remap all the out edges to 'dstId'.
     if (outEdges.count(srcId) > 0) {
       SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
       for (auto &outEdge : oldOutEdges) {
         // Remove any out edges from 'srcId' to 'dstId' across memrefs.
         if (outEdge.id == dstId)
           removeEdge(srcId, outEdge.id, outEdge.value);
-        else if (removeSrcId) {
-          addEdge(dstId, outEdge.id, outEdge.value);
-          removeEdge(srcId, outEdge.id, outEdge.value);
-        }
       }
     }
     // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
     // replaced by a private memref). These edges could come from nodes
     // other than 'srcId' which were removed in the previous step.
-    if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) {
+    if (inEdges.count(dstId) > 0 && createPrivateMemRef) {
       SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
       for (auto &inEdge : oldInEdges)
-        if (privateMemRefs.count(inEdge.value) > 0)
+        if (inEdge.value == oldMemRef)
           removeEdge(inEdge.id, dstId, inEdge.value);
     }
   }
 
   // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
-  // of sibling node 'sibId' into node 'dstId'.
+  // of sibling node 'sidId' into node 'dstId'.
   void updateEdges(unsigned sibId, unsigned dstId) {
     // For each edge in 'inEdges[sibId]':
     // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
@@ -574,132 +624,6 @@ struct MemRefDependenceGraph {
   void dump() const { print(llvm::errs()); }
 };
 
-/// Returns true if node 'srcId' can be removed after fusing it with node
-/// 'dstId'. The node can be removed if any of the following conditions are met:
-///   1. 'srcId' has no output dependences after fusion and no escaping memrefs.
-///   2. 'srcId' has no output dependences after fusion, has escaping memrefs
-///       and the fusion slice is maximal.
-///   3. 'srcId' has output dependences after fusion, the fusion slice is
-///      maximal and the fusion insertion point dominates all the dependences.
-static bool canRemoveSrcNodeAfterFusion(
-    unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
-    Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
-    MemRefDependenceGraph *mdg) {
-
-  Operation *dstNodeOp = mdg->getNode(dstId)->op;
-  bool hasOutDepsAfterFusion = false;
-
-  for (auto &outEdge : mdg->outEdges[srcId]) {
-    Operation *depNodeOp = mdg->getNode(outEdge.id)->op;
-    // Skip dependence with dstOp since it will be removed after fusion.
-    if (depNodeOp == dstNodeOp)
-      continue;
-
-    // Only fusion within the same block is supported. Use domination analysis
-    // when needed.
-    if (depNodeOp->getBlock() != dstNodeOp->getBlock())
-      return false;
-
-    // Check if the insertion point of the fused loop dominates the dependence.
-    // Otherwise, the src loop can't be removed.
-    if (fusedLoopInsPoint != depNodeOp &&
-        !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) {
-      LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't "
-                                 "dominate dependence\n");
-      return false;
-    }
-
-    hasOutDepsAfterFusion = true;
-  }
-
-  // If src loop has dependences after fusion or it writes to an live-out or
-  // escaping memref, we can only remove it if the fusion slice is maximal so
-  // that all the dependences are preserved.
-  if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
-    Optional<bool> isMaximal = fusionSlice.isMaximal();
-    if (!isMaximal.hasValue()) {
-      LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine "
-                                 "if fusion is maximal\n");
-      return false;
-    }
-
-    if (!isMaximal.getValue()) {
-      LLVM_DEBUG(llvm::dbgs()
-                 << "Src loop can't be removed: fusion is not maximal\n");
-      return false;
-    }
-  }
-
-  return true;
-}
-
-/// Returns in 'srcIdCandidates' the producer fusion candidates for consumer
-/// 'dstId'.
-// TODO: Move this to a loop fusion utility once 'mdg' is also moved.
-static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
-                                  DenseSet<unsigned> &srcIdCandidates) {
-  // Skip if no input edges along which to fuse.
-  if (mdg->inEdges.count(dstId) == 0)
-    return;
-
-  // Gather memrefs from loads in 'dstId'.
-  auto *dstNode = mdg->getNode(dstId);
-  DenseSet<Value> consumedMemrefs;
-  for (Operation *load : dstNode->loads)
-    consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
-
-  // Traverse 'dstId' incoming edges and gather the nodes that contain a store
-  // to one of the consumed memrefs.
-  for (auto &srcEdge : mdg->inEdges[dstId]) {
-    auto *srcNode = mdg->getNode(srcEdge.id);
-    // Skip if 'srcNode' is not a loop nest.
-    if (!isa<AffineForOp>(srcNode->op))
-      continue;
-
-    if (any_of(srcNode->stores, [&](Operation *op) {
-          auto storeOp = cast<AffineWriteOpInterface>(op);
-          return consumedMemrefs.count(storeOp.getMemRef()) > 0;
-        }))
-      srcIdCandidates.insert(srcNode->id);
-  }
-}
-
-/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
-/// producer-consumer dependence between 'srcId' and 'dstId'.
-static void
-gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
-                              MemRefDependenceGraph *mdg,
-                              DenseSet<Value> &producerConsumerMemrefs) {
-  auto *dstNode = mdg->getNode(dstId);
-  auto *srcNode = mdg->getNode(srcId);
-  gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
-                                producerConsumerMemrefs);
-}
-
-/// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
-/// that escape the function. A memref escapes the function if either:
-///   1. It's a function argument, or
-///   2. It's used by a non-affine op (e.g., std load/store, std call, etc.)
-void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
-                           DenseSet<Value> &escapingMemRefs) {
-  auto *node = mdg->getNode(id);
-  for (auto *storeOpInst : node->stores) {
-    auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
-    if (escapingMemRefs.count(memref))
-      continue;
-    // Check if 'memref' escapes because it's a block argument.
-    if (memref.isa<BlockArgument>()) {
-      escapingMemRefs.insert(memref);
-      continue;
-    }
-    // Check if 'memref' escapes through a non-affine op (e.g., std load/store,
-    // call op, etc.).
-    for (Operation *user : memref.getUsers())
-      if (!isMemRefDereferencingOp(*user))
-        escapingMemRefs.insert(memref);
-  }
-}
-
 } // end anonymous namespace
 
 // Initializes the data dependence graph by walking operations in 'f'.
@@ -707,7 +631,6 @@ void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
 // TODO: Add support for taking a Block arg to construct the
 // dependence graph at a 
diff erent depth.
 bool MemRefDependenceGraph::init(FuncOp f) {
-  LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
   DenseMap<Value, SetVector<unsigned>> memrefAccesses;
 
   // TODO: support multi-block functions.
@@ -763,12 +686,6 @@ bool MemRefDependenceGraph::init(FuncOp f) {
     }
   }
 
-#ifndef NDEBUG
-  for (auto &idAndNode : nodes)
-    LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n"
-                            << *(idAndNode.second.op) << "\n");
-#endif
-
   // Add dependence edges between nodes which produce SSA values and their
   // users.
   for (auto &idAndNode : nodes) {
@@ -808,6 +725,22 @@ bool MemRefDependenceGraph::init(FuncOp f) {
   return true;
 }
 
+// Removes load operations from 'srcLoads' which operate on 'memref', and
+// adds them to 'dstLoads'.
+static void moveLoadsAccessingMemrefTo(Value memref,
+                                       SmallVectorImpl<Operation *> *srcLoads,
+                                       SmallVectorImpl<Operation *> *dstLoads) {
+  dstLoads->clear();
+  SmallVector<Operation *, 4> srcLoadsToKeep;
+  for (auto *load : *srcLoads) {
+    if (cast<AffineReadOpInterface>(load).getMemRef() == memref)
+      dstLoads->push_back(load);
+    else
+      srcLoadsToKeep.push_back(load);
+  }
+  srcLoads->swap(srcLoadsToKeep);
+}
+
 // Sinks all sequential loops to the innermost levels (while preserving
 // relative order among them) and moves all parallel loops to the
 // outermost (while again preserving relative order among them).
@@ -999,6 +932,75 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
   return false;
 }
 
+// Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId'
+// may write to multiple memrefs but it is required that only one of them,
+// 'srcLiveOutStoreOp', has output edges.
+// Returns true if 'dstNode's read/write region to 'memref' is a super set of
+// 'srcNode's write region to 'memref' and 'srcId' has only one output edge.
+// TODO: Generalize this to handle more live in/out cases.
+static bool
+canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
+                               AffineWriteOpInterface srcLiveOutStoreOp,
+                               MemRefDependenceGraph *mdg) {
+  assert(srcLiveOutStoreOp && "Expected a valid store op");
+  auto *dstNode = mdg->getNode(dstId);
+  Value memref = srcLiveOutStoreOp.getMemRef();
+  // Return false if 'srcNode' has more than one output edge on 'memref'.
+  if (mdg->getOutEdgeCount(srcId, memref) > 1)
+    return false;
+
+  // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOp' on 'memref'.
+  MemRefRegion srcWriteRegion(srcLiveOutStoreOp.getLoc());
+  if (failed(srcWriteRegion.compute(srcLiveOutStoreOp, /*loopDepth=*/0))) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Unable to compute MemRefRegion for source operation\n.");
+    return false;
+  }
+  SmallVector<int64_t, 4> srcShape;
+  // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'.
+  // by 'srcStoreOp' at depth 'dstLoopDepth'.
+  Optional<int64_t> srcNumElements =
+      srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape);
+  if (!srcNumElements.hasValue())
+    return false;
+
+  // Compute MemRefRegion 'dstRegion' for 'dstStore/LoadOpInst' on 'memref'.
+  // TODO: Compute 'unionboundingbox' of all write regions (one for
+  // each store op in 'dstStoreOps').
+  SmallVector<Operation *, 2> dstStoreOps;
+  dstNode->getStoreOpsForMemref(memref, &dstStoreOps);
+  SmallVector<Operation *, 2> dstLoadOps;
+  dstNode->getLoadOpsForMemref(memref, &dstLoadOps);
+
+  auto *dstOpInst = dstStoreOps.empty() ? dstLoadOps[0] : dstStoreOps[0];
+  MemRefRegion dstRegion(dstOpInst->getLoc());
+  if (failed(dstRegion.compute(dstOpInst, /*loopDepth=*/0))) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Unable to compute MemRefRegion for dest operation\n.");
+    return false;
+  }
+  SmallVector<int64_t, 4> dstShape;
+  // Query 'dstRegion' for 'dstShape' and 'dstNumElements'.
+  // by 'dstOpInst' at depth 'dstLoopDepth'.
+  Optional<int64_t> dstNumElements =
+      dstRegion.getConstantBoundingSizeAndShape(&dstShape);
+  if (!dstNumElements.hasValue())
+    return false;
+
+  // Return false if write region is not a superset of 'srcNodes' write
+  // region to 'memref'.
+  // TODO: Check the shape and lower bounds here too.
+  if (srcNumElements != dstNumElements)
+    return false;
+
+  // Return false if 'memref' is used by a non-affine operation that is
+  // between node 'srcId' and node 'dstId'.
+  if (hasNonAffineUsersOnThePath(srcId, dstId, mdg))
+    return false;
+
+  return true;
+}
+
 // Checks the profitability of fusing a backwards slice of the loop nest
 // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on
@@ -1027,6 +1029,9 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
 //    the largest computation slice at the maximal dst loop depth (closest to
 //    the load) to minimize reuse distance and potentially enable subsequent
 //    load/store forwarding.
+//    NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for
+//    the same memref as is written by 'srcOpInst', then the union of slice
+//    loop bounds is used to compute the slice and associated slice cost.
 //    NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
 //    nest, at which the src computation slice is inserted/fused.
 //    NOTE: We attempt to maximize the dst loop depth, but there are cases
@@ -1036,18 +1041,18 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
 // *) Compares the total cost of the unfused loop nests to the min cost fused
 //    loop nest computed in the previous step, and returns true if the latter
 //    is lower.
-// TODO: Extend profitability analysis to support scenarios with multiple
-// stores.
 static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
-                               AffineForOp dstForOp,
+                               ArrayRef<Operation *> dstLoadOpInsts,
                                ArrayRef<ComputationSliceState> depthSliceUnions,
                                unsigned maxLegalFusionDepth,
                                unsigned *dstLoopDepth,
                                double computeToleranceThreshold) {
   LLVM_DEBUG({
     llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
-    llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n";
-    llvm::dbgs() << dstForOp << "\n";
+    llvm::dbgs() << ' ' << *srcOpInst << " and destination op(s)\n";
+    for (auto dstOpInst : dstLoadOpInsts) {
+      llvm::dbgs() << " " << *dstOpInst << "\n";
+    };
   });
 
   if (maxLegalFusionDepth == 0) {
@@ -1065,8 +1070,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
     return false;
 
   // Compute cost of dst loop nest.
+  SmallVector<AffineForOp, 4> dstLoopIVs;
+  getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
+
   LoopNestStats dstLoopNestStats;
-  if (!getLoopNestStats(dstForOp, &dstLoopNestStats))
+  if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats))
     return false;
 
   // Search for min cost value for 'dstLoopDepth'. At each value of
@@ -1100,19 +1108,18 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
   int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue();
 
   // Compute op instance count for the src loop nest.
-  uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
+  uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], dstLoopNestStats);
 
   // Evaluate all depth choices for materializing the slice in the destination
   // loop nest.
   for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
-    const ComputationSliceState &slice = depthSliceUnions[i - 1];
     // Skip slice union if it wasn't computed for this depth.
-    if (slice.isEmpty())
+    if (depthSliceUnions[i - 1].isEmpty())
       continue;
 
     int64_t fusedLoopNestComputeCost;
-    if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp,
-                              dstLoopNestStats, slice,
+    if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0],
+                              dstLoopNestStats, depthSliceUnions[i - 1],
                               &fusedLoopNestComputeCost)) {
       LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n.");
       continue;
@@ -1124,11 +1131,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
         1;
 
     // Determine what the slice write MemRefRegion would be, if the src loop
-    // nest slice 'slice' were to be inserted into the dst loop nest at loop
-    // depth 'i'.
+    // nest slice 'depthSliceUnions[i - 1]' were to be inserted into the dst
+    // loop nest at loop depth 'i'.
     MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
     if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
-                                        &slice))) {
+                                        &depthSliceUnions[i - 1]))) {
       LLVM_DEBUG(llvm::dbgs()
                  << "Failed to compute slice write region at loopDepth: " << i
                  << "\n");
@@ -1211,7 +1218,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
                    << "\n  fused loop nest compute cost: "
                    << minFusedLoopNestComputeCost << "\n");
 
-  auto dstMemSize = getMemoryFootprintBytes(dstForOp);
+  auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]);
   auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
 
   Optional<double> storageReduction = None;
@@ -1315,6 +1322,8 @@ struct GreedyFusion {
   MemRefDependenceGraph *mdg;
   // Worklist of graph nodes visited during the fusion pass.
   SmallVector<unsigned, 8> worklist;
+  // Set of graph nodes which are present on the worklist.
+  llvm::SmallDenseSet<unsigned, 16> worklistSet;
   // Parameter for local buffer size threshold.
   unsigned localBufSizeThreshold;
   // Parameter for fast memory space.
@@ -1335,14 +1344,16 @@ struct GreedyFusion {
         fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
         computeToleranceThreshold(computeToleranceThreshold) {}
 
-  /// Initializes 'worklist' with nodes from 'mdg'.
+  // Initializes 'worklist' with nodes from 'mdg'
   void init() {
     // TODO: Add a priority queue for prioritizing nodes by 
diff erent
     // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
     worklist.clear();
+    worklistSet.clear();
     for (auto &idAndNode : mdg->nodes) {
       const Node &node = idAndNode.second;
       worklist.push_back(node.id);
+      worklistSet.insert(node.id);
     }
   }
 
@@ -1361,11 +1372,11 @@ struct GreedyFusion {
   }
 
   void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
-    LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
     init();
     while (!worklist.empty()) {
       unsigned dstId = worklist.back();
       worklist.pop_back();
+      worklistSet.erase(dstId);
 
       // Skip if this node was removed (fused into another node).
       if (mdg->nodes.count(dstId) == 0)
@@ -1375,97 +1386,114 @@ struct GreedyFusion {
       // Skip if 'dstNode' is not a loop nest.
       if (!isa<AffineForOp>(dstNode->op))
         continue;
-
-      LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
-
       // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
       // while preserving relative order. This can increase the maximum loop
       // depth at which we can fuse a slice of a producer loop nest into a
       // consumer loop nest.
       sinkSequentialLoops(dstNode);
-      auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
-
-      // Try to fuse 'dstNode' with candidate producer loops until a fixed point
-      // is reached. Fusing two loops may expose new fusion opportunities.
-      bool dstNodeChanged;
-      do {
-        // Gather src loop candidates for 'dstNode' and visit them in "quasi"
-        // reverse program order to minimize the number of iterations needed to
-        // reach the fixed point. Note that this is a best effort approach since
-        // 'getProducerCandidates' does not always guarantee that program order
-        // in 'srcIdCandidates'.
-        dstNodeChanged = false;
-        DenseSet<unsigned> srcIdCandidates;
-        getProducerCandidates(dstId, mdg, srcIdCandidates);
-
-        /// Visit candidates in reverse node id order. This order corresponds to
-        /// the reverse program order when the 'mdg' is created. However,
-        /// reverse program order is not guaranteed and must not be required.
-        /// Reverse program order won't be held if the 'mdg' is reused from a
-        /// previous fusion step or if the node creation order changes in the
-        /// future to support more advance cases.
-        SmallVector<unsigned, 16> sortedSrcIdCandidates;
-        sortedSrcIdCandidates.reserve(srcIdCandidates.size());
-        sortedSrcIdCandidates.append(srcIdCandidates.begin(),
-                                     srcIdCandidates.end());
-        llvm::sort(sortedSrcIdCandidates, std::greater<unsigned>());
-
-        for (unsigned srcId : sortedSrcIdCandidates) {
+
+      SmallVector<Operation *, 4> loads = dstNode->loads;
+      SmallVector<Operation *, 4> dstLoadOpInsts;
+      DenseSet<Value> visitedMemrefs;
+      while (!loads.empty()) {
+        // Get memref of load on top of the stack.
+        auto memref = cast<AffineReadOpInterface>(loads.back()).getMemRef();
+        if (visitedMemrefs.count(memref) > 0)
+          continue;
+        visitedMemrefs.insert(memref);
+        // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
+        moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
+        // Skip if no input edges along which to fuse.
+        if (mdg->inEdges.count(dstId) == 0)
+          continue;
+        // Iterate through in-edges for 'dstId' and src node id for any
+        // edges on 'memref'.
+        SmallVector<unsigned, 2> srcNodeIds;
+        for (auto &srcEdge : mdg->inEdges[dstId]) {
+          // Skip 'srcEdge' if not for 'memref'.
+          if (srcEdge.value != memref)
+            continue;
+          srcNodeIds.push_back(srcEdge.id);
+        }
+        for (unsigned srcId : srcNodeIds) {
+          // Skip if this node was removed (fused into another node).
+          if (mdg->nodes.count(srcId) == 0)
+            continue;
           // Get 'srcNode' from which to attempt fusion into 'dstNode'.
           auto *srcNode = mdg->getNode(srcId);
-          auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
-          LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
-                                  << " for dst loop " << dstId << "\n");
-
-          DenseSet<Value> producerConsumerMemrefs;
-          gatherProducerConsumerMemrefs(srcId, dstId, mdg,
-                                        producerConsumerMemrefs);
-
-          // Skip if 'srcNode' out edge count on any memref is greater than
-          // 'maxSrcUserCount'.
-          if (any_of(producerConsumerMemrefs, [&](Value memref) {
-                return mdg->getOutEdgeCount(srcNode->id, memref) >
-                       maxSrcUserCount;
-              }))
+          // Skip if 'srcNode' is not a loop nest.
+          if (!isa<AffineForOp>(srcNode->op))
             continue;
+          // Skip if 'srcNode' has more than one live-out store to a
+          // function-local memref.
+          // TODO: Support more generic multi-output src loop nests
+          // fusion.
+          auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode);
+          if (!srcStoreOp) {
+            // Get the src store op at the deepest loop depth.
+            // We will use 'LoopFusionUtils::canFuseLoops' to check fusion
+            // feasibility for loops with multiple stores.
+            unsigned maxLoopDepth = 0;
+            for (auto *op : srcNode->stores) {
+              auto storeOp = cast<AffineWriteOpInterface>(op);
+              if (storeOp.getMemRef() != memref) {
+                srcStoreOp = nullptr;
+                break;
+              }
+              unsigned loopDepth = getNestingDepth(storeOp);
+              if (loopDepth > maxLoopDepth) {
+                maxLoopDepth = loopDepth;
+                srcStoreOp = storeOp;
+              }
+            }
+            if (!srcStoreOp)
+              continue;
+          }
 
-          // Gather memrefs in 'srcNode' that are written and escape to the
-          // function (e.g., memref function arguments, returned memrefs,
-          // memrefs passed to function calls, etc.).
-          DenseSet<Value> srcEscapingMemRefs;
-          gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
-
-          // Skip if there are non-affine operations in between the 'srcNode'
-          // and 'dstNode' using their memrefs. If so, we wouldn't be able to
-          // compute a legal insertion point for now. 'srcNode' and 'dstNode'
-          // memrefs with non-affine operation users would be considered
-          // escaping memrefs so we can limit this check to only scenarios with
-          // escaping memrefs.
-          if (!srcEscapingMemRefs.empty() &&
-              hasNonAffineUsersOnThePath(srcId, dstId, mdg)) {
-            LLVM_DEBUG(
-                llvm::dbgs()
-                << "Can't fuse: non-affine users in between the loops\n.");
+          // Unique outgoing store found must write to 'memref' since 'memref'
+          // is the one that established the producer-consumer relationship
+          // between 'srcNode' and 'dstNode'.
+          assert(srcStoreOp.getMemRef() == memref &&
+                 "Found store to unexpected memref");
+
+          // Skip if 'srcNode' writes to any live in or escaping memrefs,
+          // and cannot be fused.
+          bool writesToLiveInOrOut =
+              mdg->writesToLiveInOrEscapingMemrefs(srcNode->id);
+          if (writesToLiveInOrOut &&
+              !canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg))
             continue;
+
+          // Don't create a private memref if 'writesToLiveInOrOut'.
+          bool createPrivateMemref = !writesToLiveInOrOut;
+          // Don't create a private memref if 'srcNode' has in edges on
+          // 'memref', or if 'dstNode' has out edges on 'memref'.
+          if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) > 0 ||
+              mdg->getOutEdgeCount(dstNode->id, memref) > 0) {
+            createPrivateMemref = false;
           }
 
+          // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'.
+          if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount)
+            continue;
+
           // Compute an operation list insertion point for the fused loop
           // nest which preserves dependences.
-          Operation *fusedLoopInsPoint =
+          Operation *insertPointInst =
               mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
-          if (fusedLoopInsPoint == nullptr)
+          if (insertPointInst == nullptr)
             continue;
 
-          // Compute the innermost common loop depth for dstNode
-          // producer-consumer loads/stores.
+          auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
+          auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
+
+          // Compute the innermost common loop depth for dstNode loads/stores.
           SmallVector<Operation *, 2> dstMemrefOps;
           for (Operation *op : dstNode->loads)
-            if (producerConsumerMemrefs.count(
-                    cast<AffineReadOpInterface>(op).getMemRef()) > 0)
+            if (cast<AffineReadOpInterface>(op).getMemRef() == memref)
               dstMemrefOps.push_back(op);
           for (Operation *op : dstNode->stores)
-            if (producerConsumerMemrefs.count(
-                    cast<AffineWriteOpInterface>(op).getMemRef()))
+            if (cast<AffineWriteOpInterface>(op).getMemRef() == memref)
               dstMemrefOps.push_back(op);
           unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps);
 
@@ -1474,7 +1502,7 @@ struct GreedyFusion {
           unsigned maxLegalFusionDepth = 0;
           SmallVector<ComputationSliceState, 8> depthSliceUnions;
           depthSliceUnions.resize(dstLoopDepthTest);
-          FusionStrategy strategy(FusionStrategy::ProducerConsumer);
+          FusionStrategy strategy(FusionStrategy::ProducerConsumer, memref);
           for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
             FusionResult result = mlir::canFuseLoops(
                 srcAffineForOp, dstAffineForOp,
@@ -1484,82 +1512,27 @@ struct GreedyFusion {
               maxLegalFusionDepth = i;
           }
 
-          if (maxLegalFusionDepth == 0) {
-            LLVM_DEBUG(llvm::dbgs()
-                       << "Can't fuse: fusion is not legal at any depth\n");
+          // Skip if fusion is not feasible at any loop depths.
+          if (maxLegalFusionDepth == 0)
             continue;
-          }
 
           // Check if fusion would be profitable. We skip profitability analysis
           // for maximal fusion since we already know the maximal legal depth to
           // fuse.
           unsigned bestDstLoopDepth = maxLegalFusionDepth;
-          if (!maximalFusion) {
-            // Retrieve producer stores from the src loop.
-            SmallVector<Operation *, 2> producerStores;
-            for (Operation *op : srcNode->stores)
-              if (producerConsumerMemrefs.count(
-                      cast<AffineWriteOpInterface>(op).getMemRef()))
-                producerStores.push_back(op);
-
-            // TODO: Suppport multiple producer stores in profitability
-            // analysis. We limit profitability analysis to only scenarios with
-            // a single producer store for now. Note that some multi-store
-            // producer scenarios will still go through profitability analysis
-            // if only one of the stores is involved the producer-consumer
-            // relationship of the candidate loops.
-            assert(producerStores.size() > 0 && "Expected producer store");
-            if (producerStores.size() > 1)
-              LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
-                                         "supported for this case\n");
-            else if (!isFusionProfitable(producerStores[0], producerStores[0],
-                                         dstAffineForOp, depthSliceUnions,
-                                         maxLegalFusionDepth, &bestDstLoopDepth,
-                                         computeToleranceThreshold))
-              continue;
-          }
+          if (!maximalFusion &&
+              !isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts,
+                                  depthSliceUnions, maxLegalFusionDepth,
+                                  &bestDstLoopDepth, computeToleranceThreshold))
+            continue;
 
           assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
-          ComputationSliceState &bestSlice =
-              depthSliceUnions[bestDstLoopDepth - 1];
-          assert(!bestSlice.isEmpty() && "Missing slice union for depth");
-
-          // Determine if 'srcId' can be removed after fusion, taking into
-          // account remaining dependences, escaping memrefs and the fusion
-          // insertion point.
-          bool removeSrcNode = canRemoveSrcNodeAfterFusion(
-              srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
-              mdg);
-
-          DenseSet<Value> privateMemrefs;
-          for (Value memref : producerConsumerMemrefs) {
-            // Don't create a private memref if 'srcNode' writes to escaping
-            // memrefs.
-            if (srcEscapingMemRefs.count(memref) > 0)
-              continue;
-
-            // Don't create a private memref if 'srcNode' has in edges on
-            // 'memref' or 'dstNode' has out edges on 'memref'.
-            if (mdg->getIncomingMemRefAccesses(srcId, memref) > 0 ||
-                mdg->getOutEdgeCount(dstId, memref) > 0)
-              continue;
-
-            // If 'srcNode' will be removed but it has out edges on 'memref' to
-            // nodes other than 'dstNode', we have to preserve dependences and
-            // cannot create a private memref.
-            if (removeSrcNode &&
-                any_of(mdg->outEdges[srcId], [&](const auto &edge) {
-                  return edge.value == memref && edge.id != dstId;
-                }))
-              continue;
-
-            // Create a private version of this memref.
-            privateMemrefs.insert(memref);
-          }
+          assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
+                 "Missing slice union for depth");
 
           // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
-          fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
-          dstNodeChanged = true;
+          fuseLoops(srcAffineForOp, dstAffineForOp,
+                    depthSliceUnions[bestDstLoopDepth - 1]);
 
           LLVM_DEBUG(llvm::dbgs()
                      << "Fused src loop " << srcId << " into dst loop " << dstId
@@ -1567,20 +1540,18 @@ struct GreedyFusion {
                      << dstAffineForOp << "\n");
 
           // Move 'dstAffineForOp' before 'insertPointInst' if needed.
-          if (fusedLoopInsPoint != dstAffineForOp.getOperation())
-            dstAffineForOp.getOperation()->moveBefore(fusedLoopInsPoint);
+          if (insertPointInst != dstAffineForOp.getOperation())
+            dstAffineForOp->moveBefore(insertPointInst);
 
           // Update edges between 'srcNode' and 'dstNode'.
-          mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
-                           removeSrcNode);
+          mdg->updateEdges(srcNode->id, dstNode->id, memref,
+                           createPrivateMemref);
 
           // Collect slice loop stats.
           LoopNestStateCollector dstForCollector;
           dstForCollector.collect(dstAffineForOp);
-          for (Value memref : privateMemrefs) {
+          if (createPrivateMemref) {
             // Create private memref for 'memref' in 'dstAffineForOp'.
-            // TODO: remove storesForMemref and move the code below to the
-            // loop-if.
             SmallVector<Operation *, 4> storesForMemref;
             for (auto *storeOpInst : dstForCollector.storeOpInsts) {
               if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() ==
@@ -1592,6 +1563,7 @@ struct GreedyFusion {
             auto newMemRef = createPrivateMemRef(
                 dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
                 fastMemorySpace, localBufSizeThreshold);
+            visitedMemrefs.insert(newMemRef);
             // Create new node in dependence graph for 'newMemRef' alloc op.
             unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
             // Add edge from 'newMemRef' node to dstNode.
@@ -1602,21 +1574,58 @@ struct GreedyFusion {
           LoopNestStateCollector dstLoopCollector;
           dstLoopCollector.collect(dstAffineForOp.getOperation());
 
+          // Add new load ops to current Node load op list 'loads' to continue
+          // fusing based on new operands.
+          for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
+            // NOTE: Change 'loads' to a hash set in case efficiency is an
+            // issue. We still use a vector since it's expected to be small.
+            if (!llvm::is_contained(loads, loadOpInst))
+              loads.push_back(loadOpInst);
+          }
+          // Clear visited memrefs after fusion so that previously visited src
+          // nodes are considered for fusion again in the context of the new
+          // fused node.
+          // TODO: This shouldn't be necessary if we visited candidates in the
+          // dependence graph in post-order or once we fully support multi-store
+          // producers. Currently, in a multi-store producer scenario such as
+          // A->B, A->C, B->C, we fail to fuse A+B due to the multiple outgoing
+          // edges. However, after fusing B+C, A has a single outgoing edge and
+          // can be fused if we revisit it in the context of the new fused B+C
+          // node.
+          visitedMemrefs.clear();
+
           // Clear and add back loads and stores.
           mdg->clearNodeLoadAndStores(dstNode->id);
           mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
                          dstLoopCollector.storeOpInsts);
-
-          if (removeSrcNode) {
-            LLVM_DEBUG(llvm::dbgs()
-                       << "Removing src loop " << srcId << " after fusion\n");
-            // srcNode is no longer valid after it is removed from mdg.
-            srcAffineForOp.erase();
-            mdg->removeNode(srcId);
-            srcNode = nullptr;
+          // Remove old src loop nest if it no longer has outgoing dependence
+          // edges, and if it does not write to a memref which escapes the
+          // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has been
+          // fused into 'dstNode' and write region of 'dstNode' covers the write
+          // region of 'srcNode', and 'srcNode' has no other users so it is safe
+          // to remove.
+          if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) {
+            mdg->removeNode(srcNode->id);
+            srcNode->op->erase();
+          } else {
+            // Add remaining users of 'oldMemRef' back on the worklist (if not
+            // already there), as its replacement with a local/private memref
+            // has reduced dependences on 'oldMemRef' which may have created new
+            // fusion opportunities.
+            if (mdg->outEdges.count(srcNode->id) > 0) {
+              SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges =
+                  mdg->outEdges[srcNode->id];
+              for (auto &outEdge : oldOutEdges) {
+                if (outEdge.value == memref &&
+                    worklistSet.count(outEdge.id) == 0) {
+                  worklist.push_back(outEdge.id);
+                  worklistSet.insert(outEdge.id);
+                }
+              }
+            }
           }
         }
-      } while (dstNodeChanged);
+      }
     }
   }
 
@@ -1627,6 +1636,7 @@ struct GreedyFusion {
     while (!worklist.empty()) {
       unsigned dstId = worklist.back();
       worklist.pop_back();
+      worklistSet.erase(dstId);
 
       // Skip if this node was removed (fused into another node).
       if (mdg->nodes.count(dstId) == 0)
@@ -1688,7 +1698,7 @@ struct GreedyFusion {
       SmallVector<ComputationSliceState, 8> depthSliceUnions;
       depthSliceUnions.resize(dstLoopDepthTest);
       unsigned maxLegalFusionDepth = 0;
-      FusionStrategy strategy(memref);
+      FusionStrategy strategy(FusionStrategy::Sibling, memref);
       for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
         FusionResult result = mlir::canFuseLoops(
             sibAffineForOp, dstAffineForOp,
@@ -1702,10 +1712,10 @@ struct GreedyFusion {
       if (maxLegalFusionDepth == 0)
         continue;
 
-      unsigned bestDstLoopDepth = maxLegalFusionDepth;
+      unsigned bestDstLoopDepth = dstLoopDepthTest;
       if (!maximalFusion) {
         // Check if fusion would be profitable.
-        if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstAffineForOp,
+        if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts,
                                 depthSliceUnions, maxLegalFusionDepth,
                                 &bestDstLoopDepth, computeToleranceThreshold))
           continue;

diff  --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
index 9749a8de2351..9759300f2e42 100644
--- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
@@ -191,8 +191,11 @@ gatherLoadsAndStores(AffineForOp forOp,
 /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences.
 // TODO: Generalize this check for sibling and more generic fusion scenarios.
 // TODO: Support forward slice fusion.
-static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
-                                ArrayRef<Operation *> dstOps) {
+static unsigned getMaxLoopDepth(ArrayRef<Operation *> dstOps,
+                                FusionStrategy fusionStrategy) {
+  assert(fusionStrategy.strategy == FusionStrategy::ProducerConsumer &&
+         "Fusion strategy not supported");
+
   if (dstOps.empty())
     // Expected at least one memory operation.
     // TODO: Revisit this case with a specific example.
@@ -200,14 +203,15 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
 
   // Filter out ops in 'dstOps' that do not use the producer-consumer memref so
   // that they are not considered for analysis.
-  DenseSet<Value> producerConsumerMemrefs;
-  gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs);
+  // TODO: Currently, we pass the producer-consumer memref through
+  // fusionStrategy. We will retrieve the memrefs from 'srcOps' once we
+  // generalize the algorithm.
   SmallVector<Operation *, 4> targetDstOps;
   for (Operation *dstOp : dstOps) {
     auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp);
     Value memref = loadOp ? loadOp.getMemRef()
                           : cast<AffineWriteOpInterface>(dstOp).getMemRef();
-    if (producerConsumerMemrefs.count(memref) > 0)
+    if (memref == fusionStrategy.memref)
       targetDstOps.push_back(dstOp);
   }
 
@@ -304,10 +308,10 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
   // loop dependences.
   // TODO: Enable this check for sibling and more generic loop fusion
   // strategies.
-  if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) {
+  if (fusionStrategy.strategy == FusionStrategy::ProducerConsumer) {
     // TODO: 'getMaxLoopDepth' does not support forward slice fusion.
     assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
-    if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) {
+    if (getMaxLoopDepth(opsB, fusionStrategy) < dstLoopDepth) {
       LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
       return FusionResult::FailFusionDependence;
     }
@@ -320,7 +324,7 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
   // Filter out ops in 'opsA' to compute the slice union based on the
   // assumptions made by the fusion strategy.
   SmallVector<Operation *, 4> strategyOpsA;
-  switch (fusionStrategy.getStrategy()) {
+  switch (fusionStrategy.strategy) {
   case FusionStrategy::Generic:
     // Generic fusion. Take into account all the memory operations to compute
     // the slice union.
@@ -328,9 +332,10 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
     break;
   case FusionStrategy::ProducerConsumer:
     // Producer-consumer fusion (AffineLoopFusion pass) only takes into
-    // account stores in 'srcForOp' to compute the slice union.
+    // account stores to 'memref' in 'srcForOp' to compute the slice union.
     for (Operation *op : opsA) {
-      if (isa<AffineWriteOpInterface>(op))
+      auto store = dyn_cast<AffineWriteOpInterface>(op);
+      if (store && store.getMemRef() == fusionStrategy.memref)
         strategyOpsA.push_back(op);
     }
     break;
@@ -339,7 +344,7 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
     // to 'memref' in 'srcForOp' to compute the slice union.
     for (Operation *op : opsA) {
       auto load = dyn_cast<AffineReadOpInterface>(op);
-      if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef())
+      if (load && load.getMemRef() == fusionStrategy.memref)
         strategyOpsA.push_back(op);
     }
     break;
@@ -623,23 +628,3 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
                            /*tripCountOverrideMap=*/nullptr, &computeCostMap);
   return true;
 }
-
-/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
-/// producer-consumer dependence between write ops in 'srcOps' and read ops in
-/// 'dstOps'.
-void mlir::gatherProducerConsumerMemrefs(
-    ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps,
-    DenseSet<Value> &producerConsumerMemrefs) {
-  // Gather memrefs from stores in 'srcOps'.
-  DenseSet<Value> srcStoreMemRefs;
-  for (Operation *op : srcOps)
-    if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op))
-      srcStoreMemRefs.insert(storeOp.getMemRef());
-
-  // Compute the intersection between memrefs from stores in 'srcOps' and
-  // memrefs from loads in 'dstOps'.
-  for (Operation *op : dstOps)
-    if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
-      if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0)
-        producerConsumerMemrefs.insert(loadOp.getMemRef());
-}

diff  --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir
index c1bccea4c9f5..a23f0e2ee430 100644
--- a/mlir/test/Transforms/loop-fusion.mlir
+++ b/mlir/test/Transforms/loop-fusion.mlir
@@ -364,8 +364,8 @@ func @should_fuse_and_move_to_preserve_war_dep() {
 
 // -----
 
-// CHECK-LABEL: func @should_fuse_if_top_level_access() {
-func @should_fuse_if_top_level_access() {
+// CHECK-LABEL: func @should_fuse_with_private_memref_if_top_level_access() {
+func @should_fuse_with_private_memref_if_top_level_access() {
   %m = alloc() : memref<10xf32>
   %cf7 = constant 7.0 : f32
 
@@ -378,45 +378,14 @@ func @should_fuse_if_top_level_access() {
 
   %c0 = constant 4 : index
   %v1 = affine.load %m[%c0] : memref<10xf32>
-  // Top-level load to '%m' should prevent creating a private memref but
-  // loop nests should be fused and '%i0' should be removed.
-  // CHECK:      %[[m:.*]] = alloc() : memref<10xf32>
-  // CHECK-NOT:  alloc
-
-  // CHECK:      affine.for %[[i1:.*]] = 0 to 10 {
-  // CHECK-NEXT:   affine.store %{{.*}}, %[[m]][%[[i1]]] : memref<10xf32>
-  // CHECK-NEXT:   affine.load %[[m]][%[[i1]]] : memref<10xf32>
-  // CHECK-NEXT: }
-  // CHECK:      affine.load %[[m]][%{{.*}}] : memref<10xf32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: func @should_fuse_but_not_remove_src() {
-func @should_fuse_but_not_remove_src() {
-  %m = alloc() : memref<100xf32>
-  %cf7 = constant 7.0 : f32
-
-  affine.for %i0 = 0 to 100 {
-    affine.store %cf7, %m[%i0] : memref<100xf32>
-  }
-  affine.for %i1 = 0 to 17 {
-    %v0 = affine.load %m[%i1] : memref<100xf32>
-  }
-  %v1 = affine.load %m[99] : memref<100xf32>
-
-  // Loop '%i0' and '%i1' should be fused but '%i0' shouldn't be removed to
-  // preserve the dependence with the top-level access.
-  // CHECK:      affine.for %{{.*}} = 0 to 100 {
-  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<100xf32>
+  // Top-level load to '%{{.*}}' should prevent fusion.
+  // CHECK:      affine.for %{{.*}} = 0 to 10 {
+  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT: }
-  // CHECK-NEXT: affine.for %{{.*}} = 0 to 17 {
+  // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
   // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
   // CHECK-NEXT:   affine.load %{{.*}}[0] : memref<1xf32>
   // CHECK-NEXT: }
-  // CHECK-NEXT: affine.load %{{.*}}[99] : memref<100xf32>
-  // CHECK-NEXT: return
   return
 }
 
@@ -1141,8 +1110,8 @@ func @should_fuse_with_private_memrefs_with_
diff _shapes() {
 
 // -----
 
-// CHECK-LABEL: func @should_fuse_live_out_arg_but_preserve_src_loop(%{{.*}}: memref<10xf32>) {
-func @should_fuse_live_out_arg_but_preserve_src_loop(%arg0: memref<10xf32>) {
+// CHECK-LABEL: func @should_not_fuse_live_out_arg(%{{.*}}: memref<10xf32>) {
+func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) {
   %cf7 = constant 7.0 : f32
 
   affine.for %i0 = 0 to 10 {
@@ -1160,7 +1129,6 @@ func @should_fuse_live_out_arg_but_preserve_src_loop(%arg0: memref<10xf32>) {
   // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:  }
   // CHECK-NEXT:  affine.for %{{.*}} = 0 to 9 {
-  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:  }
   // CHECK-NEXT:  return
@@ -1192,8 +1160,8 @@ func @should_fuse_live_out_arg(%arg0: memref<10xf32>) {
 
 // -----
 
-// CHECK-LABEL: func @should_fuse_escaping_memref_but_preserve_src_loop() -> memref<10xf32>
-func @should_fuse_escaping_memref_but_preserve_src_loop() -> memref<10xf32> {
+// CHECK-LABEL: func @should_not_fuse_escaping_memref() -> memref<10xf32>
+func @should_not_fuse_escaping_memref() -> memref<10xf32> {
   %cf7 = constant 7.0 : f32
   %m = alloc() : memref<10xf32>
   affine.for %i0 = 0 to 10 {
@@ -1202,21 +1170,19 @@ func @should_fuse_escaping_memref_but_preserve_src_loop() -> memref<10xf32> {
   affine.for %i1 = 0 to 9 {
     %v0 = affine.load %m[%i1] : memref<10xf32>
   }
-  // This tests that the loop nest '%i0' should not be removed after fusion
-  // because it writes to memref '%m', which is returned by the function, and
-  // the '%i1' memory region does not cover '%i0' memory region.
-
+  // This tests that the loop nest '%{{.*}}' should not be removed after fusion
+  // because it writes to memref '%{{.*}}' which is returned by the function.
   // CHECK-DAG:   alloc() : memref<10xf32>
   // CHECK:       affine.for %{{.*}} = 0 to 10 {
   // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:  }
   // CHECK-NEXT:  affine.for %{{.*}} = 0 to 9 {
-  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:  }
   // CHECK-NEXT:  return %{{.*}} : memref<10xf32>
   return %m : memref<10xf32>
 }
+
 // -----
 
 // This should fuse with the %in becoming a 1x1x1.
@@ -1264,7 +1230,7 @@ func @R3_to_R2_reshape() {
 
 // -----
 
-func @should_fuse_multi_output_producer() {
+func @should_not_fuse_multi_output_producer() {
   %a = alloc() : memref<10xf32>
   %b = alloc() : memref<10xf32>
 
@@ -1280,10 +1246,12 @@ func @should_fuse_multi_output_producer() {
   }
 
   // CHECK:       affine.for %{{.*}} = 0 to 10 {
-  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
-  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
-  // CHECK-NEXT:    affine.load %{{.*}}[0] : memref<1xf32>
-  // CHECK-NEXT:    affine.load %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:  }
+  // CHECK-NEXT:  affine.for %{{.*}} = 0 to 10 {
+  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:  }
   // CHECK-NEXT:  return
   return
@@ -1536,8 +1504,8 @@ func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32>
 
 // -----
 
-// CHECK-LABEL: func @should_fuse_only_two_loops_and_remove_producer() {
-func @should_fuse_only_two_loops_and_remove_producer() {
+// CHECK-LABEL: func @should_fuse_after_private_memref_creation() {
+func @should_fuse_after_private_memref_creation() {
   %a = alloc() : memref<10xf32>
   %b = alloc() : memref<10xf32>
 
@@ -1557,21 +1525,18 @@ func @should_fuse_only_two_loops_and_remove_producer() {
 
   // On the first visit to '%i2', the fusion algorithm can not fuse loop nest
   // '%i0' into '%i2' because of the dependences '%i0' and '%i2' each have on
-  // '%i1'. Then, '%i0' is fused into '%i1' and no private memref is created for
-  // memref '%a' to be able to remove '%i0' and still preserve the depencence on
-  // '%a' with '%i2'.
-  // TODO: Alternatively, we could fuse '%i0' into '%i1' with a private memref,
-  // the dependence between '%i0' and '%i1' on memref '%a' would no longer exist,
-  // and '%i0' could be fused into '%i2' as well. Note that this approach would
-  // duplicate the computation in loop nest '%i0' to loop nests '%i1' and '%i2',
-  // which would limit its profitability.
+  // '%i1'. However, once the loop nest '%i0' is fused into '%i1' with a
+  // private memref, the dependence between '%i0' and '%i1' on memref '%a' no
+  // longer exists, so '%i0' can now be fused into '%i2'.
+
   // CHECK:       affine.for %{{.*}} = 0 to 10 {
-  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT:    affine.load %{{.*}}[0] : memref<1xf32>
   // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:  }
   // CHECK-NEXT:  affine.for %{{.*}} = 0 to 10 {
-  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT:    affine.load %{{.*}}[0] : memref<1xf32>
   // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
   // CHECK-NEXT:  }
   // CHECK-NEXT:   return
@@ -2255,7 +2220,7 @@ func @affine_2_dependent_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<10
     }
   }
 
-  // CHECK:        affine.for %{{.*}} = 0 to 1024 {
+  // CHECK:  affine.for %{{.*}} = 0 to 1024 {
   // CHECK-NEXT:     affine.for %{{.*}} = 0 to 1024 {
   // CHECK-NEXT:       affine.for %{{.*}} = 0 to 1024 {
   // CHECK-NEXT:         affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
@@ -2346,8 +2311,8 @@ func @should_fuse_function_live_out_multi_store_producer(%live_in_out_m : memref
   }
   // CHECK:      affine.for %[[i0:.*]] = 0 to 10 {
   // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32>
-  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
-  // CHECK-NEXT:   affine.load %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32>
+  // CHECK-NEXT:   affine.load %{{.*}}[%[[i0]]] : memref<10xf32>
   // CHECK-NEXT: }
   // CHECK-NEXT: return
   return
@@ -2408,11 +2373,12 @@ func @mul_add_0(%arg0: memref<3x4xf32>, %arg1: memref<4x3xf32>, %arg2: memref<3x
 
 // -----
 
-// Verify that 'fuseProducerConsumerNodes' fuse a producer loop with a store
-// that has multiple outgoing edges.
+// Verify that 'fuseProducerConsumerNodes' doesn't fuse a producer loop with
+// a store that has multiple outgoing edges. Sibling loop fusion should not fuse
+// any of these loops due to dependencies on external memref '%a'.
 
-// CHECK-LABEL: func @should_fuse_multi_outgoing_edge_store_producer
-func @should_fuse_multi_outgoing_edge_store_producer(%a : memref<1xf32>) {
+// CHECK-LABEL: func @should_not_fuse_multi_outgoing_edge_store_producer1
+func @should_not_fuse_multi_outgoing_edge_store_producer1(%a : memref<1xf32>) {
   %cst = constant 0.000000e+00 : f32
   affine.for %arg0 = 0 to 1 {
     affine.store %cst, %a[%arg0] : memref<1xf32>
@@ -2425,12 +2391,9 @@ func @should_fuse_multi_outgoing_edge_store_producer(%a : memref<1xf32>) {
   affine.for %arg0 = 0 to 1 {
     %0 = affine.load %a[%arg0] : memref<1xf32>
   }
-  // CHECK:      affine.for %{{.*}} = 0 to 1 {
-  // CHECK-NEXT:   affine.store
-  // CHECK-NEXT:   affine.load
-  // CHECK-NEXT:   affine.load
-  // CHECK-NEXT: }
-
+  // CHECK: affine.for %{{.*}} = 0 to 1
+  // CHECK: affine.for %{{.*}} = 0 to 1
+  // CHECK: affine.for %{{.*}} = 0 to 1
   return
 }
 
@@ -2700,109 +2663,3 @@ func @fuse_minor_affine_map(%in: memref<128xf32>, %out: memref<20x512xf32>) {
 // MAXIMAL:       affine.for
 // MAXIMAL-NEXT:    affine.for
 // MAXIMAL-NOT:   affine.for
-// MAXIMAL:       return
-
-// -----
-
-// CHECK-LABEL: func @should_fuse_multi_store_producer_and_privatize_memfefs
-func @should_fuse_multi_store_producer_and_privatize_memfefs() {
-  %a = alloc() : memref<10xf32>
-  %b = alloc() : memref<10xf32>
-  %c = alloc() : memref<10xf32>
-  %cst = constant 0.000000e+00 : f32
-  affine.for %arg0 = 0 to 10 {
-    affine.store %cst, %a[%arg0] : memref<10xf32>
-    affine.store %cst, %b[%arg0] : memref<10xf32>
-    affine.store %cst, %c[%arg0] : memref<10xf32>
-    %0 = affine.load %c[%arg0] : memref<10xf32>
-  }
-
-  affine.for %arg0 = 0 to 10 {
-    %0 = affine.load %a[%arg0] : memref<10xf32>
-  }
-
-  affine.for %arg0 = 0 to 10 {
-    %0 = affine.load %b[%arg0] : memref<10xf32>
-  }
-
-	// All the memrefs should be privatized except '%c', which is not involved in
-  // the producer-consumer fusion.
-  // CHECK:      affine.for %{{.*}} = 0 to 10 {
-  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
-  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
-  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT:   affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT:   affine.load %{{.*}}[0] : memref<1xf32>
-  // CHECK-NEXT:   affine.load %{{.*}}[0] : memref<1xf32>
-  // CHECK-NEXT: }
-
-  return
-}
-
-// -----
-
-func @should_fuse_multi_store_producer_with_scaping_memrefs_and_remove_src(
-    %a : memref<10xf32>, %b : memref<10xf32>) {
-  %cst = constant 0.000000e+00 : f32
-  affine.for %i0 = 0 to 10 {
-    affine.store %cst, %a[%i0] : memref<10xf32>
-    affine.store %cst, %b[%i0] : memref<10xf32>
-  }
-
-  affine.for %i1 = 0 to 10 {
-    %0 = affine.load %a[%i1] : memref<10xf32>
-  }
-
-  affine.for %i2 = 0 to 10 {
-    %0 = affine.load %b[%i2] : memref<10xf32>
-  }
-
-	// Producer loop '%i0' should be removed after fusion since fusion is maximal.
-  // No memref should be privatized since they escape the function.
-  // CHECK:       affine.for %{{.*}} = 0 to 10 {
-  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT:  }
-  // CHECK-NOT:   affine.for
-
-  return
-}
-
-// -----
-
-func @should_fuse_multi_store_producer_with_scaping_memrefs_and_preserve_src(
-    %a : memref<10xf32>, %b : memref<10xf32>) {
-  %cst = constant 0.000000e+00 : f32
-  affine.for %i0 = 0 to 10 {
-    affine.store %cst, %a[%i0] : memref<10xf32>
-    affine.store %cst, %b[%i0] : memref<10xf32>
-  }
-
-  affine.for %i1 = 0 to 5 {
-    %0 = affine.load %a[%i1] : memref<10xf32>
-  }
-
-  affine.for %i2 = 0 to 10 {
-    %0 = affine.load %b[%i2] : memref<10xf32>
-  }
-
-	// Loops '%i0' and '%i2' should be fused first and '%i0' should be removed
-  // since fusion is maximal. Then the fused loop and '%i1' should be fused
-  // and the fused loop shouldn't be removed since fusion is not maximal.
-  // CHECK:       affine.for %{{.*}} = 0 to 10 {
-  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT:  }
-  // CHECK:       affine.for %{{.*}} = 0 to 5 {
-  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT:  }
-  // CHECK-NOT:   affine.for
-
-  return
-}


        


More information about the llvm-branch-commits mailing list