[Mlir-commits] [mlir] f8284d2 - [mlir][Linalg] Fuse sequence of Linalg operation (on buffers)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 19 19:03:22 PST 2020


Author: MaheshRavishankar
Date: 2020-11-19T19:03:06-08:00
New Revision: f8284d21a8e294d58a0acd4b8b2e906d7a9f110c

URL: https://github.com/llvm/llvm-project/commit/f8284d21a8e294d58a0acd4b8b2e906d7a9f110c
DIFF: https://github.com/llvm/llvm-project/commit/f8284d21a8e294d58a0acd4b8b2e906d7a9f110c.diff

LOG: [mlir][Linalg] Fuse sequence of Linalg operation (on buffers)

Enhance the tile+fuse logic to allow fusing a sequence of operations.

Differential Revision: https://reviews.llvm.org/D90991

Added: 
    mlir/test/Dialect/Linalg/fusion-sequence.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/fusion-pattern.mlir
    mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8d531a1e343a..fac91ca0b256 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -37,14 +37,6 @@ struct TiledLinalgOp {
   SmallVector<Value, 4> tensorResults;
 };
 
-struct TiledAndFusedLinalgOps {
-  LinalgOp op;
-  SmallVector<LinalgOp, 1> fusedProducers;
-  SmallVector<LinalgOp, 1> originalProducers;
-  SmallVector<Operation *, 4> fusedLoops;
-  SmallVector<Operation *, 4> unfusedLoops;
-};
-
 /// Populates patterns for vectorization of all ConvN-D ops.
 void populateConvVectorizationPatterns(
     MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
@@ -73,14 +65,11 @@ void populateLinalgBufferizePatterns(MLIRContext *context,
 Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
                                      const LinalgTilingOptions &options);
 
-/// Tile and fuse the `op` with its producers. The tile and fuse proceeds in
-/// three steps
-/// - Find tile loops that are fusable with its producer tile loops (a.k.a. tile
-///   + fuse loops).
-/// - Tile just these loops of the consumer (root operation) and fuse with
-///   the producer.
-/// - Tile again the tiled consumer operation produced above to do rest of
-///   the tiling specified by the `tilingOptions`.
+/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This
+/// proceeds as follows:
+/// - Find outer parallel loops in these ops that can be fused.
+/// - Tile fusable outer parallel loops of the last operation in the sequence.
+/// - Fuse the remaining operations with the tiled operation
 ///
 /// For example, consider the sequence of matmul below
 ///
@@ -107,36 +96,39 @@ Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
 ///     : memref<256x32xf32> to memref<16x32xf32, #map0>
 ///   %3 = subview %arg1[0, 0] [32, 32] [1, 1]
 ///     : memref<32x32xf32> to memref<32x32xf32, #map1>
+///   %4 = subview %arg3[0, 0] [32, 32] [1, 1]
+///     : memref<32x32xf32> to memref<32x32xf32, #map1>
 ///   linalg.matmul
 ///     ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>)
 ///     outs(%0 : memref<16x32xf32, #map0>)
-///   scf.parallel (%arg6) = (%c0) to (%c32) step (%c8) {
-///   scf.for %arg7 = %c0 to %c32 step %c4 {
-///     %4 = subview %0[0, %arg7] [16, 4] [1, 1]
-///       : memref<16x32xf32, #map0> to memref<16x4xf32, #map0>
-///     %5 = subview %arg3[%arg7, %arg6] [4, 8] [1, 1]
-///       : memref<32x32xf32> to memref<4x8xf32, #map0>
-///     %6 = subview %1[0, %arg6] [16, 8] [1, 1]
-///       : memref<16x32xf32, #map0> to memref<16x8xf32, #map0>
-///     linalg.matmul
-///       ins(%4, %5 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
-///       outs(%6 : memref<16x8xf32, #map0>)
-///     }
-///     scf.yield
-///   }
-///   scf.yield
+///   linalg.matmul
+///     ins(%0, %4 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
+///     outs(%1 : memref<16x8xf32, #map0>)
 /// }
 ///
-/// The following tiling options are handled 
diff erently in tile+fuse (compared
-/// to tile only)
+/// `tilingOptions` are used to tile the corresponding operation in `ops` (the
+/// size of the former should be same as size of the latter. Based on how
+/// tile+fuse is implemented, the fused loops are generated based on the last
+/// operation in the sequence. For example, the tile sizes for the fused loops
+/// is obtained from `tilingOptions.back()`. The following tiling options are
+/// handled 
diff erently in tile+fuse (compared to tile only)
 /// - Interchange of the tiling loops is not supported right now.
-/// - Distribution is only done for the tile+fuse loops. The tiled loops
-///   generated by the second tiling is not distributed.
+/// - Only the fused loops are distributed.
+struct TiledAndFusedLinalgOps {
+  /// Operation obtained by tiling the last operation in sequence of `ops`
+  /// passed to `tileAndFuseLinalgOps`.
+  LinalgOp op;
+  /// The dimension of the loops that are fused.
+  std::set<unsigned> fusedLoopDims;
+  /// The generated fused operations (created within the fused loops).
+  SmallVector<LinalgOp, 1> fusedProducers;
+  /// The fused loop generated.
+  SmallVector<Operation *, 4> fusedLoops;
+};
 Optional<TiledAndFusedLinalgOps>
-tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
+tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
                      const LinalgDependenceGraph &dependenceGraph,
-                     const LinalgTilingOptions &tilingOptions,
-                     const LinalgFusionOptions &fusionOptions);
+                     const LinalgTilingOptions &tilingOptions);
 
 /// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`.
 /// This is an in-place transformation controlled by `interchangeVector`.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 969bea4a4549..02417bd78d9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -178,6 +178,9 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
     Value shape = en.value();
     SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
     for (auto en2 : llvm::enumerate(map.getResults())) {
+      auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
+      if (!dimExpr)
+        continue;
       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
         LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
                                 << loopDepth << "\n");
@@ -190,49 +193,18 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
   llvm_unreachable("Expect to be able to extract a shape defining loop range");
 }
 
