[Mlir-commits] [mlir] 0caa82e - Revert "[mlir][Linalg] Fuse sequence of Linalg operation (on buffers)"

Mikhail Goncharov llvmlistbot at llvm.org
Fri Nov 20 04:13:12 PST 2020


Author: Mikhail Goncharov
Date: 2020-11-20T13:12:54+01:00
New Revision: 0caa82e2ac53b2ff475531086dfe648fb2d6158a

URL: https://github.com/llvm/llvm-project/commit/0caa82e2ac53b2ff475531086dfe648fb2d6158a
DIFF: https://github.com/llvm/llvm-project/commit/0caa82e2ac53b2ff475531086dfe648fb2d6158a.diff

LOG: Revert "[mlir][Linalg] Fuse sequence of Linalg operation (on buffers)"

This reverts commit f8284d21a8e294d58a0acd4b8b2e906d7a9f110c.

Revert "[mlir][Linalg] NFC: Expose some utility functions used for promotion."

This reverts commit 0c59f51592ef5c014352994369f5216c6376fae1.

Revert "Remove unused isZero function"

This reverts commit 0f9f0a4046e11c2b4c130640f343e3b2b5db08c1.

Change f8284d21 led to multiple failures in IREE compilation.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/fusion-pattern.mlir
    mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    mlir/test/Dialect/Linalg/fusion-sequence.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 87ff2a97d93f..8d531a1e343a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -37,6 +37,14 @@ struct TiledLinalgOp {
   SmallVector<Value, 4> tensorResults;
 };
 
+struct TiledAndFusedLinalgOps {
+  LinalgOp op;
+  SmallVector<LinalgOp, 1> fusedProducers;
+  SmallVector<LinalgOp, 1> originalProducers;
+  SmallVector<Operation *, 4> fusedLoops;
+  SmallVector<Operation *, 4> unfusedLoops;
+};
+
 /// Populates patterns for vectorization of all ConvN-D ops.
 void populateConvVectorizationPatterns(
     MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
@@ -65,11 +73,14 @@ void populateLinalgBufferizePatterns(MLIRContext *context,
 Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
                                      const LinalgTilingOptions &options);
 
-/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This
-/// proceeds as follows:
-/// - Find outer parallel loops in these ops that can be fused.
-/// - Tile fusable outer parallel loops of the last operation in the sequence.
-/// - Fuse the remaining operations with the tiled operation
+/// Tile and fuse the `op` with its producers. The tile and fuse proceeds in
+/// three steps
+/// - Find tile loops that are fusable with its producer tile loops (a.k.a. tile
+///   + fuse loops).
+/// - Tile just these loops of the consumer (root operation) and fuse with
+///   the producer.
+/// - Tile again the tiled consumer operation produced above to do rest of
+///   the tiling specified by the `tilingOptions`.
 ///
 /// For example, consider the sequence of matmul below
 ///
@@ -96,39 +107,36 @@ Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
 ///     : memref<256x32xf32> to memref<16x32xf32, #map0>
 ///   %3 = subview %arg1[0, 0] [32, 32] [1, 1]
 ///     : memref<32x32xf32> to memref<32x32xf32, #map1>
-///   %4 = subview %arg3[0, 0] [32, 32] [1, 1]
-///     : memref<32x32xf32> to memref<32x32xf32, #map1>
 ///   linalg.matmul
 ///     ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>)
 ///     outs(%0 : memref<16x32xf32, #map0>)
-///   linalg.matmul
-///     ins(%0, %4 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
-///     outs(%1 : memref<16x8xf32, #map0>)
+///   scf.parallel (%arg6) = (%c0) to (%c32) step (%c8) {
+///   scf.for %arg7 = %c0 to %c32 step %c4 {
+///     %4 = subview %0[0, %arg7] [16, 4] [1, 1]
+///       : memref<16x32xf32, #map0> to memref<16x4xf32, #map0>
+///     %5 = subview %arg3[%arg7, %arg6] [4, 8] [1, 1]
+///       : memref<32x32xf32> to memref<4x8xf32, #map0>
+///     %6 = subview %1[0, %arg6] [16, 8] [1, 1]
+///       : memref<16x32xf32, #map0> to memref<16x8xf32, #map0>
+///     linalg.matmul
+///       ins(%4, %5 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
+///       outs(%6 : memref<16x8xf32, #map0>)
+///     }
+///     scf.yield
+///   }
+///   scf.yield
 /// }
 ///
-/// `tilingOptions` are used to tile the corresponding operation in `ops` (the
-/// size of the former should be same as size of the latter. Based on how
-/// tile+fuse is implemented, the fused loops are generated based on the last
-/// operation in the sequence. For example, the tile sizes for the fused loops
-/// is obtained from `tilingOptions.back()`. The following tiling options are
-/// handled 
diff erently in tile+fuse (compared to tile only)
+/// The following tiling options are handled 
diff erently in tile+fuse (compared
+/// to tile only)
 /// - Interchange of the tiling loops is not supported right now.
-/// - Only the fused loops are distributed.
-struct TiledAndFusedLinalgOps {
-  /// Operation obtained by tiling the last operation in sequence of `ops`
-  /// passed to `tileAndFuseLinalgOps`.
-  LinalgOp op;
-  /// The dimension of the loops that are fused.
-  std::set<unsigned> fusedLoopDims;
-  /// The generated fused operations (created within the fused loops).
-  SmallVector<LinalgOp, 1> fusedProducers;
-  /// The fused loop generated.
-  SmallVector<Operation *, 4> fusedLoops;
-};
+/// - Distribution is only done for the tile+fuse loops. The tiled loops
+///   generated by the second tiling is not distributed.
 Optional<TiledAndFusedLinalgOps>
-tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
+tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
                      const LinalgDependenceGraph &dependenceGraph,
-                     const LinalgTilingOptions &tilingOptions);
+                     const LinalgTilingOptions &tilingOptions,
+                     const LinalgFusionOptions &fusionOptions);
 
 /// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`.
 /// This is an in-place transformation controlled by `interchangeVector`.
@@ -234,20 +242,6 @@ struct LinalgPromotionOptions {
   }
 };
 
-/// Creates a new buffer using the `allocationFn` provided. The size of this
-/// buffer is the smallest constant bounding size along each dimension that can
-/// be computed for the size of the result of `subView`. Returns the allocated
-/// buffer as `fullLocalView` and the view that matches the size of the result
-/// of subview operation as `partialLocalView`.
-struct PromotionInfo {
-  Value fullLocalView;
-  Value partialLocalView;
-};
-Optional<PromotionInfo>
-promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, SubViewOp subView,
-                          AllocBufferCallbackFn allocationFn,
-                          OperationFolder *folder = nullptr);
-
 /// Promotes the `subViews` into a new buffer allocated at the insertion point
 /// `b`. Promotion occurs in 3 steps:
 ///   1. Create a new buffer for a full tile (i.e. not clipped at the boundary).

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 1eaf8b0e709c..f5669e383368 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -17,7 +17,6 @@
 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 