-/// Fuses the producer of `producerIdx` into the loop immediately enclosing
-/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
-/// is needed just before the `consumer.
-///
-/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
-/// 2 cases:
-///   1. Buffer case: `producerIdx` is the index of the buffer in
-///      `producer.getOutputBuffers()`.
-///   2. Tensor case: `producerIdx` is the index of the tensor in
-///      `producer.getResults()`.
-static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
-                     LinalgOp consumer, unsigned consumerIdx) {
-  Operation *shapeProducingOp =
-      consumer.getShapedOperand(consumerIdx).getDefiningOp();
-  assert((isa<SubViewOp>(shapeProducingOp) ||
-          isa<SubTensorOp>(shapeProducingOp)) &&
-         "SubviewOp or SubTensorOp expected");
-
-  // loopToOperandRangesMaps are permutations-only by construction:
-  //   we can always identify a data dimension with a (at least one) loop
-  //   dimension.
-  // TODO: extend this with range inference.
-  AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
-  LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
-                          << ", producer map: " << producerMap << "\n");
+/// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges`
+/// provides the loop range information for the fused loops. The rest are
+/// obtained from the producer itself, since they are not tiled + fused.
+static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
+                     const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
 
   unsigned nPar = producer.getNumParallelLoops();
   unsigned nRed = producer.getNumReductionLoops();
   unsigned nWin = producer.getNumWindowLoops();
   SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
-
-  // Iterate over dimensions identified by the producer map for `producerIdx`.
-  // This defines a subset of the loop ranges that we need to complete later.
-  auto loc = consumer.getLoc();
-  for (auto en : llvm::enumerate(producerMap.getResults())) {
-    unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
-    loopRanges[posInProducerLoop] =
-        isa<SubViewOp>(shapeProducingOp)
-            ? cast<SubViewOp>(shapeProducingOp)
-                  .getOrCreateRanges(b, loc)[en.index()]
-            : cast<SubTensorOp>(shapeProducingOp)
-                  .getOrCreateRanges(b, loc)[en.index()];
-  }
+  for (auto fusedLoops : fusedLoopsAndRanges)
+    loopRanges[fusedLoops.first] = fusedLoops.second;
 
   // Iterate over all dimensions. For the dimensions not identified by the
   // producer map for `producerIdx`, we need to explicitly compute the shape
@@ -250,7 +222,45 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
     }
   }
 
-  return cloneWithLoopRanges(b, loc, producer, loopRanges);
+  return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges);
+}
+
+/// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
+/// expected to be defined by a subview op or a subtensor op.
+static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
+                                      Value shapedOperand, unsigned dim) {
+  Operation *shapeProducingOp = shapedOperand.getDefiningOp();
+  if (auto subViewOp = dyn_cast<SubViewOp>(shapeProducingOp))
+    return subViewOp.getOrCreateRanges(b, loc)[dim];
+  if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
+    return subTensorOp.getOrCreateRanges(b, loc)[dim];
+  llvm_unreachable("SubviewOp or SubTensorOp expected");
+}
+
+/// Fuses the producer of `producerIdx` into the loop immediately enclosing
+/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
+/// is needed just before the `consumer.
+///
+/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
+/// 2 cases:
+///   1. Buffer case: `producerIdx` is the index of the buffer in
+///      `producer.getOutputBuffers()`.
+///   2. Tensor case: `producerIdx` is the index of the tensor in
+///      `producer.getResults()`.
+static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
+                     LinalgOp consumer, unsigned consumerIdx) {
+  AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
+  LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
+                          << ", producer map: " << producerMap << "\n");
+  DenseMap<unsigned, Range> fusedLoopsAndRanges;
+  Location loc = consumer.getLoc();
+  Value shapedOperand = consumer.getShapedOperand(consumerIdx);
+  for (auto en : llvm::enumerate(producerMap.getResults())) {
+    unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
+    fusedLoopsAndRanges[posInProducerLoop] =
+        getRangeFromOperandShape(b, loc, shapedOperand, en.index());
+  }
+  return fuse(b, producer, fusedLoopsAndRanges);
 }
 
 // Encode structural fusion safety preconditions.
@@ -525,6 +535,69 @@ using FusableOpDependencesTy = llvm::MapVector<
     Operation *,
     SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
 
+/// Returns the mapping from iterations in the consumer that write to the same
+/// location as the iterations in the producer. To do so use
+/// - indexing map of the fused view in the consumer : consumerIndexMap
+/// - indexing map of the fused view in the producer : producerIndexMap
+///     consumerLoopToProducerLoop =
+///       inverse(producerIndexMap).compose(consumerIndexMap)
+static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
+    LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
+  auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
+  AffineMap producerIndexingMap =
+      producer.getIndexingMap(dependence.dependentOpView.operandIndex);
+  auto consumer = cast<LinalgOp>(dependence.indexingOpView.op);
+  AffineMap consumerIndexingMap =
+      consumer.getIndexingMap(dependence.indexingOpView.operandIndex);
+
+  AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
+      producer.iterator_types().getValue(), producerIndexingMap);
+  if (!prunedProducerIndexingMap.isPermutation())
+    return None;
+
+  if (consumerIndexingMap.getNumResults() !=
+      prunedProducerIndexingMap.getNumResults())
+    return None;
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "\t producerMap : ";
+    producerIndexingMap.print(llvm::dbgs());
+    llvm::dbgs() << "  pruned : ";
+    prunedProducerIndexingMap.print(llvm::dbgs());
+    llvm::dbgs() << "\n";
+    llvm::dbgs() << "\t consumerMap : ";
+    consumerIndexingMap.print(llvm::dbgs());
+    llvm::dbgs() << "\n";
+  });
+
+  AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
+  if (!invProducerIndexMap)
+    return None;
+
+  return invProducerIndexMap.compose(consumerIndexingMap);
+}
+
+/// Given a projected permutation `map`, returns true if the map changes the
+/// order in which the fused loop dimension appear.
+static bool doesTransposeAccess(AffineMap map,
+                                const std::set<unsigned> &fusableLoops) {
+  Optional<unsigned> lastFusableLoop;
+  for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
+         return expr.cast<AffineDimExpr>().getPosition();
+       })) {
+    if (!fusableLoops.count(pos))
+      continue;
+    if (!lastFusableLoop) {
+      lastFusableLoop = pos;
+      continue;
+    }
+    if (pos <= lastFusableLoop.getValue())
+      return true;
+    lastFusableLoop = pos;
+  }
+  return false;
+}
+
 /// Returns the positions of the loop in `op` that can be tiled based on the
 /// operations that are to be fused with it. For example, in a
 ///