-#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SetVector.h"
 
 using mlir::edsc::intrinsics::AffineIndexedValue;
@@ -83,13 +82,6 @@ bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
 bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
                    Value consumedView, LinalgOp producer);
 
-using FusableOpDependencesTy = llvm::MapVector<
-    Operation *,
-    SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
-FusableOpDependencesTy
-findAllFusableDependences(ArrayRef<LinalgOp> ops,
-                          const LinalgDependenceGraph &dependenceGraph);
-
 /// Fuses producer into consumer if the producer is structurally feasible and
 /// the fusion would not violate dependencies.
 /// Implements the fusion part of the "tileAndFuse on buffers"

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 45a68fcba4a2..969bea4a4549 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -178,9 +178,6 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
     Value shape = en.value();
     SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
     for (auto en2 : llvm::enumerate(map.getResults())) {
-      auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
-      if (!dimExpr)
-        continue;
       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
         LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
                                 << loopDepth << "\n");
@@ -193,18 +190,49 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
   llvm_unreachable("Expect to be able to extract a shape defining loop range");
 }
 
-/// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges`
-/// provides the loop range information for the fused loops. The rest are
-/// obtained from the producer itself, since they are not tiled + fused.
-static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
-                     const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
+/// Fuses the producer of `producerIdx` into the loop immediately enclosing
+/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
+/// is needed just before the `consumer.
+///
+/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
+/// 2 cases:
+///   1. Buffer case: `producerIdx` is the index of the buffer in
+///      `producer.getOutputBuffers()`.
+///   2. Tensor case: `producerIdx` is the index of the tensor in
+///      `producer.getResults()`.
+static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
+                     LinalgOp consumer, unsigned consumerIdx) {
+  Operation *shapeProducingOp =
+      consumer.getShapedOperand(consumerIdx).getDefiningOp();
+  assert((isa<SubViewOp>(shapeProducingOp) ||
+          isa<SubTensorOp>(shapeProducingOp)) &&
+         "SubviewOp or SubTensorOp expected");
+
+  // loopToOperandRangesMaps are permutations-only by construction:
+  //   we can always identify a data dimension with a (at least one) loop
+  //   dimension.
+  // TODO: extend this with range inference.
+  AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
+  LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
+                          << ", producer map: " << producerMap << "\n");
 
   unsigned nPar = producer.getNumParallelLoops();
   unsigned nRed = producer.getNumReductionLoops();
   unsigned nWin = producer.getNumWindowLoops();
   SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
-  for (auto fusedLoops : fusedLoopsAndRanges)
-    loopRanges[fusedLoops.first] = fusedLoops.second;
+
+  // Iterate over dimensions identified by the producer map for `producerIdx`.
+  // This defines a subset of the loop ranges that we need to complete later.
+  auto loc = consumer.getLoc();
+  for (auto en : llvm::enumerate(producerMap.getResults())) {
+    unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
+    loopRanges[posInProducerLoop] =
+        isa<SubViewOp>(shapeProducingOp)
+            ? cast<SubViewOp>(shapeProducingOp)
+                  .getOrCreateRanges(b, loc)[en.index()]
+            : cast<SubTensorOp>(shapeProducingOp)
+                  .getOrCreateRanges(b, loc)[en.index()];
+  }
 
   // Iterate over all dimensions. For the dimensions not identified by the
   // producer map for `producerIdx`, we need to explicitly compute the shape
@@ -222,45 +250,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
     }
   }
 
-  return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges);
-}
-
-/// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
-/// expected to be defined by a subview op or a subtensor op.
-static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
-                                      Value shapedOperand, unsigned dim) {
-  Operation *shapeProducingOp = shapedOperand.getDefiningOp();
-  if (auto subViewOp = dyn_cast<SubViewOp>(shapeProducingOp))
-    return subViewOp.getOrCreateRanges(b, loc)[dim];
-  if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
-    return subTensorOp.getOrCreateRanges(b, loc)[dim];
-  llvm_unreachable("SubviewOp or SubTensorOp expected");
-}
-
-/// Fuses the producer of `producerIdx` into the loop immediately enclosing
-/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
-/// is needed just before the `consumer.
-///
-/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
-/// 2 cases:
-///   1. Buffer case: `producerIdx` is the index of the buffer in
-///      `producer.getOutputBuffers()`.
-///   2. Tensor case: `producerIdx` is the index of the tensor in
-///      `producer.getResults()`.
-static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
-                     LinalgOp consumer, unsigned consumerIdx) {
-  AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
-  LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
-                          << ", producer map: " << producerMap << "\n");
-  DenseMap<unsigned, Range> fusedLoopsAndRanges;
-  Location loc = consumer.getLoc();
-  Value shapedOperand = consumer.getShapedOperand(consumerIdx);
-  for (auto en : llvm::enumerate(producerMap.getResults())) {
-    unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
-    fusedLoopsAndRanges[posInProducerLoop] =
-        getRangeFromOperandShape(b, loc, shapedOperand, en.index());
-  }
-  return fuse(b, producer, fusedLoopsAndRanges);
+  return cloneWithLoopRanges(b, loc, producer, loopRanges);
 }
 
 // Encode structural fusion safety preconditions.
@@ -531,68 +521,9 @@ static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
   return getProjectedMap(map, projectedDims);
 }
 
-/// Returns the mapping from iterations in the consumer that write to the same
-/// location as the iterations in the producer. To do so use
-/// - indexing map of the fused view in the consumer : consumerIndexMap
-/// - indexing map of the fused view in the producer : producerIndexMap
-///     consumerLoopToProducerLoop =
-///       inverse(producerIndexMap).compose(consumerIndexMap)
-static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
-    LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
-  auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
-  AffineMap producerIndexingMap =
-      producer.getIndexingMap(dependence.dependentOpView.operandIndex);
-  auto consumer = cast<LinalgOp>(dependence.indexingOpView.op);
-  AffineMap consumerIndexingMap =
-      consumer.getIndexingMap(dependence.indexingOpView.operandIndex);
-
-  AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
-      producer.iterator_types().getValue(), producerIndexingMap);
-  if (!prunedProducerIndexingMap.isPermutation())
-    return None;
-
-  if (consumerIndexingMap.getNumResults() !=
-      prunedProducerIndexingMap.getNumResults())
-    return None;
-
-  LLVM_DEBUG({
-    llvm::dbgs() << "\t producerMap : ";
-    producerIndexingMap.print(llvm::dbgs());
-    llvm::dbgs() << "  pruned : ";
-    prunedProducerIndexingMap.print(llvm::dbgs());
-    llvm::dbgs() << "\n";
-    llvm::dbgs() << "\t consumerMap : ";
-    consumerIndexingMap.print(llvm::dbgs());
-    llvm::dbgs() << "\n";
-  });
-
-  AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
-  if (!invProducerIndexMap)
-    return None;
-
-  return invProducerIndexMap.compose(consumerIndexingMap);
-}
-
-/// Given a projected permutation `map`, returns true if the map changes the
-/// order in which the fused loop dimension appear.
-static bool doesTransposeAccess(AffineMap map,
-                                const std::set<unsigned> &fusableLoops) {
-  Optional<unsigned> lastFusableLoop;
-  for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
-         return expr.cast<AffineDimExpr>().getPosition();
-       })) {
-    if (!fusableLoops.count(pos))
-      continue;
-    if (!lastFusableLoop) {
-      lastFusableLoop = pos;
-      continue;
-    }
-    if (pos <= lastFusableLoop.getValue())
-      return true;
-    lastFusableLoop = pos;
-  }
-  return false;
-}
+using FusableOpDependencesTy = llvm::MapVector<
+    Operation *,
+    SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
 
 /// Returns the positions of the loop in `op` that can be tiled based on the
 /// operations that are to be fused with it. For example, in a
@@ -607,7 +538,13 @@ static bool doesTransposeAccess(AffineMap map,
 /// 2. Of the parallel loops only some can be fused. Only those loops can be
 ///    fused such where the fusable loops iteration space only touches one tile
 ///    of the fused operation. This is because the producer (which is writing
-///    the fused subview) has update semantics.
+///    the fused subview) has update semantics. To compute this,
+///    a. Find the mapping from iterations in the consumer that write to the
+///       same location as the iterations in the producer. To do so use
+///       - indexing map of the fused view in the consumer : consumerIndexMap
+///       - indexing map of the fused view in the producer : producerIndexMap
+///       consumerLoopToProducerLoop =
+///         inverse(producerIndexMap).compose(consumerIndexMap)
 ///
 /// Since an inverse computation is needed, we need to consider the projection
 /// of the producerIndexMap w.r.t the parallel loops.  The actual fusable loops
@@ -645,9 +582,8 @@ static bool doesTransposeAccess(AffineMap map,
 ///     submap with only parallel loops = affine_map<(i, j) -> (j)>
 ///   Fused dimensions : j
 static std::set<unsigned>
-collectFusableLoops(ArrayRef<LinalgOp> ops,
-                    const FusableOpDependencesTy &fusableDependences) {
-  assert(!ops.empty());
+collectTileAndFuseLoops(LinalgOp op,
+                        const FusableOpDependencesTy &fusableDependences) {
   auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
     return linalgOp.iterator_types()
         .getValue()
@@ -658,245 +594,289 @@ collectFusableLoops(ArrayRef<LinalgOp> ops,
         .size();
   };
 
-  size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
-  for (auto op : ops.drop_back()) {
+  LLVM_DEBUG({
+    llvm::dbgs() << "Op : ";
+    op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+    llvm::dbgs() << "\n";
+  });
+
+  size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
+  for (auto dependence : fusableDependences) {
+    linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
     numOuterParallelLoops =
-        std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
+        std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer));
   }
 
   std::set<unsigned> fusableLoops;
   auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
   fusableLoops.insert(range.begin(), range.end());
-
-  for (auto op : reverse(ops)) {
-    for (auto dependence : fusableDependences.lookup(op)) {
-      LLVM_DEBUG({
-        llvm::dbgs() << "\t fusable :";
-        for (unsigned i : fusableLoops)
-          llvm::dbgs() << " " << i;
-        llvm::dbgs() << "\n";
-      });
-
-      Optional<AffineMap> consumerLoopToProducerLoop =
-          getConsumerLoopToProducerLoopMap(dependence);
-      if (!consumerLoopToProducerLoop) {
-        op.emitRemark("failed to get map from consumer loop to producer loop");
-        return {};
-      }
-      // todo: This condition is only an implementation limitation. When fusing
-      // the operation, if the accesses in the producer/consumer are transposes
-      // of each other, the loop bounds for the tiled producer can be
-      // manipulated accordingly. This requires some additional bookkeeping in
-      // the implementation of tile+fuse that is defered to later.
-      if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
-        op.emitRemark("unhandled fusion when fusion requires permutation");
-        return {};
-      }
-
-      std::set<unsigned> candidates;
-      for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
-        unsigned position = expr.cast<AffineDimExpr>().getPosition();
-        if (fusableLoops.count(position))
-          candidates.insert(position);
-      }
-      LLVM_DEBUG({
-        llvm::dbgs() << "\t candidates :";
-        for (unsigned i : candidates)
-          llvm::dbgs() << " " << i;
-        llvm::dbgs() << "\n";
-      });
-      if (candidates.empty())
-        return {};
-      std::swap(candidates, fusableLoops);
+  for (auto dependence : fusableDependences) {
+    LLVM_DEBUG({
+      llvm::dbgs() << "\t fusable :";
+      for (unsigned i : fusableLoops)
+        llvm::dbgs() << " " << i;
+      llvm::dbgs() << "\n";
+    });
+    linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
+
+    assert(!dependence.second.empty() &&
+           "unexpected producer but not dependences");
+    AffineMap producerIndexingMap = producer.getIndexingMap(
+        dependence.second.front().dependentOpView.operandIndex);
+    AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
+        producer.iterator_types().getValue(), producerIndexingMap);
+    if (!prunedProducerIndexingMap.isPermutation())
+      return {};
+
+    AffineMap consumerIndexingMap = op.getIndexingMap(
+        dependence.second.front().indexingOpView.operandIndex);
+    if (consumerIndexingMap.getNumResults() !=
+        prunedProducerIndexingMap.getNumResults())
+      return {};
+
+    LLVM_DEBUG({
+      llvm::dbgs() << "\t producerMap : ";
+      producerIndexingMap.print(llvm::dbgs());
+      llvm::dbgs() << "  pruned : ";
+      prunedProducerIndexingMap.print(llvm::dbgs());
+      llvm::dbgs() << "\n";
+      llvm::dbgs() << "\t consumerMap : ";
+      consumerIndexingMap.print(llvm::dbgs());
+      llvm::dbgs() << "\n";
+    });
+
+    AffineMap invProducerIndexMap =
+        inversePermutation(prunedProducerIndexingMap);
+    if (!invProducerIndexMap)
+      return {};
+
+    AffineMap consumerLoopToProducerLoop =
+        invProducerIndexMap.compose(consumerIndexingMap);
+
+    LLVM_DEBUG({
+      llvm::dbgs() << "\t consumerLoopToProducerLoop : ";
+      consumerLoopToProducerLoop.print(llvm::dbgs());
+    });
+
+    std::set<unsigned> candidates;
+    for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) {
+      AffineDimExpr dimExpr = expr.dyn_cast<AffineDimExpr>();
+      if (!dimExpr)
+        continue;
+      unsigned position = dimExpr.getPosition();
+      if (fusableLoops.count(position))
+        candidates.insert(position);
     }
+    LLVM_DEBUG({
+      llvm::dbgs() << "\t candidates :";
+      for (unsigned i : candidates)
+        llvm::dbgs() << " " << i;
+      llvm::dbgs() << "\n";
+    });
+    if (candidates.empty())
+      return {};
+    std::swap(candidates, fusableLoops);
   }
 
   return fusableLoops;
 }
 