@@ -538,13 +611,7 @@ using FusableOpDependencesTy = llvm::MapVector<
 /// 2. Of the parallel loops only some can be fused. Only those loops can be
 ///    fused such where the fusable loops iteration space only touches one tile
 ///    of the fused operation. This is because the producer (which is writing
-///    the fused subview) has update semantics. To compute this,
-///    a. Find the mapping from iterations in the consumer that write to the
-///       same location as the iterations in the producer. To do so use
-///       - indexing map of the fused view in the consumer : consumerIndexMap
-///       - indexing map of the fused view in the producer : producerIndexMap
-///       consumerLoopToProducerLoop =
-///         inverse(producerIndexMap).compose(consumerIndexMap)
+///    the fused subview) has update semantics.
 ///
 /// Since an inverse computation is needed, we need to consider the projection
 /// of the producerIndexMap w.r.t the parallel loops.  The actual fusable loops
@@ -582,8 +649,9 @@ using FusableOpDependencesTy = llvm::MapVector<
 ///     submap with only parallel loops = affine_map<(i, j) -> (j)>
 ///   Fused dimensions : j
 static std::set<unsigned>
-collectTileAndFuseLoops(LinalgOp op,
-                        const FusableOpDependencesTy &fusableDependences) {
+collectFusableLoops(ArrayRef<LinalgOp> ops,
+                    const FusableOpDependencesTy &fusableDependences) {
+  assert(!ops.empty());
   auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
     return linalgOp.iterator_types()
         .getValue()
@@ -594,88 +662,57 @@ collectTileAndFuseLoops(LinalgOp op,
         .size();
   };
 
-  LLVM_DEBUG({
-    llvm::dbgs() << "Op : ";
-    op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-    llvm::dbgs() << "\n";
-  });
-
-  size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
-  for (auto dependence : fusableDependences) {
-    linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
+  size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
+  for (auto op : ops.drop_back()) {
     numOuterParallelLoops =
-        std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer));
+        std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
   }
 
   std::set<unsigned> fusableLoops;
   auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
   fusableLoops.insert(range.begin(), range.end());
-  for (auto dependence : fusableDependences) {
-    LLVM_DEBUG({
-      llvm::dbgs() << "\t fusable :";
-      for (unsigned i : fusableLoops)
-        llvm::dbgs() << " " << i;
-      llvm::dbgs() << "\n";
-    });
-    linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
-
-    assert(!dependence.second.empty() &&
-           "unexpected producer but not dependences");
-    AffineMap producerIndexingMap = producer.getIndexingMap(
-        dependence.second.front().dependentOpView.operandIndex);
-    AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
-        producer.iterator_types().getValue(), producerIndexingMap);
-    if (!prunedProducerIndexingMap.isPermutation())
-      return {};
-
-    AffineMap consumerIndexingMap = op.getIndexingMap(
-        dependence.second.front().indexingOpView.operandIndex);
-    if (consumerIndexingMap.getNumResults() !=
-        prunedProducerIndexingMap.getNumResults())
-      return {};
-
-    LLVM_DEBUG({
-      llvm::dbgs() << "\t producerMap : ";
-      producerIndexingMap.print(llvm::dbgs());
-      llvm::dbgs() << "  pruned : ";
-      prunedProducerIndexingMap.print(llvm::dbgs());
-      llvm::dbgs() << "\n";
-      llvm::dbgs() << "\t consumerMap : ";
-      consumerIndexingMap.print(llvm::dbgs());
-      llvm::dbgs() << "\n";
-    });
-
-    AffineMap invProducerIndexMap =
-        inversePermutation(prunedProducerIndexingMap);
-    if (!invProducerIndexMap)
-      return {};
-
-    AffineMap consumerLoopToProducerLoop =
-        invProducerIndexMap.compose(consumerIndexingMap);
-
-    LLVM_DEBUG({
-      llvm::dbgs() << "\t consumerLoopToProducerLoop : ";
-      consumerLoopToProducerLoop.print(llvm::dbgs());
-    });
-
-    std::set<unsigned> candidates;
-    for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) {
-      AffineDimExpr dimExpr = expr.dyn_cast<AffineDimExpr>();
-      if (!dimExpr)
-        continue;
-      unsigned position = dimExpr.getPosition();
-      if (fusableLoops.count(position))
-        candidates.insert(position);
+
+  for (auto op : reverse(ops)) {
+    for (auto dependence : fusableDependences.lookup(op)) {
+      LLVM_DEBUG({
+        llvm::dbgs() << "\t fusable :";
+        for (unsigned i : fusableLoops)
+          llvm::dbgs() << " " << i;
+        llvm::dbgs() << "\n";
+      });
+
+      Optional<AffineMap> consumerLoopToProducerLoop =
+          getConsumerLoopToProducerLoopMap(dependence);
+      if (!consumerLoopToProducerLoop) {
+        op.emitRemark("failed to get map from consumer loop to producer loop");
+        return {};
+      }
+      // todo: This condition is only an implementation limitation. When fusing
+      // the operation, if the accesses in the producer/consumer are transposes
+      // of each other, the loop bounds for the tiled producer can be
+      // manipulated accordingly. This requires some additional bookkeeping in
+      // the implementation of tile+fuse that is defered to later.
+      if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
+        op.emitRemark("unhandled fusion when fusion requires permutation");
+        return {};
+      }
+
+      std::set<unsigned> candidates;
+      for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
+        unsigned position = expr.cast<AffineDimExpr>().getPosition();
+        if (fusableLoops.count(position))
+          candidates.insert(position);
+      }
+      LLVM_DEBUG({
+        llvm::dbgs() << "\t candidates :";
+        for (unsigned i : candidates)
+          llvm::dbgs() << " " << i;
+        llvm::dbgs() << "\n";
+      });
+      if (candidates.empty())
+        return {};
+      std::swap(candidates, fusableLoops);
     }
-    LLVM_DEBUG({
-      llvm::dbgs() << "\t candidates :";
-      for (unsigned i : candidates)
-        llvm::dbgs() << " " << i;
-      llvm::dbgs() << "\n";
-    });
-    if (candidates.empty())
-      return {};
-    std::swap(candidates, fusableLoops);
   }
 
   return fusableLoops;
@@ -683,60 +720,69 @@ collectTileAndFuseLoops(LinalgOp op,
 
 /// Find all dependences that are to be fusable.
 static FusableOpDependencesTy
-findAllFusableDependences(LinalgOp op,
-                          const LinalgDependenceGraph &dependenceGraph,
-                          const LinalgFusionOptions &fusionOptions) {
+findAllFusableDependences(ArrayRef<LinalgOp> ops,
+                          const LinalgDependenceGraph &dependenceGraph) {
   FusableOpDependencesTy fusableDependences;
   // TODO: Currently fusion would not be legal if the fusable dependence is to
   // the same producer but 
diff erent indexing map in the consumer. Fix this, but
   // in the meanwhile disallow such a fusion.
   DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
-  for (auto operandIndex : fusionOptions.indicesToFuse) {
-    auto fusableDependence =
-        findFusableProducer(op, operandIndex, dependenceGraph);
-    if (!fusableDependence)
-      return FusableOpDependencesTy{};
-    LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
-    // Do not fuse dependences that are to operations not in the same basic
-    // block. This avoid moving fused operations across loops that might
-    // themselves carry dependency making the fusion illegal.
-    if (producerOp.getOperation()->getBlock() !=
-        op.getOperation()->getBlock()) {
-      op.emitRemark("unhandled fusion of ops in 
diff erent basic blocks");
-      return FusableOpDependencesTy{};
-    }
-    // Make sure that the indexing map of the view used for fusion in the
-    // producer is a projected permutation.
-    unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
-    AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
-    if (!producerMap.isProjectedPermutation()) {
-      op.emitRemark("unhandled non permutation indexing map for fused view in "
-                    "producer for operand at index ")
-          << operandIndex;
-      return FusableOpDependencesTy{};
-    }
+  for (LinalgOp op : reverse(ops)) {
+    for (auto operandIndex :
+         llvm::seq<unsigned>(0, op.getNumInputsAndOutputBuffers())) {
+      Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
+          fusableDependence =
+              findFusableProducer(op, operandIndex, dependenceGraph);
+      if (!fusableDependence)
+        continue;
+      LinalgOp producerOp =
+          cast<LinalgOp>(fusableDependence->dependentOpView.op);
+      // Do not fuse dependences that are to operations not in the same basic
+      // block. This avoid moving fused operations across loops that might
+      // themselves carry dependency making the fusion illegal.
+      if (producerOp.getOperation()->getBlock() !=
+          op.getOperation()->getBlock()) {
+        op.emitRemark("unhandled fusion of ops in 
diff erent basic blocks");
+        return FusableOpDependencesTy{};
+      }
+      // Make sure that the indexing map of the view used for fusion in the
+      // producer is a projected permutation.
+      unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
+      AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
+      if (!producerMap.isProjectedPermutation()) {
+        op.emitRemark(
+            "unhandled non permutation indexing map for fused view in "
+            "producer for operand at index ")
+            << operandIndex;
+        return FusableOpDependencesTy{};
+      }
 
-    unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
-    AffineMap consumerMap = op.getIndexingMap(consumerIdx);
-    if (!consumerMap.isProjectedPermutation()) {
-      op.emitRemark(
-          "unhandled case where indexing map for fused view in the consumer is "
-          "not a projected permutation while fusing at index ")
-          << operandIndex;
-      return FusableOpDependencesTy{};
-    }
+      unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
+      AffineMap consumerMap = op.getIndexingMap(consumerIdx);
+      if (!consumerMap.isProjectedPermutation()) {
+        op.emitRemark(
+            "unhandled case where indexing map for fused view in the consumer "
+            "is "
+            "not a projected permuration while fusing at index ")
+            << operandIndex;
+        return FusableOpDependencesTy{};
+      }
 
-    // Check if the producer is already a fusion candidate. Cannot fuse this
-    // dependence if it has a 
diff erent indexing map when used in the consumer.
-    if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
-        fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
-      op.emitRemark("unhandled fusion to the same producer but with 
diff erent "
-                    "indexing maps");
-      return FusableOpDependencesTy{};
-    }
-    fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
+      // Check if the producer is already a fusion candidate. Cannot fuse this
+      // dependence if it has a 
diff erent indexing map when used in the
+      // consumer.
+      if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
+          fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
+        op.emitRemark(
+            "unhandled fusion to the same producer but with 
diff erent "
+            "indexing maps");
+        return FusableOpDependencesTy{};
+      }
+      fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
 
-    fusableDependences[producerOp.getOperation()].push_back(*fusableDependence);
+      fusableDependences[producerOp.getOperation()].push_back(
+          *fusableDependence);
+    }
   }
   return fusableDependences;
 }
@@ -747,136 +793,120 @@ static bool isZero(Value v) {
   return false;
 }
 
+/// Tile the fused loops in the root operation, by setting the tile sizes for
+/// all other loops to zero (those will be tiled later).
+static Optional<TiledLinalgOp> tileRootOperation(
+    OpBuilder &builder, LinalgOp op, ArrayRef<Value> tileSizeVector,
+    const LinalgTilingOptions &options, const std::set<unsigned> &fusedLoops) {
+  SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
+  auto zero = std_constant_index(0);
+  for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
+    if (!fusedLoops.count(i))
+      tileSizes[i] = zero;
+  LinalgTilingOptions tileFusedLoopsOptions = options;
+  tileFusedLoopsOptions.setTileSizes(tileSizes);
+  return tileLinalgOp(builder, op, tileFusedLoopsOptions);
+}
+
+/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
+/// to be a tiled operation such that it is valid to fuse all operations in
+/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
+/// `tiledOp`.
+static SmallVector<LinalgOp, 1>
+fuseOperations(OpBuilder &builder, LinalgOp tiledOp,
+               ArrayRef<LinalgOp> fusionCandidates,
+               const FusableOpDependencesTy &fusableDependences,
+               const std::set<unsigned> &fusedLoops) {
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPoint(tiledOp);
+  DenseMap<unsigned, Range> fusedLoopsAndRanges;
+  for (unsigned loop : fusedLoops) {
+    ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop);
+    fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
+        builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
+  }
+  SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
+  for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
+    LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges);
+    fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
+    builder.setInsertionPoint(fusedOp);
+  }
+  return fusedOps;
+}
+
 template <typename LoopType>
 static Optional<TiledAndFusedLinalgOps>
-tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
+tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
                          const LinalgDependenceGraph &dependenceGraph,
-                         const LinalgTilingOptions &tilingOptions,
-                         const LinalgFusionOptions &fusionOptions) {
-  assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
-  // Some of the tiling options might not be supportable with tile and fuse.
-  // TODO: Support interchange with tile + fuse.
+                         const LinalgTilingOptions &tilingOptions) {
+  if (ops.empty())
+    return llvm::None;
+  LinalgOp rootOp = ops.back();
+  for (auto op : enumerate(ops)) {
+    // TODO: Nothing in the fusion of sequence of ops is specific to
+    // buffers. This check can be removed after it is tested on tensors.
+    LinalgOp linalgOp = op.value();
+    if (!linalgOp.hasBufferSemantics()) {
+      linalgOp.emitError("tile and fuse only tested for buffer operation");
+      return llvm::None;
+    }
+  }
+  // TODO: Support interchange with tile + fuse. This might actually help do
+  // better fusion.
   if (!tilingOptions.interchangeVector.empty()) {
-    op.emitError("unable to handle tile and fuse with interchange");
+    rootOp.emitError("unable to handle tile and fuse with interchange");
     return llvm::None;
   }
 
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(op);
-  ScopedContext scope(rewriter, op.getLoc());
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPoint(rootOp);
+  ScopedContext scope(builder, rootOp.getLoc());
 
   // Find all the producers.
   FusableOpDependencesTy fusableDependences =
-      findAllFusableDependences(op, dependenceGraph, fusionOptions);
+      findAllFusableDependences(ops, dependenceGraph);
   if (fusableDependences.empty())
     return llvm::None;
 
-  // Enforce the convention that "tiling by zero" skips tiling a particular
-  // dimension. This convention is significantly simpler to handle instead of
-  // adjusting affine maps to account for missing dimensions.
-  auto nLoops = op.getNumLoops();
-  SmallVector<Value, 4> tileSizeVector =
-      tilingOptions.tileSizeComputationFunction(rewriter, op);
-  if (tileSizeVector.size() < nLoops) {
-    auto zero = std_constant_index(0);
-    tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
-  }
-
   TiledAndFusedLinalgOps ret;
-
   // Find the loops that can be tiled and fused.
-  std::set<unsigned> tileFuseLoops =
-      collectTileAndFuseLoops(op, fusableDependences);
+  ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
 
   // If there are no fusable dependences or there are no tile+fusable loops,
   // just return.
-  if (tileFuseLoops.empty()) {
+  if (ret.fusedLoopDims.empty()) {
     return llvm::None;
   }
 
-  // Get the tile sizes for the first and second tiling steps. For the first
-  // step the tile size are set to zero for the loops that arent
-  // fused. Similarly for the second step, the tile sizes are set to zero for
-  // the loops that are fused. For example, if for the following input
-  //
-  // ```
-  //   linalg.add ins(%a, %b) outs(%c)
-  //   linalg.matmul ins(%d, %c) outs(%e)
-  // ```
-  //
-  // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}`
-  // respectively, and since only `j` can be tiled and fused. The tile sizes
-  // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable
-  // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile
-  // the tiled matmul generated by the first tiling step.
-  SmallVector<Value, 4> tileAndFuseSizes, tileSizes;
-  for (auto tileSize : enumerate(tileSizeVector)) {
-    auto zero = std_constant_index(0);
-    if (tileFuseLoops.count(tileSize.index())) {
-      tileAndFuseSizes.push_back(tileSize.value());
-      tileSizes.push_back(zero);
-    } else {
-      tileSizes.push_back(tileSize.value());
-      tileAndFuseSizes.push_back(zero);
-    }
-  }
-
-  // Tile for the loops that can be fused.
-  LinalgTilingOptions firstTilingOptions = tilingOptions;
-  firstTilingOptions.setTileSizes(tileAndFuseSizes);
-  Optional<TiledLinalgOp> firstTiledOp =
-      tileLinalgOp(rewriter, op, firstTilingOptions);
-  if (!firstTiledOp)
+  // Tile the fused loops in the last operation in the list.
+  SmallVector<Value, 4> tileSizeVector =
+      tilingOptions.tileSizeComputationFunction(builder, rootOp);
+  Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
+      builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
+  if (!tiledRootOp) {
+    rootOp.emitError("failed to tile the fused loops");
     return llvm::None;
-  ret.op = firstTiledOp->op;
-  ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end());
-
-  rewriter.setInsertionPoint(ret.op);
-  // Fuse the operands.
-  for (auto dependence : fusableDependences) {
-    LinalgOp producerOp = cast<LinalgOp>(dependence.first);
-    unsigned producerIdx =
-        dependence.second.front().dependentOpView.operandIndex;
-    unsigned consumerIdx =
-        dependence.second.front().indexingOpView.operandIndex;
-    LinalgOp fusedOp = fuse(rewriter, producerOp,
-                            producerOp.getOutputIndex(producerIdx).getValue(),
-                            ret.op, consumerIdx);
-    ret.fusedProducers.push_back(fusedOp);
-    ret.originalProducers.push_back(producerOp);
-  }
-
-  if (!llvm::all_of(tileSizes, isZero)) {
-    // Tile the remaining loops of the root operation.
-    LinalgTilingOptions secondTilingOptions = tilingOptions;
-    // The distribution is done only for the tile+fused loops.
-    secondTilingOptions.distribution = llvm::None;
-    secondTilingOptions.setTileSizes(tileSizes);
-    Optional<TiledLinalgOp> secondTiledOp =
-        tileLinalgOp(rewriter, ret.op, secondTilingOptions);
-    if (!secondTiledOp)
-      return llvm::None;
-    ret.unfusedLoops.assign(secondTiledOp->loops.begin(),
-                            secondTiledOp->loops.end());
-    rewriter.eraseOp(ret.op);
-    ret.op = secondTiledOp->op;
   }
+  ret.op = tiledRootOp->op;
+  ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
 
+  // Fuse the other operations into the fused inter-tile loops produced above.
+  ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(),
+                                      fusableDependences, ret.fusedLoopDims);
   return ret;
 }
 
 Optional<TiledAndFusedLinalgOps>
-mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
+mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
                                    const LinalgDependenceGraph &dependenceGraph,
-                                   const LinalgTilingOptions &tilingOptions,
-                                   const LinalgFusionOptions &fusionOptions) {
+                                   const LinalgTilingOptions &tilingOptions) {
   switch (tilingOptions.loopType) {
   case LinalgTilingLoopType::Loops:
-    return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph,
-                                                tilingOptions, fusionOptions);
+    return tileAndFuseLinalgOpsImpl<scf::ForOp>(builder, ops, dependenceGraph,
+                                                tilingOptions);
   case LinalgTilingLoopType::ParallelLoops:
     return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
-        rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
+        builder, ops, dependenceGraph, tilingOptions);
   default:;
   }
   return llvm::None;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 836cc28e0a47..a855c07cb8d4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -165,17 +165,69 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
   if (!linalgOp.hasBufferSemantics())
     return failure();
 
+  DenseSet<Operation *> producers;
+  producers.insert(linalgOp);
+  for (auto dependence : dependenceGraph.getDependentOperations(linalgOp)) {
+    if (!fusionOptions.indicesToFuse.count(
+            dependence.indexingOpView.operandIndex))
+      continue;
+    if (isa<LinalgOp>(dependence.dependentOpView.op))
+      producers.insert(dependence.dependentOpView.op);
+  }
+
+  SmallVector<LinalgOp, 1> fusionOps;
+  for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
+       ++it) {
+    auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
+    if (producerLinalgOp && producers.count(producerLinalgOp))
+      fusionOps.push_back(producerLinalgOp);
+  }
+  fusionOps.push_back(linalgOp);
+
+  SmallVector<Value, 4> tileSizes =
+      tilingOptions.tileSizeComputationFunction(rewriter, op);
+  LinalgTilingOptions instanceTilingOptions = tilingOptions;
+  instanceTilingOptions.setTileSizes(tileSizes);
   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
-      rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
+      rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
   if (!tiledAndFusedOps)
     return failure();
+
+  // Tile the unfused loops;
+  SmallVector<Value, 4> unfusedLoopTileSizes;
+  Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0);
+  for (auto tileSize : enumerate(tileSizes)) {
+    if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
+      unfusedLoopTileSizes.push_back(zero);
+    else
+      unfusedLoopTileSizes.push_back(tileSize.value());
+  }
+  // Tile the loop only if there is a non-zero tile size.
+  if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
+    unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
+  if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
+        if (auto cst = val.getDefiningOp<ConstantIndexOp>())
+          return cst.getValue() != 0;
+        return true;
+      })) {
+    LinalgTilingOptions unfusedTilingOptions = tilingOptions;
+    unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
+    Optional<TiledLinalgOp> unfusedTiledOp =
+        tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
+    if (!unfusedTiledOp)
+      return failure();
+    rewriter.eraseOp(tiledAndFusedOps->op);
+    tiledAndFusedOps->op = unfusedTiledOp->op;
+  }
+
   marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation());
   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
     fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation());
   }
-  for (auto origProducerOp : tiledAndFusedOps->originalProducers)
+  for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
     originalOpMarker.replaceLinalgMarker(rewriter,
                                          origProducerOp.getOperation());
+  }
   rewriter.updateRootInPlace(
       op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); });
   return success();