-/// Find all dependences that are fusable.
-FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
-    ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {
+/// Find all dependences that are to be fusable.
+static FusableOpDependencesTy
+findAllFusableDependences(LinalgOp op,
+                          const LinalgDependenceGraph &dependenceGraph,
+                          const LinalgFusionOptions &fusionOptions) {
   FusableOpDependencesTy fusableDependences;
   // TODO: Currently fusion would not be legal if the fusable dependence is to
   // the same producer but 
diff erent indexing map in the consumer. Fix this, but
   // in the meanwhile disallow such a fusion.
   DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
-  for (LinalgOp op : reverse(ops)) {
-    for (auto operandIndex :
-         llvm::seq<unsigned>(0, op.getNumInputsAndOutputBuffers())) {
-      Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
-          fusableDependence =
-              findFusableProducer(op, operandIndex, dependenceGraph);
-      if (!fusableDependence)
-        continue;
-      LinalgOp producerOp =
-          cast<LinalgOp>(fusableDependence->dependentOpView.op);
-      // Do not fuse dependences that are to operations not in the same basic
-      // block. This avoid moving fused operations across loops that might
-      // themselves carry dependency making the fusion illegal.
-      if (producerOp.getOperation()->getBlock() !=
-          op.getOperation()->getBlock()) {
-        op.emitRemark("unhandled fusion of ops in 
diff erent basic blocks");
-        return FusableOpDependencesTy{};
-      }
-      // Make sure that the indexing map of the view used for fusion in the
-      // producer is a projected permutation.
-      unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
-      AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
-      if (!producerMap.isProjectedPermutation()) {
-        op.emitRemark(
-            "unhandled non permutation indexing map for fused view in "
-            "producer for operand at index ")
-            << operandIndex;
-        return FusableOpDependencesTy{};
-      }
-
-      unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
-      AffineMap consumerMap = op.getIndexingMap(consumerIdx);
-      if (!consumerMap.isProjectedPermutation()) {
-        op.emitRemark(
-            "unhandled case where indexing map for fused view in the consumer "
-            "is "
-            "not a projected permuration while fusing at index ")
-            << operandIndex;
-        return FusableOpDependencesTy{};
-      }
+  for (auto operandIndex : fusionOptions.indicesToFuse) {
+    auto fusableDependence =
+        findFusableProducer(op, operandIndex, dependenceGraph);
+    if (!fusableDependence)
+      return FusableOpDependencesTy{};
+    LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
+    // Do not fuse dependences that are to operations not in the same basic
+    // block. This avoid moving fused operations across loops that might
+    // themselves carry dependency making the fusion illegal.
+    if (producerOp.getOperation()->getBlock() !=
+        op.getOperation()->getBlock()) {
+      op.emitRemark("unhandled fusion of ops in 
diff erent basic blocks");
+      return FusableOpDependencesTy{};
+    }
+    // Make sure that the indexing map of the view used for fusion in the
+    // producer is a projected permutation.
+    unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
+    AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
+    if (!producerMap.isProjectedPermutation()) {
+      op.emitRemark("unhandled non permutation indexing map for fused view in "
+                    "producer for operand at index ")
+          << operandIndex;
+      return FusableOpDependencesTy{};
+    }
 
-      // Check if the producer is already a fusion candidate. Cannot fuse this
-      // dependence if it has a 
diff erent indexing map when used in the
-      // consumer.
-      if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
-          fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
-        op.emitRemark(
-            "unhandled fusion to the same producer but with 
diff erent "
-            "indexing maps");
-        return FusableOpDependencesTy{};
-      }
-      fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
+    unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
+    AffineMap consumerMap = op.getIndexingMap(consumerIdx);
+    if (!consumerMap.isProjectedPermutation()) {
+      op.emitRemark(
+          "unhandled case where indexing map for fused view in the consumer is "
+          "not a projected permutation while fusing at index ")
+          << operandIndex;
+      return FusableOpDependencesTy{};
+    }
 
-      fusableDependences[producerOp.getOperation()].push_back(
-          *fusableDependence);
+    // Check if the producer is already a fusion candidate. Cannot fuse this
+    // dependence if it has a 
diff erent indexing map when used in the consumer.
+    if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
+        fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
+      op.emitRemark("unhandled fusion to the same producer but with 
diff erent "
+                    "indexing maps");
+      return FusableOpDependencesTy{};
     }
+    fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
+
+    fusableDependences[producerOp.getOperation()].push_back(*fusableDependence);
   }
   return fusableDependences;
 }
 
-/// Tile the fused loops in the root operation, by setting the tile sizes for
-/// all other loops to zero (those will be tiled later).
-static Optional<TiledLinalgOp> tileRootOperation(
-    OpBuilder &builder, LinalgOp op, ArrayRef<Value> tileSizeVector,
-    const LinalgTilingOptions &options, const std::set<unsigned> &fusedLoops) {
-  SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
-  auto zero = std_constant_index(0);
-  for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
-    if (!fusedLoops.count(i))
-      tileSizes[i] = zero;
-  LinalgTilingOptions tileFusedLoopsOptions = options;
-  tileFusedLoopsOptions.setTileSizes(tileSizes);
-  return tileLinalgOp(builder, op, tileFusedLoopsOptions);
-}
-
-/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
-/// to be a tiled operation such that it is valid to fuse all operations in
-/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
-/// `tiledOp`.
-static SmallVector<LinalgOp, 1>
-fuseOperations(OpBuilder &builder, LinalgOp tiledOp,
-               ArrayRef<LinalgOp> fusionCandidates,
-               const FusableOpDependencesTy &fusableDependences,
-               const std::set<unsigned> &fusedLoops) {
-  OpBuilder::InsertionGuard guard(builder);
-  builder.setInsertionPoint(tiledOp);
-  DenseMap<unsigned, Range> fusedLoopsAndRanges;
-  for (unsigned loop : fusedLoops) {
-    ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop);
-    fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
-        builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
-  }
-
-  SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
-  for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
-    LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges);
-    fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
-    builder.setInsertionPoint(fusedOp);
-  }
-  return fusedOps;
+static bool isZero(Value v) {
+  if (auto cst = v.getDefiningOp<ConstantIndexOp>())
+    return cst.getValue() == 0;
+  return false;
 }
 
 template <typename LoopType>
 static Optional<TiledAndFusedLinalgOps>
-tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
+tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
                          const LinalgDependenceGraph &dependenceGraph,
-                         const LinalgTilingOptions &tilingOptions) {
-  if (ops.empty())
-    return llvm::None;
-  LinalgOp rootOp = ops.back();
-  for (auto op : enumerate(ops)) {
-    // TODO: Nothing in the fusion of sequence of ops is specific to
-    // buffers. This check can be removed after it is tested on tensors.
-    LinalgOp linalgOp = op.value();
-    if (!linalgOp.hasBufferSemantics()) {
-      linalgOp.emitError("tile and fuse only tested for buffer operation");
-      return llvm::None;
-    }
-  }
-  // TODO: Support interchange with tile + fuse. This might actually help do
-  // better fusion.
+                         const LinalgTilingOptions &tilingOptions,
+                         const LinalgFusionOptions &fusionOptions) {
+  assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
+  // Some of the tiling options might not be supportable with tile and fuse.
+  // TODO: Support interchange with tile + fuse.
   if (!tilingOptions.interchangeVector.empty()) {
-    rootOp.emitError("unable to handle tile and fuse with interchange");
+    op.emitError("unable to handle tile and fuse with interchange");
     return llvm::None;
   }
 
-  OpBuilder::InsertionGuard guard(builder);
-  builder.setInsertionPoint(rootOp);
-  ScopedContext scope(builder, rootOp.getLoc());
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
+  ScopedContext scope(rewriter, op.getLoc());
 
   // Find all the producers.
   FusableOpDependencesTy fusableDependences =
-      findAllFusableDependences(ops, dependenceGraph);
+      findAllFusableDependences(op, dependenceGraph, fusionOptions);
   if (fusableDependences.empty())
     return llvm::None;
 
+  // Enforce the convention that "tiling by zero" skips tiling a particular
+  // dimension. This convention is significantly simpler to handle instead of
+  // adjusting affine maps to account for missing dimensions.
+  auto nLoops = op.getNumLoops();
+  SmallVector<Value, 4> tileSizeVector =
+      tilingOptions.tileSizeComputationFunction(rewriter, op);
+  if (tileSizeVector.size() < nLoops) {
+    auto zero = std_constant_index(0);
+    tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
+  }
+
   TiledAndFusedLinalgOps ret;
+
   // Find the loops that can be tiled and fused.
-  ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
+  std::set<unsigned> tileFuseLoops =
+      collectTileAndFuseLoops(op, fusableDependences);
 
   // If there are no fusable dependences or there are no tile+fusable loops,
   // just return.
-  if (ret.fusedLoopDims.empty()) {
+  if (tileFuseLoops.empty()) {
     return llvm::None;
   }
 
-  // Tile the fused loops in the last operation in the list.
-  SmallVector<Value, 4> tileSizeVector =
-      tilingOptions.tileSizeComputationFunction(builder, rootOp);
-  Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
-      builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
-  if (!tiledRootOp) {
-    rootOp.emitError("failed to tile the fused loops");
+  // Get the tile sizes for the first and second tiling steps. For the first
+  // step the tile size are set to zero for the loops that arent
+  // fused. Similarly for the second step, the tile sizes are set to zero for
+  // the loops that are fused. For example, if for the following input
+  //
+  // ```
+  //   linalg.add ins(%a, %b) outs(%c)
+  //   linalg.matmul ins(%d, %c) outs(%e)
+  // ```
+  //
+  // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}`
+  // respectively, and since only `j` can be tiled and fused. The tile sizes
+  // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable
+  // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile
+  // the tiled matmul generated by the first tiling step.
+  SmallVector<Value, 4> tileAndFuseSizes, tileSizes;
+  for (auto tileSize : enumerate(tileSizeVector)) {
+    auto zero = std_constant_index(0);
+    if (tileFuseLoops.count(tileSize.index())) {
+      tileAndFuseSizes.push_back(tileSize.value());
+      tileSizes.push_back(zero);
+    } else {
+      tileSizes.push_back(tileSize.value());
+      tileAndFuseSizes.push_back(zero);
+    }
+  }
+
+  // Tile for the loops that can be fused.
+  LinalgTilingOptions firstTilingOptions = tilingOptions;
+  firstTilingOptions.setTileSizes(tileAndFuseSizes);
+  Optional<TiledLinalgOp> firstTiledOp =
+      tileLinalgOp(rewriter, op, firstTilingOptions);
+  if (!firstTiledOp)
     return llvm::None;
+  ret.op = firstTiledOp->op;
+  ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end());
+
+  rewriter.setInsertionPoint(ret.op);
+  // Fuse the operands.
+  for (auto dependence : fusableDependences) {
+    LinalgOp producerOp = cast<LinalgOp>(dependence.first);
+    unsigned producerIdx =
+        dependence.second.front().dependentOpView.operandIndex;
+    unsigned consumerIdx =
+        dependence.second.front().indexingOpView.operandIndex;
+    LinalgOp fusedOp = fuse(rewriter, producerOp,
+                            producerOp.getOutputIndex(producerIdx).getValue(),
+                            ret.op, consumerIdx);
+    ret.fusedProducers.push_back(fusedOp);
+    ret.originalProducers.push_back(producerOp);
+  }
+
+  if (!llvm::all_of(tileSizes, isZero)) {
+    // Tile the remaining loops of the root operation.
+    LinalgTilingOptions secondTilingOptions = tilingOptions;
+    // The distribution is done only for the tile+fused loops.
+    secondTilingOptions.distribution = llvm::None;
+    secondTilingOptions.setTileSizes(tileSizes);
+    Optional<TiledLinalgOp> secondTiledOp =
+        tileLinalgOp(rewriter, ret.op, secondTilingOptions);
+    if (!secondTiledOp)
+      return llvm::None;
+    ret.unfusedLoops.assign(secondTiledOp->loops.begin(),
+                            secondTiledOp->loops.end());
+    rewriter.eraseOp(ret.op);
+    ret.op = secondTiledOp->op;
   }
-  ret.op = tiledRootOp->op;
-  ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
 
-  // Fuse the other operations into the fused inter-tile loops produced above.
-  ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(),
-                                      fusableDependences, ret.fusedLoopDims);
   return ret;
 }
 
 Optional<TiledAndFusedLinalgOps>
-mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
+mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
                                    const LinalgDependenceGraph &dependenceGraph,
-                                   const LinalgTilingOptions &tilingOptions) {
+                                   const LinalgTilingOptions &tilingOptions,
+                                   const LinalgFusionOptions &fusionOptions) {
   switch (tilingOptions.loopType) {
   case LinalgTilingLoopType::Loops:
-    return tileAndFuseLinalgOpsImpl<scf::ForOp>(builder, ops, dependenceGraph,
-                                                tilingOptions);
+    return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph,
+                                                tilingOptions, fusionOptions);
   case LinalgTilingLoopType::ParallelLoops:
     return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
-        builder, ops, dependenceGraph, tilingOptions);
+        rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
   default:;
   }
   return llvm::None;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index a824f6eb620f..e002336ed1c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -166,6 +166,11 @@ struct LinalgOpInstancePromotionOptions {
   /// Alignment of promoted buffer.
   Optional<unsigned> alignment;
 };
+
+struct PromotionInfo {
+  Value fullLocalView;
+  Value partialLocalView;
+};
 } // namespace
 
 LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
@@ -228,10 +233,10 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
 // To account for general boundary effects, padding must be performed on the
 // boundary tiles. For now this is done with an unconditional `fill` op followed
 // by a partial `copy` op.