diff  --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
index 2ddc66651db2..fa471811ef4e 100644
--- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
@@ -47,7 +47,9 @@ module {
 //      CHECK:     %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]]
 //      CHECK:     %[[SV3:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]]
 // CHECK-SAME:       [%[[TILE_M_2]], %[[TILE_N_2]]]
-//      CHECK:     linalg.fill(%[[SV3]], %[[CST]])
+//      CHECK:     %[[SV3_2:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK-SAME:       [%[[TILE_M]], %[[TILE_N]]]
+//      CHECK:     linalg.fill(%[[SV3_2]], %[[CST]])
 // CHECK-SAME:       __internal_linalg_transform__ = "after_basic_fusion_producer"
 //      CHECK:     scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
 //      CHECK:       %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]]
@@ -109,9 +111,12 @@ module {
 //      CHECK:     %[[TILE_N_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[N_2]]]
 //      CHECK:     %[[SV2:.+]] = subview %[[ARG3]][0, %[[IV0]]]
 // CHECK-SAME:       [%[[M]], %[[TILE_N_2]]]
+//      CHECK:     %[[K_2:.+]] = dim %[[ARG1]], %[[C0]]
 //      CHECK:     %[[SV3:.+]] = subview %[[ARG1]][0, %[[IV0]]]
-// CHECK-SAME:       [%[[K]], %[[TILE_N]]]
-//      CHECK:     linalg.copy(%[[SV3]], %[[SV1]])
+// CHECK-SAME:       [%[[K_2]], %[[TILE_N]]]
+//      CHECK:     %[[SV3_2:.+]] = subview %[[ARG2]][0, %[[IV0]]]
+// CHECK-SAME:       [%[[K_2]], %[[TILE_N]]]
+//      CHECK:     linalg.copy(%[[SV3]], %[[SV3_2]])
 // CHECK-SAME:       __internal_linalg_transform__ = "after_rhs_fusion_producer"
 //  CHECK-NOT:     linalg.fill
 //  CHECK-DAG:     %[[M_2:.+]] = dim %[[ARG0]], %[[C0]]
@@ -186,11 +191,16 @@ module {
 //      CHECK:     %[[N:.+]] = dim %[[ARG3]], %[[C1]]
 //      CHECK:     %[[SV2:.+]] = subview %[[ARG3]][%[[IV0]], 0]
 // CHECK-SAME:       [%[[TILE_M_2]], %[[N]]]
+//      CHECK:     %[[SV2_2:.+]] = subview %[[ARG3]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M]], %[[N]]]
+//      CHECK:     %[[K_2:.+]] = dim %[[ARG0]], %[[C1]]
 //      CHECK:     %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME:       [%[[TILE_M]], %[[K]]]
-//      CHECK:     linalg.copy(%[[SV3]], %[[SV1]])
+// CHECK-SAME:       [%[[TILE_M]], %[[K_2]]]
+//      CHECK:     %[[SV3_2:.+]] = subview %[[ARG1]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M]], %[[K_2]]]
+//      CHECK:     linalg.copy(%[[SV3]], %[[SV3_2]])
 // CHECK-SAME:       __internal_linalg_transform__ = "after_two_operand_fusion_producer"
-//      CHECK:     linalg.fill(%[[SV2]], %[[CST]])
+//      CHECK:     linalg.fill(%[[SV2_2]], %[[CST]])
 // CHECK-SAME:       __internal_linalg_transform__ = "after_two_operand_fusion_producer"
 //  CHECK-DAG:     %[[N_2:.+]] = dim %[[ARG2]], %[[C1]]
 //      CHECK:     scf.parallel (%[[IV1:.+]]) =
@@ -261,15 +271,18 @@ module {
 //      CHECK:     %[[N:.+]] = dim %[[ARG4]], %[[C1]]
 //      CHECK:     %[[SV2:.+]] = subview %[[ARG4]][%[[IV0]], 0]
 // CHECK-SAME:       [%[[TILE_M_2]], %[[N]]]
+//      CHECK:     %[[K2_2:.+]] = dim %[[ARG1]], %[[C1]]
 //      CHECK:     %[[K1:.+]] = dim %[[ARG0]], %[[C1]]
 //      CHECK:     %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0]
 // CHECK-SAME:       [%[[TILE_M]], %[[K1]]]
-//      CHECK:     %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2]]]
+//      CHECK:     %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2_2]]]
+//      CHECK:     %[[SV1_2:.+]] = subview %[[ARG2]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M]], %[[K2_2]]]
 //      CHECK:     linalg.matmul
 // CHECK-SAME:         __internal_linalg_transform__ = "after_lhs_fusion_producer"
 // CHECK-SAME:         ins(%[[SV3]], %[[SV4]]
 // CHECK-SAME:           : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
-// CHECK-SAME:         outs(%[[SV1]] : memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME:         outs(%[[SV1_2]] : memref<?x?xf32, #[[MAP1]]>)
 //  CHECK-DAG:     %[[N_2:.+]] = dim %[[ARG3]], %[[C1]]
 //      CHECK:     scf.parallel (%[[IV1:.+]]) =
 // CHECK-SAME:       (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) {

diff  --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir
new file mode 100644
index 000000000000..a02c878ef341
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fusion-sequence.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s
+
+module {
+  func @three_op_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+                        %arg2: memref<?xf32>, %arg3 : memref<?x?xf32>) {
+    %cst = constant 0.000000e+00 : f32
+    %c0 = constant 0 : index
+    %c1 = constant 1 : index
+    %d0 = dim %arg0, %c0 : memref<?x?xf32>
+    %d1 = dim %arg1, %c1 : memref<?x?xf32>
+    %0 = alloc(%d0, %d1) : memref<?x?xf32>
+    linalg.fill(%0, %cst) : memref<?x?xf32>, f32
+    linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%0 : memref<?x?xf32>)
+    linalg.generic
+      {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                        affine_map<(d0, d1) -> (d1)>,
+                        affine_map<(d0, d1) -> (d0, d1)>],
+       iterator_types = ["parallel", "parallel"]}
+      ins(%0, %arg2 : memref<?x?xf32>, memref<?xf32>)
+      outs(%arg3 : memref<?x?xf32>) {
+      ^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) :
+        %5 = addf %arg4, %arg5 : f32
+        linalg.yield %5 : f32
+      }
+    return
+  }
+}
+
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+//       CHECK: func @three_op_fusion
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: memref<?xf32>
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//       CHECK:   %[[TEMP:.+]] = alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
+//       CHECK:   scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} {
+//   CHECK-DAG:     %[[SV_TEMP:.+]] = subview %[[TEMP]][%[[IV0]], %[[IV1]]]
+//   CHECK-DAG:     %[[SV_ARG2:.+]] = subview %[[ARG2]][%[[IV1]]]
+//   CHECK-DAG:     %[[SV_ARG3:.+]] = subview %[[ARG3]][%[[IV0]], %[[IV1]]]
+//   CHECK-DAG:     %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0]
+//   CHECK-DAG:     %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[IV1]]]
+//       CHECK:     linalg.fill(%[[SV_TEMP]], %{{.+}})
+//       CHECK:     linalg.matmul
+//  CHECK-SAME:       ins(%[[SV_ARG0]], %[[SV_ARG1]]
+//  CHECK-SAME:         : memref<?x?xf32, #[[MAP2]]>, memref<?x?xf32, #[[MAP2]]>)
+//  CHECK-SAME:       outs(%[[SV_TEMP]] : memref<?x?xf32, #[[MAP2]]>)
+//       CHECK:     linalg.generic
+//  CHECK-SAME:       ins(%[[SV_TEMP]], %[[SV_ARG2]]
+//  CHECK-SAME:         : memref<?x?xf32, #[[MAP2]]>, memref<?xf32, #[[MAP3]]>)
+//  CHECK-SAME:       outs(%[[SV_ARG3]] : memref<?x?xf32, #[[MAP2]]>)
+//       CHECK:     scf.yield
+//       CHECK:   }
+
+// -----
+
+module {
+  func @sequence_of_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+                           %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
+			   %arg4: memref<?x?xf32>) {
+    %cst = constant 0.000000e+00 : f32
+    %c0 = constant 0 : index
+    %c1 = constant 1 : index
+    %m = dim %arg0, %c0 : memref<?x?xf32>
+    %n1 = dim %arg1, %c1 : memref<?x?xf32>
+    %n2 = dim %arg2, %c1 : memref<?x?xf32>
+    %n3 = dim %arg3, %c1 : memref<?x?xf32>
+    %0 = alloc(%m, %n1) : memref<?x?xf32>
+    %1 = alloc(%m, %n2) : memref<?x?xf32>
+    linalg.fill(%0, %cst) : memref<?x?xf32>, f32
+    linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%0 : memref<?x?xf32>)
+    linalg.fill(%1, %cst) : memref<?x?xf32>, f32
+    linalg.matmul ins(%0, %arg2 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%1 : memref<?x?xf32>)
+    linalg.fill(%arg4, %cst) : memref<?x?xf32>, f32
+    linalg.matmul ins(%1, %arg3 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg4 : memref<?x?xf32>)
+    return
+  }
+}
+
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+//       CHECK: func @sequence_of_matmul
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[C16:.+]] = constant 16 : index
+//   CHECK-DAG:   %[[M:.+]] = dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[N1:.+]] = dim %[[ARG1]], %[[C1]]
+//   CHECK-DAG:   %[[N2:.+]] = dim %[[ARG2]], %[[C1]]
+//       CHECK:   %[[ALLOC1:.+]] = alloc(%[[M]], %[[N1]])
+//       CHECK:   %[[ALLOC2:.+]] = alloc(%[[M]], %[[N2]])
+//       CHECK:   scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]])
+//  CHECK-SAME:     step (%[[C16]]) {
+//       CHECK:     %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
+//       CHECK:     %[[SV_ALLOC2:.+]] = subview %[[ALLOC2]][%[[IV0]], 0]
+//  CHECK-SAME:       [%[[TILE_M]], %[[N2]]]
+//       CHECK:     %[[M_2:.+]] = dim %[[ARG4]], %[[C0]]
+//       CHECK:     %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
+//       CHECK:     %[[N3:.+]] = dim %[[ARG4]], %[[C1]]
+//       CHECK:     %[[SV_ARG4:.+]] = subview %[[ARG4]][%[[IV0]], 0]
+//  CHECK-SAME:       [%[[TILE_M_2]], %[[N3]]]
+//       CHECK:     %[[SV_ARG4_2:.+]] = subview %[[ARG4]][%[[IV0]], 0]
+//  CHECK-SAME:       [%[[TILE_M]], %[[N3]]]
+//       CHECK:     %[[SV_ALLOC1:.+]] = subview %[[ALLOC1]][%[[IV0]], 0]
+//  CHECK-SAME:       [%[[TILE_M]], %[[N1]]]
+//       CHECK:     %[[SV_ARG2:.+]] = subview %[[ARG2]][0, 0] [%[[N1]], %[[N2]]]
+//       CHECK:     %[[N0:.+]] = dim %[[ARG0]], %[[C1]]
+//       CHECK:     %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0]
+//  CHECK-SAME:       [%[[TILE_M:.+]], %[[N0]]]
+//       CHECK:     %[[SV_ARG1:.+]] = subview %[[ARG1]][0, 0] [%[[N0]], %[[N1]]]
+//       CHECK:     linalg.fill(%[[SV_ALLOC1]], %{{.+}})
+//       CHECK:     linalg.matmul ins(%[[SV_ARG0]], %[[SV_ARG1]]
+//  CHECK-SAME:        : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+//  CHECK-SAME:        outs(%[[SV_ALLOC1]] : memref<?x?xf32, #[[MAP1]]>)
+//       CHECK:     linalg.fill(%[[SV_ALLOC2]], %{{.+}})
+//       CHECK:     linalg.matmul ins(%[[SV_ALLOC1]], %[[SV_ARG2]]
+//  CHECK-SAME:        : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+//  CHECK-SAME:        outs(%[[SV_ALLOC2]] : memref<?x?xf32, #[[MAP1]]>)
+//       CHECK:     linalg.fill(%[[SV_ARG4_2]], %{{.+}})
+//       CHECK:     linalg.matmul ins(%[[SV_ALLOC2]], %[[ARG3]]
+//  CHECK-SAME:        : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
+//  CHECK-SAME:        outs(%[[SV_ARG4]] : memref<?x?xf32, #[[MAP1]]>)
+//       CHECK:     scf.yield
+//       CHECK:   }
+