-Optional<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
-    OpBuilder &b, Location loc, SubViewOp subView,
-    AllocBufferCallbackFn allocationFn, OperationFolder *folder) {
-  ScopedContext scopedContext(b, loc);
+static Optional<PromotionInfo>
+promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, SubViewOp subView,
+                          LinalgOpInstancePromotionOptions const &options,
+                          OperationFolder *folder) {
   auto viewType = subView.getType();
   auto rank = viewType.getRank();
   SmallVector<Value, 4> fullSizes, partialSizes;
@@ -249,7 +254,8 @@ Optional<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
   SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
   // If a callback is not specified, then use the default implementation for
   // allocating the promoted buffer.
-  Optional<Value> fullLocalView = allocationFn(b, subView, fullSizes, folder);
+  Optional<Value> fullLocalView =
+      options.allocationFn(b, subView, fullSizes, folder);
   if (!fullLocalView)
     return {};
   auto zero = folded_std_constant_index(folder, 0);
@@ -273,8 +279,8 @@ promoteSubViews(OpBuilder &b, Location loc,
 
   for (auto v : options.subViews) {
     SubViewOp subView = cast<SubViewOp>(v.second.getDefiningOp());
-    Optional<PromotionInfo> promotionInfo = promoteSubviewAsNewBuffer(
-        b, loc, subView, options.allocationFn, folder);
+    Optional<PromotionInfo> promotionInfo =
+        promoteSubviewAsNewBuffer(b, loc, subView, options, folder);
     if (!promotionInfo)
       return {};
     promotionInfoMap[v.first] = *promotionInfo;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index a855c07cb8d4..836cc28e0a47 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -165,69 +165,17 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
   if (!linalgOp.hasBufferSemantics())
     return failure();
 
-  DenseSet<Operation *> producers;
-  producers.insert(linalgOp);
-  for (auto dependence : dependenceGraph.getDependentOperations(linalgOp)) {
-    if (!fusionOptions.indicesToFuse.count(
-            dependence.indexingOpView.operandIndex))
-      continue;
-    if (isa<LinalgOp>(dependence.dependentOpView.op))
-      producers.insert(dependence.dependentOpView.op);
-  }
-
-  SmallVector<LinalgOp, 1> fusionOps;
-  for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
-       ++it) {
-    auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
-    if (producerLinalgOp && producers.count(producerLinalgOp))
-      fusionOps.push_back(producerLinalgOp);
-  }
-  fusionOps.push_back(linalgOp);
-
-  SmallVector<Value, 4> tileSizes =
-      tilingOptions.tileSizeComputationFunction(rewriter, op);
-  LinalgTilingOptions instanceTilingOptions = tilingOptions;
-  instanceTilingOptions.setTileSizes(tileSizes);
   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
-      rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
+      rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
   if (!tiledAndFusedOps)
     return failure();
-
-  // Tile the unfused loops;
-  SmallVector<Value, 4> unfusedLoopTileSizes;
-  Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0);
-  for (auto tileSize : enumerate(tileSizes)) {
-    if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
-      unfusedLoopTileSizes.push_back(zero);
-    else
-      unfusedLoopTileSizes.push_back(tileSize.value());
-  }
-  // Tile the loop only if there is a non-zero tile size.
-  if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
-    unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
-  if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
-        if (auto cst = val.getDefiningOp<ConstantIndexOp>())
-          return cst.getValue() != 0;
-        return true;
-      })) {
-    LinalgTilingOptions unfusedTilingOptions = tilingOptions;
-    unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
-    Optional<TiledLinalgOp> unfusedTiledOp =
-        tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
-    if (!unfusedTiledOp)
-      return failure();
-    rewriter.eraseOp(tiledAndFusedOps->op);
-    tiledAndFusedOps->op = unfusedTiledOp->op;
-  }
-
   marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation());
   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
     fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation());
   }
-  for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
+  for (auto origProducerOp : tiledAndFusedOps->originalProducers)
     originalOpMarker.replaceLinalgMarker(rewriter,
                                          origProducerOp.getOperation());
-  }
   rewriter.updateRootInPlace(
       op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); });
   return success();