diff  --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index eb9e3a533138..5289b2d1055f 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -197,6 +197,44 @@ struct TestLinalgGreedyFusion
     }
   }
 };
+
+/// Pass to test tile and fuse of sequence of operations. Intended only for
+/// testing.
+struct TestLinalgTileAndFuseSequencePass
+    : public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> {
+  TestLinalgTileAndFuseSequencePass() = default;
+  TestLinalgTileAndFuseSequencePass(
+      const TestLinalgTileAndFuseSequencePass &pass){};
+
+  ListOption<int64_t> tileSizes{
+      *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"),
+      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
+  }
+
+  void runOnFunction() override {
+    FuncOp funcOp = getOperation();
+    auto &blocks = funcOp.getBody().getBlocks();
+    if (!llvm::hasSingleElement(blocks)) {
+      return;
+    }
+    SmallVector<LinalgOp, 2> linalgOps =
+        llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
+    Aliases aliases;
+    LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
+    OpBuilder builder(funcOp.getContext());
+    Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
+        builder, linalgOps, dependenceGraph,
+        LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(
+            LinalgTilingLoopType::ParallelLoops));
+    if (!tileAndFuseOps)
+      return signalPassFailure();
+    for (auto op : linalgOps)
+      op.erase();
+  }
+};
 } // namespace
 
 namespace mlir {
@@ -211,5 +249,12 @@ void registerTestLinalgGreedyFusion() {
       "test-linalg-greedy-fusion",
       "Test Linalg fusion by applying a greedy test transformation.");
 }
+void registerTestLinalgTileAndFuseSequencePass() {
+  PassRegistration<TestLinalgTileAndFuseSequencePass>
+      testTileAndFuseSequencePass(
+          "test-linalg-tile-and-fuse",
+          "Test Linalg tiling and fusion of a sequence of Linalg operations.");
+}
+
 } // namespace test
 } // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 4771b11b20e4..a0e36cf82534 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -74,6 +74,7 @@ void registerTestLinalgCodegenStrategy();
 void registerTestLinalgFusionTransforms();
 void registerTestLinalgGreedyFusion();
 void registerTestLinalgHoisting();
+void registerTestLinalgTileAndFuseSequencePass();
 void registerTestLinalgTransforms();
 void registerTestLivenessPass();
 void registerTestLoopFusion();
@@ -141,6 +142,7 @@ void registerTestPasses() {
   test::registerTestLinalgFusionTransforms();
   test::registerTestLinalgGreedyFusion();
   test::registerTestLinalgHoisting();
+  test::registerTestLinalgTileAndFuseSequencePass();
   test::registerTestLinalgTransforms();
   test::registerTestLivenessPass();
   test::registerTestLoopFusion();


        


More information about the Mlir-commits mailing list