diff  --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
index fa471811ef4e..2ddc66651db2 100644
--- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
@@ -47,9 +47,7 @@ module {
 //      CHECK:     %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]]
 //      CHECK:     %[[SV3:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]]
 // CHECK-SAME:       [%[[TILE_M_2]], %[[TILE_N_2]]]
-//      CHECK:     %[[SV3_2:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]]
-// CHECK-SAME:       [%[[TILE_M]], %[[TILE_N]]]
-//      CHECK:     linalg.fill(%[[SV3_2]], %[[CST]])
+//      CHECK:     linalg.fill(%[[SV3]], %[[CST]])
 // CHECK-SAME:       __internal_linalg_transform__ = "after_basic_fusion_producer"
 //      CHECK:     scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
 //      CHECK:       %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]]
@@ -111,12 +109,9 @@ module {
 //      CHECK:     %[[TILE_N_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[N_2]]]
 //      CHECK:     %[[SV2:.+]] = subview %[[ARG3]][0, %[[IV0]]]
 // CHECK-SAME:       [%[[M]], %[[TILE_N_2]]]
-//      CHECK:     %[[K_2:.+]] = dim %[[ARG1]], %[[C0]]
 //      CHECK:     %[[SV3:.+]] = subview %[[ARG1]][0, %[[IV0]]]
-// CHECK-SAME:       [%[[K_2]], %[[TILE_N]]]
-//      CHECK:     %[[SV3_2:.+]] = subview %[[ARG2]][0, %[[IV0]]]
-// CHECK-SAME:       [%[[K_2]], %[[TILE_N]]]
-//      CHECK:     linalg.copy(%[[SV3]], %[[SV3_2]])
+// CHECK-SAME:       [%[[K]], %[[TILE_N]]]
+//      CHECK:     linalg.copy(%[[SV3]], %[[SV1]])
 // CHECK-SAME:       __internal_linalg_transform__ = "after_rhs_fusion_producer"
 //  CHECK-NOT:     linalg.fill
 //  CHECK-DAG:     %[[M_2:.+]] = dim %[[ARG0]], %[[C0]]
@@ -191,16 +186,11 @@ module {
 //      CHECK:     %[[N:.+]] = dim %[[ARG3]], %[[C1]]
 //      CHECK:     %[[SV2:.+]] = subview %[[ARG3]][%[[IV0]], 0]
 // CHECK-SAME:       [%[[TILE_M_2]], %[[N]]]
-//      CHECK:     %[[SV2_2:.+]] = subview %[[ARG3]][%[[IV0]], 0]
-// CHECK-SAME:       [%[[TILE_M]], %[[N]]]
-//      CHECK:     %[[K_2:.+]] = dim %[[ARG0]], %[[C1]]
 //      CHECK:     %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME:       [%[[TILE_M]], %[[K_2]]]
-//      CHECK:     %[[SV3_2:.+]] = subview %[[ARG1]][%[[IV0]], 0]
-// CHECK-SAME:       [%[[TILE_M]], %[[K_2]]]
-//      CHECK:     linalg.copy(%[[SV3]], %[[SV3_2]])
+// CHECK-SAME:       [%[[TILE_M]], %[[K]]]
+//      CHECK:     linalg.copy(%[[SV3]], %[[SV1]])
 // CHECK-SAME:       __internal_linalg_transform__ = "after_two_operand_fusion_producer"
-//      CHECK:     linalg.fill(%[[SV2_2]], %[[CST]])
+//      CHECK:     linalg.fill(%[[SV2]], %[[CST]])
 // CHECK-SAME:       __internal_linalg_transform__ = "after_two_operand_fusion_producer"
 //  CHECK-DAG:     %[[N_2:.+]] = dim %[[ARG2]], %[[C1]]
 //      CHECK:     scf.parallel (%[[IV1:.+]]) =
@@ -271,18 +261,15 @@ module {
 //      CHECK:     %[[N:.+]] = dim %[[ARG4]], %[[C1]]
 //      CHECK:     %[[SV2:.+]] = subview %[[ARG4]][%[[IV0]], 0]
 // CHECK-SAME:       [%[[TILE_M_2]], %[[N]]]
-//      CHECK:     %[[K2_2:.+]] = dim %[[ARG1]], %[[C1]]
 //      CHECK:     %[[K1:.+]] = dim %[[ARG0]], %[[C1]]
 //      CHECK:     %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0]
 // CHECK-SAME:       [%[[TILE_M]], %[[K1]]]
-//      CHECK:     %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2_2]]]
-//      CHECK:     %[[SV1_2:.+]] = subview %[[ARG2]][%[[IV0]], 0]
-// CHECK-SAME:       [%[[TILE_M]], %[[K2_2]]]
+//      CHECK:     %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2]]]
 //      CHECK:     linalg.matmul
 // CHECK-SAME:         __internal_linalg_transform__ = "after_lhs_fusion_producer"
 // CHECK-SAME:         ins(%[[SV3]], %[[SV4]]
 // CHECK-SAME:           : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
-// CHECK-SAME:         outs(%[[SV1_2]] : memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME:         outs(%[[SV1]] : memref<?x?xf32, #[[MAP1]]>)
 //  CHECK-DAG:     %[[N_2:.+]] = dim %[[ARG3]], %[[C1]]
 //      CHECK:     scf.parallel (%[[IV1:.+]]) =
 // CHECK-SAME:       (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) {

diff  --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir
deleted file mode 100644
index a02c878ef341..000000000000
--- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir
+++ /dev/null
@@ -1,133 +0,0 @@
-// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s
-
-module {
-  func @three_op_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
-                        %arg2: memref<?xf32>, %arg3 : memref<?x?xf32>) {
-    %cst = constant 0.000000e+00 : f32
-    %c0 = constant 0 : index
-    %c1 = constant 1 : index
-    %d0 = dim %arg0, %c0 : memref<?x?xf32>
-    %d1 = dim %arg1, %c1 : memref<?x?xf32>
-    %0 = alloc(%d0, %d1) : memref<?x?xf32>
-    linalg.fill(%0, %cst) : memref<?x?xf32>, f32
-    linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
-      outs(%0 : memref<?x?xf32>)
-    linalg.generic
-      {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
-                        affine_map<(d0, d1) -> (d1)>,
-                        affine_map<(d0, d1) -> (d0, d1)>],
-       iterator_types = ["parallel", "parallel"]}
-      ins(%0, %arg2 : memref<?x?xf32>, memref<?xf32>)
-      outs(%arg3 : memref<?x?xf32>) {
-      ^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) :
-        %5 = addf %arg4, %arg5 : f32
-        linalg.yield %5 : f32
-      }
-    return
-  }
-}
-
-//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
-//       CHECK: func @three_op_fusion
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: memref<?xf32>
-//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-//       CHECK:   %[[TEMP:.+]] = alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
-//       CHECK:   scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} {
-//   CHECK-DAG:     %[[SV_TEMP:.+]] = subview %[[TEMP]][%[[IV0]], %[[IV1]]]
-//   CHECK-DAG:     %[[SV_ARG2:.+]] = subview %[[ARG2]][%[[IV1]]]
-//   CHECK-DAG:     %[[SV_ARG3:.+]] = subview %[[ARG3]][%[[IV0]], %[[IV1]]]
-//   CHECK-DAG:     %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0]
-//   CHECK-DAG:     %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[IV1]]]
-//       CHECK:     linalg.fill(%[[SV_TEMP]], %{{.+}})
-//       CHECK:     linalg.matmul
-//  CHECK-SAME:       ins(%[[SV_ARG0]], %[[SV_ARG1]]
-//  CHECK-SAME:         : memref<?x?xf32, #[[MAP2]]>, memref<?x?xf32, #[[MAP2]]>)
-//  CHECK-SAME:       outs(%[[SV_TEMP]] : memref<?x?xf32, #[[MAP2]]>)
-//       CHECK:     linalg.generic
-//  CHECK-SAME:       ins(%[[SV_TEMP]], %[[SV_ARG2]]
-//  CHECK-SAME:         : memref<?x?xf32, #[[MAP2]]>, memref<?xf32, #[[MAP3]]>)
-//  CHECK-SAME:       outs(%[[SV_ARG3]] : memref<?x?xf32, #[[MAP2]]>)
-//       CHECK:     scf.yield
-//       CHECK:   }
-
-// -----
-
-module {
-  func @sequence_of_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
-                           %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
-			   %arg4: memref<?x?xf32>) {
-    %cst = constant 0.000000e+00 : f32
-    %c0 = constant 0 : index
-    %c1 = constant 1 : index
-    %m = dim %arg0, %c0 : memref<?x?xf32>
-    %n1 = dim %arg1, %c1 : memref<?x?xf32>
-    %n2 = dim %arg2, %c1 : memref<?x?xf32>
-    %n3 = dim %arg3, %c1 : memref<?x?xf32>
-    %0 = alloc(%m, %n1) : memref<?x?xf32>
-    %1 = alloc(%m, %n2) : memref<?x?xf32>
-    linalg.fill(%0, %cst) : memref<?x?xf32>, f32
-    linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
-      outs(%0 : memref<?x?xf32>)
-    linalg.fill(%1, %cst) : memref<?x?xf32>, f32
-    linalg.matmul ins(%0, %arg2 : memref<?x?xf32>, memref<?x?xf32>)
-      outs(%1 : memref<?x?xf32>)
-    linalg.fill(%arg4, %cst) : memref<?x?xf32>, f32
-    linalg.matmul ins(%1, %arg3 : memref<?x?xf32>, memref<?x?xf32>)
-      outs(%arg4 : memref<?x?xf32>)
-    return
-  }
-}
-
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
-//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-//       CHECK: func @sequence_of_matmul
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
-//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
-//   CHECK-DAG:   %[[C16:.+]] = constant 16 : index
-//   CHECK-DAG:   %[[M:.+]] = dim %[[ARG0]], %[[C0]]
-//   CHECK-DAG:   %[[N1:.+]] = dim %[[ARG1]], %[[C1]]
-//   CHECK-DAG:   %[[N2:.+]] = dim %[[ARG2]], %[[C1]]
-//       CHECK:   %[[ALLOC1:.+]] = alloc(%[[M]], %[[N1]])
-//       CHECK:   %[[ALLOC2:.+]] = alloc(%[[M]], %[[N2]])
-//       CHECK:   scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]])
-//  CHECK-SAME:     step (%[[C16]]) {
-//       CHECK:     %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
-//       CHECK:     %[[SV_ALLOC2:.+]] = subview %[[ALLOC2]][%[[IV0]], 0]
-//  CHECK-SAME:       [%[[TILE_M]], %[[N2]]]
-//       CHECK:     %[[M_2:.+]] = dim %[[ARG4]], %[[C0]]
-//       CHECK:     %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
-//       CHECK:     %[[N3:.+]] = dim %[[ARG4]], %[[C1]]
-//       CHECK:     %[[SV_ARG4:.+]] = subview %[[ARG4]][%[[IV0]], 0]
-//  CHECK-SAME:       [%[[TILE_M_2]], %[[N3]]]
-//       CHECK:     %[[SV_ARG4_2:.+]] = subview %[[ARG4]][%[[IV0]], 0]
-//  CHECK-SAME:       [%[[TILE_M]], %[[N3]]]
-//       CHECK:     %[[SV_ALLOC1:.+]] = subview %[[ALLOC1]][%[[IV0]], 0]
-//  CHECK-SAME:       [%[[TILE_M]], %[[N1]]]
-//       CHECK:     %[[SV_ARG2:.+]] = subview %[[ARG2]][0, 0] [%[[N1]], %[[N2]]]
-//       CHECK:     %[[N0:.+]] = dim %[[ARG0]], %[[C1]]
-//       CHECK:     %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0]
-//  CHECK-SAME:       [%[[TILE_M:.+]], %[[N0]]]
-//       CHECK:     %[[SV_ARG1:.+]] = subview %[[ARG1]][0, 0] [%[[N0]], %[[N1]]]
-//       CHECK:     linalg.fill(%[[SV_ALLOC1]], %{{.+}})
-//       CHECK:     linalg.matmul ins(%[[SV_ARG0]], %[[SV_ARG1]]
-//  CHECK-SAME:        : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
-//  CHECK-SAME:        outs(%[[SV_ALLOC1]] : memref<?x?xf32, #[[MAP1]]>)
-//       CHECK:     linalg.fill(%[[SV_ALLOC2]], %{{.+}})
-//       CHECK:     linalg.matmul ins(%[[SV_ALLOC1]], %[[SV_ARG2]]
-//  CHECK-SAME:        : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
-//  CHECK-SAME:        outs(%[[SV_ALLOC2]] : memref<?x?xf32, #[[MAP1]]>)
-//       CHECK:     linalg.fill(%[[SV_ARG4_2]], %{{.+}})
-//       CHECK:     linalg.matmul ins(%[[SV_ALLOC2]], %[[ARG3]]
-//  CHECK-SAME:        : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
-//  CHECK-SAME:        outs(%[[SV_ARG4]] : memref<?x?xf32, #[[MAP1]]>)
-//       CHECK:     scf.yield
-//       CHECK:   }
-

diff  --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 5289b2d1055f..eb9e3a533138 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -197,44 +197,6 @@ struct TestLinalgGreedyFusion
     }
   }
 };
-
-/// Pass to test tile and fuse of sequence of operations. Intended only for
-/// testing.
-struct TestLinalgTileAndFuseSequencePass
-    : public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> {
-  TestLinalgTileAndFuseSequencePass() = default;
-  TestLinalgTileAndFuseSequencePass(
-      const TestLinalgTileAndFuseSequencePass &pass){};
-
-  ListOption<int64_t> tileSizes{
-      *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"),
-      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
-
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
-  }
-
-  void runOnFunction() override {
-    FuncOp funcOp = getOperation();
-    auto &blocks = funcOp.getBody().getBlocks();
-    if (!llvm::hasSingleElement(blocks)) {
-      return;
-    }
-    SmallVector<LinalgOp, 2> linalgOps =
-        llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
-    Aliases aliases;
-    LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
-    OpBuilder builder(funcOp.getContext());
-    Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
-        builder, linalgOps, dependenceGraph,
-        LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(
-            LinalgTilingLoopType::ParallelLoops));
-    if (!tileAndFuseOps)
-      return signalPassFailure();
-    for (auto op : linalgOps)
-      op.erase();
-  }
-};
 } // namespace
 
 namespace mlir {
@@ -249,12 +211,5 @@ void registerTestLinalgGreedyFusion() {
       "test-linalg-greedy-fusion",
       "Test Linalg fusion by applying a greedy test transformation.");
 }
-void registerTestLinalgTileAndFuseSequencePass() {
-  PassRegistration<TestLinalgTileAndFuseSequencePass>
-      testTileAndFuseSequencePass(
-          "test-linalg-tile-and-fuse",
-          "Test Linalg tiling and fusion of a sequence of Linalg operations.");
-}
-
 } // namespace test
 } // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index a0e36cf82534..4771b11b20e4 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -74,7 +74,6 @@ void registerTestLinalgCodegenStrategy();
 void registerTestLinalgFusionTransforms();
 void registerTestLinalgGreedyFusion();
 void registerTestLinalgHoisting();
-void registerTestLinalgTileAndFuseSequencePass();
 void registerTestLinalgTransforms();
 void registerTestLivenessPass();
 void registerTestLoopFusion();
@@ -142,7 +141,6 @@ void registerTestPasses() {
   test::registerTestLinalgFusionTransforms();
   test::registerTestLinalgGreedyFusion();
   test::registerTestLinalgHoisting();
-  test::registerTestLinalgTileAndFuseSequencePass();
   test::registerTestLinalgTransforms();
   test::registerTestLivenessPass();
   test::registerTestLoopFusion();


        


More information about the Mlir-commits mailing list