[Mlir-commits] [mlir] c694588 - [mlir][Linalg] Add pattern to tile and fuse Linalg operations on buffers.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 30 14:57:13 PDT 2020


Author: MaheshRavishankar
Date: 2020-09-30T14:56:58-07:00
New Revision: c694588fc52a8845174fee06ad0bcfa338e87816

URL: https://github.com/llvm/llvm-project/commit/c694588fc52a8845174fee06ad0bcfa338e87816
DIFF: https://github.com/llvm/llvm-project/commit/c694588fc52a8845174fee06ad0bcfa338e87816.diff

LOG: [mlir][Linalg] Add pattern to tile and fuse Linalg operations on buffers.

The pattern is structured similar to other patterns like
LinalgTilingPattern. The fusion patterns takes options that allows you
to fuse with producers of multiple operands at once.
- The pattern fuses only at the level that is known to be legal, i.e
  if a reduction loop in the consumer is tiled, then fusion should
  happen "before" this loop. Some refactoring of the fusion code is
  needed to fuse only where it is legal.
- Since the fusion on buffers uses the LinalgDependenceGraph that is
  not mutable in place the fusion pattern keeps the original
  operations in the IR, but are tagged with a marker that can be later
  used to find the original operations.

This change also fixes an issue with tiling and
distribution/interchange where if the tile size of a loop were 0 it
wasnt account for in these.

Differential Revision: https://reviews.llvm.org/D88435

Added: 
    mlir/test/Dialect/Linalg/fusion-pattern.mlir
    mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index 23d296c392ff..f51f7b913027 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -459,6 +459,24 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
             }));
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the position of buffer in inputs + outputs list
+      }],
+      /*retTy=*/"Optional<unsigned>",
+      /*methodName=*/"getIndexOfInputAndOutputBuffer",
+      /*args=*/(ins "Value":$value),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        Optional<unsigned> inputIndex = getIndexOfInput(value);
+        if (inputIndex.hasValue()) return inputIndex.getValue();
+        Optional<unsigned> outputIndex = getIndexOfOutputBuffer(value);
+        if (outputIndex.hasValue()) {
+          return $_op.getNumInputs() + outputIndex.getValue();
+        }
+        return llvm::None;
+      }]
+    >,
 
     //===------------------------------------------------------------------===//
     // Other interface methods.

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 00a094d72076..a7f8c31e2264 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -18,6 +18,7 @@
 namespace mlir {
 namespace linalg {
 
+struct LinalgFusionOptions;
 struct LinalgTilingOptions;
 
 //===----------------------------------------------------------------------===//
@@ -30,6 +31,14 @@ struct TiledLinalgOp {
   SmallVector<Operation *, 8> loops;
 };
 
+struct TiledAndFusedLinalgOps {
+  LinalgOp op;
+  SmallVector<LinalgOp, 1> fusedProducers;
+  SmallVector<LinalgOp, 1> originalProducers;
+  SmallVector<Operation *, 4> fusedLoops;
+  SmallVector<Operation *, 4> unfusedLoops;
+};
+
 /// Populates patterns for vectorization of all ConvN-D ops.
 void populateConvVectorizationPatterns(
     MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
@@ -53,6 +62,71 @@ void populateConvVectorizationPatterns(
 Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
                                      const LinalgTilingOptions &options);
 
+/// Tile and fuse the `op` with its producers. The tile and fuse proceeds in
+/// three steps
+/// - Find tile loops that are fusable with its producer tile loops (a.k.a. tile
+///   + fuse loops).
+/// - Tile just these loops of the consumer (root operation) and fuse with
+///   the producer.
+/// - Tile again the tiled consumer operation produced above to do rest of
+///   the tiling specified by the `tilingOptions`.
+///
+/// For example, consider the sequence of matmul below
+///
+///   linalg.matmul ins(%arg0, %arg1 : memref<256x32xf32>, memref<32x32xf32>)
+///                 outs(%arg2 : memref<256x32xf32>)
+///   linalg.matmul ins(%arg2, %arg3 : memref<256x32xf32>, memref<32x32xf32>)
+///                 outs(%arg4 : memref<256x32xf32>)
+///
+/// It is legal to fuse the RAW dependence (through %arg2) by only fusing the
+/// matmuls row-wise. For example, the fused computation for the above is shown
+/// below. The outer `scf.parallel` loop is the "fused" loop obtained by tiling
+/// along the rows of the matrix. The entire rows of the first matmul operation
+/// need to be computed before they can be used for the second matmul. The
+/// second matmul is further tiled (similar to normal tiling).
+///
+/// #map0 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>
+/// #map1 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+/// scf.parallel (%arg5) = (%c0) to (%c256) step (%c16) {
+///   %0 = subview %arg2[%arg5, 0] [16, 32] [1, 1]
+///     : memref<256x32xf32> to memref<16x32xf32, #map0>
+///   %1 = subview %arg4[%arg5, 0] [16, 32] [1, 1]
+///     : memref<256x32xf32> to memref<16x32xf32, #map0>
+///   %2 = subview %arg0[%arg5, 0] [16, 32] [1, 1]
+///     : memref<256x32xf32> to memref<16x32xf32, #map0>
+///   %3 = subview %arg1[0, 0] [32, 32] [1, 1]
+///     : memref<32x32xf32> to memref<32x32xf32, #map1>
+///   linalg.matmul
+///     ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>)
+///     outs(%0 : memref<16x32xf32, #map0>)
+///   scf.parallel (%arg6) = (%c0) to (%c32) step (%c8) {
+///   scf.for %arg7 = %c0 to %c32 step %c4 {
+///     %4 = subview %0[0, %arg7] [16, 4] [1, 1]
+///       : memref<16x32xf32, #map0> to memref<16x4xf32, #map0>
+///     %5 = subview %arg3[%arg7, %arg6] [4, 8] [1, 1]
+///       : memref<32x32xf32> to memref<4x8xf32, #map0>
+///     %6 = subview %1[0, %arg6] [16, 8] [1, 1]
+///       : memref<16x32xf32, #map0> to memref<16x8xf32, #map0>
+///     linalg.matmul
+///       ins(%4, %5 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
+///       outs(%6 : memref<16x8xf32, #map0>)
+///     }
+///     scf.yield
+///   }
+///   scf.yield
+/// }
+///
+/// The following tiling options are handled 
diff erently in tile+fuse (compared
+/// to tile only)
+/// - Interchange of the tiling loops is not supported right now.
+/// - Distribution is only done for the tile+fuse loops. The tiled loops
+///   generated by the second tiling is not distributed.
+Optional<TiledAndFusedLinalgOps>
+tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
+                     const LinalgDependenceGraph &dependenceGraph,
+                     const LinalgTilingOptions &tilingOptions,
+                     const LinalgFusionOptions &fusionOptions);
+
 /// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`.
 /// This is an in-place transformation controlled by `interchangeVector`.
 /// An empty vector is interpreted as the identity permutation and the
@@ -323,6 +397,63 @@ struct LinalgTilingPattern : public LinalgBaseTilingPattern {
   }
 };
 
+struct LinalgFusionOptions {
+  /// Optional list of operands indices to use for fusion. When unspecified,
+  /// only one fusion is done, i.e., the pattern returns after the first fusion.
+  Optional<DenseSet<unsigned>> indicesToFuse = None;
+  LinalgFusionOptions &setIndicesToFuse(ArrayRef<int64_t> operands) {
+    indicesToFuse = DenseSet<unsigned>();
+    indicesToFuse->insert(operands.begin(), operands.end());
+    return *this;
+  }
+};
+
+struct LinalgBaseTileAndFusePattern : public RewritePattern {
+  LinalgBaseTileAndFusePattern(StringRef opName, MLIRContext *context,
+                               const LinalgDependenceGraph &dependenceGraph,
+                               LinalgTilingOptions tilingOptions,
+                               LinalgFusionOptions fusionOptions,
+                               LinalgMarker marker = LinalgMarker(),
+                               LinalgMarker fusedOpMarker = LinalgMarker(),
+                               LinalgMarker originalOpMarker = LinalgMarker(),
+                               PatternBenefit benefit = 1);
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// Dependence graph needed for fusion.
+  const LinalgDependenceGraph &dependenceGraph;
+  /// Options to control tiling.
+  LinalgTilingOptions tilingOptions;
+  /// Options to control fusion.
+  LinalgFusionOptions fusionOptions;
+  /// Marker to control application of the pattern.
+  LinalgMarker marker;
+  /// Marker set on the fused op after tile and fuse.
+  LinalgMarker fusedOpMarker;
+  /// The dependenceGraph is not modifiable, i.e. if the Linalg operations used
+  /// to build the dependence graph changes then the dependenceGraph needs to be
+  /// recomputed right now. To not invalidate the dependenceGraph as
+  /// transformation happens, the original producer can be tagged with a marker
+  /// that can be later used to delete the original operations.
+  LinalgMarker originalOpMarker;
+};
+
+template <typename OpTy>
+struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
+  LinalgTileAndFusePattern(MLIRContext *context,
+                           const LinalgDependenceGraph &dependenceGraph,
+                           LinalgTilingOptions tilingOptions,
+                           LinalgFusionOptions fusionOptions,
+                           LinalgMarker marker = LinalgMarker(),
+                           LinalgMarker fusedOpMarker = LinalgMarker(),
+                           LinalgMarker originalOpMarker = LinalgMarker(),
+                           PatternBenefit benefit = 1)
+      : LinalgBaseTileAndFusePattern(
+            OpTy::getOperationName(), context, dependenceGraph, tilingOptions,
+            fusionOptions, marker, fusedOpMarker, originalOpMarker, benefit) {}
+};
+
 ///
 /// Linalg interchange patterns.
 ///

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index aca5a981b003..76ce4eb30e7f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_LINALG_UTILS_H_
 
 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/EDSC/Builders.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/SCF/SCF.h"

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index dfc977daa207..8dadfe63e659 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
 #include "mlir/IR/AffineExpr.h"
@@ -154,9 +155,9 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
   llvm_unreachable("Expect to be able to extract a view defining loop range");
 }
 
-static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
-                     unsigned consumerIdx, unsigned producerIdx,
-                     OperationFolder *folder) {
+static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
+                     LinalgOp consumer, unsigned consumerIdx,
+                     OperationFolder *folder = nullptr) {
   assert(producer.hasBufferSemantics() &&
          "expected linalg op with buffer semantics");
   assert(consumer.hasBufferSemantics() &&
@@ -174,9 +175,7 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
   //   we can always identify a data dimension with a (at least one) loop
   //   dimension.
   AffineMap producerMap =
-      producer.indexing_maps()[producer.getNumInputs() + producerIdx]
-          .cast<AffineMapAttr>()
-          .getValue();
+      producer.indexing_maps()[producerIdx].cast<AffineMapAttr>().getValue();
   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
                     << ", producer map: " << producerMap << "\n");
 
@@ -185,10 +184,9 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
   unsigned nWin = producer.getNumWindowLoops();
   SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
 
-  OpBuilder b(consumer.getOperation());
-  auto loc = consumer.getLoc();
   // Iterate over dimensions identified by the producer map for `producerIdx`.
   // This defines a subset of the loop ranges that we need to complete later.
+  auto loc = consumer.getLoc();
   for (auto en : llvm::enumerate(producerMap.getResults())) {
     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
     loopRanges[posInProducerLoop] =
@@ -319,71 +317,380 @@ static bool isSameSubView(Value a, Value b) {
   return true;
 }
 
-static Optional<FusionInfo>
-fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
-                  const LinalgDependenceGraph &graph, OperationFolder *folder,
-                  LinalgDependenceGraph::DependenceType depType) {
-  assert(consumer.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
-  LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
-                    << *consumer.getOperation());
-  for (auto dependence : graph.getDependencesInto(consumer, depType)) {
-    LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
-                      << *dependence.dependentOpView.op << "\n");
-    auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
-
-    // Check that the dependence is indeed on the input `consumerIdx` view.
-    auto consumedView = dependence.indexingView;
-    if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
-      continue;
-
-    // Consumer consumes this view, `isStructurallyFusableProducer` also checks
-    // whether it is a strict subview of the producer view.
-    auto producedView = dependence.dependentOpView.view;
-    auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue();
-    // `consumerIdx` and `producerIdx` exist by construction.
-    LLVM_DEBUG(dbgs() << "\n"
-                      << LinalgDependenceGraph::getDependenceTypeStr(depType)
-                      << "producer: " << *producer.getOperation() << " view: "
-                      << producedView << " output index: " << producerIdx);
-
-    // Must be a subview or a slice to guarantee there are loops we can fuse
-    // into.
-    auto subView = consumedView.getDefiningOp<SubViewOp>();
-    auto slice = consumedView.getDefiningOp<SliceOp>();
-    if (!subView && !slice) {
-      LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
-      continue;
-    }
+static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
+findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
+                    const LinalgDependenceGraph &dependenceGraph) {
+  // Only consider RAW and WAW atm.
+  for (auto depType : {
+           LinalgDependenceGraph::DependenceType::RAW,
+           LinalgDependenceGraph::DependenceType::WAW,
+       }) {
+    for (auto dependence :
+         dependenceGraph.getDependencesInto(consumer, depType)) {
+      auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
 
-    // Simple fusability checks.
-    if (!isFusableInto(graph, consumer, consumedView, producer))
-      continue;
+      // Check that the dependence is indeed on the input `consumerIdx` view.
+      auto consumedView = dependence.indexingView;
+      if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
+        continue;
 
-    // Fuse `producer` just before `consumer`.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(consumer.getOperation());
-    ScopedContext scope(b, consumer.getLoc());
-    LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
-    auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx,
-                              producerIdx, folder);
+      // Consumer consumes this view, `isStructurallyFusableProducer` also
+      // checks whether it is a strict subview of the producer view.
+      auto producedView = dependence.dependentOpView.view;
+      auto producerIdx =
+          producer.getIndexOfOutputBuffer(producedView).getValue();
+      // `consumerIdx` and `producerIdx` exist by construction.
+      LLVM_DEBUG(dbgs() << "\n"
+                        << LinalgDependenceGraph::getDependenceTypeStr(depType)
+                        << "producer: " << *producer.getOperation() << " view: "
+                        << producedView << " output index: " << producerIdx);
+      (void)producerIdx;
+
+      // Simple fusability checks.
+      if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
+        continue;
 
-    return FusionInfo{producer, fusedProducer};
+      return dependence;
+    }
   }
-  return llvm::None;
+  return {};
 }
 
-// Only consider RAW and WAW atm.
 Optional<FusionInfo> mlir::linalg::fuseProducerOf(
     OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
     const LinalgDependenceGraph &graph, OperationFolder *folder) {
-  for (auto dep : {
-           LinalgDependenceGraph::DependenceType::RAW,
-           LinalgDependenceGraph::DependenceType::WAW,
-       }) {
-    if (auto res =
-            fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep))
-      return res;
+  Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
+      findFusableProducer(consumer, consumerIdx, graph);
+  if (!fusableDependence)
+    return {};
+
+  LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
+  Value producerView = fusableDependence->dependentOpView.view;
+  Value consumerView = fusableDependence->indexingView;
+
+  // Must be a subview or a slice to guarantee there are loops we can fuse
+  // into.
+  auto subView = consumerView.getDefiningOp<SubViewOp>();
+  auto slice = consumerView.getDefiningOp<SliceOp>();
+  if (!subView && !slice) {
+    LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
+    return {};
+  }
+
+  // Fuse `producer` just before `consumer`.
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(consumer.getOperation());
+  ScopedContext scope(b, consumer.getLoc());
+  LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
+  Optional<unsigned> producerIdxOpt =
+      producerOp.getIndexOfInputAndOutputBuffer(producerView);
+  assert(producerIdxOpt.hasValue() && "incorrect operand index");
+  unsigned producerIdx = producerIdxOpt.getValue();
+
+  auto fusedProducer =
+      fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder);
+  return FusionInfo{producerOp, fusedProducer};
+}
+
+/// Returns the positions of the loop in `op` that can be tiled based on the
+/// operations that are to be fused with it. For example, in a
+///
+///   linalg. matmul ins(%a, %b : ...) outs(%c : ...)
+///
+/// if the producer of %a needs to be fused with this op, only the `i` loop of
+/// the matmul can be tiled while fusing. If producer of %a, and %b are to be
+/// fused, then no loops can be tiled while fusing.
+static DenseSet<unsigned> collectTileAndFuseLoops(
+    LinalgOp op, ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem>
+                     fusableDependences) {
+  // 1. Only parallel loops can be used for tile + fuse. Find the number of
+  // common outer parallel loops between the op and its producers being fused.
+  auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
+    return linalgOp.iterator_types()
+        .getValue()
+        .take_while([](Attribute attr) -> bool {
+          return attr.cast<StringAttr>().getValue() ==
+                 getParallelIteratorTypeName();
+        })
+        .size();
+  };
+
+  size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
+  for (auto dependence : fusableDependences) {
+    numOuterParallelLoops =
+        std::min(numOuterParallelLoops, getNumOuterParallelLoops(cast<LinalgOp>(
+                                            dependence.dependentOpView.op)));
+  }
+
+  // Need to compute what tiled loops can be "fused". Given the precondition
+  // that all indexing map for the producer view is a projected permutation, we
+  // can assert that the producer iterates over the dimensions of the "fused
+  // view" only once. To be used a fused loop the producer should use this loop
+  // to access the fused view. For example, consider
+  //
+  // ```
+  //   linalg.add ins(%a, %b) outs(%c)
+  //   linalg.matmul ins(%d, %c) outs(%e)
+  // ```
+  //
+  // if `linalg.add` has the semantics of `c = a + b`, then the following
+  // tile+fuse code is correct.
+  //
+  // ```
+  // for j ... += TSj
+  //   %sa = subview %a[0, %j][...]
+  //   %sb = subview %b[0, %j][...]
+  //   %sc = subview %c[0, %j][...]
+  //   %sd = subview %d[0, 0][...]
+  //   %se = subview %e[0, %j][...]
+  //   linalg.add ins(%sa, %sb) outs(%sc)
+  //   linalg.matmul ins(%sd, %sc) outs(%se)
+  // ```
+  //
+  // On the other hand tiling along i would be incorrect
+  //
+  // ```
+  // for %i .. += TSi
+  //   %sa = subview %a[%i, 0][...]
+  //   %sb = subview %b[%i, 0][...]
+  //   %sc = subview %c[%i, 0][...]
+  //   %sc2 = subview %c[0, 0][...]
+  //   %sd = subview %d[%i, 0][...]
+  //   %se = subview %e[%i, 0][...]
+  //   linalg.add ins(%sa, %sb) outs(%sc)
+  //   linalg.matmul ins(%sd, %sc2) outs(%se)
+  // ```
+  //
+  // The write to the subview `%sc` in `linalg.add` is performed after the read
+  // from it using `%sc2` violating the RAW dependence of the original code. To
+  // find such loops indexing map of the fused view in the consumer op is
+  // used. For the above example, this indexing map is
+  //
+  //   affine_map<(d0, d1, d2) -> (d2, d1)>
+  //
+  // Since d0 is not in the result expressions of this map, it is not treated as
+  // tile + fuse loop, (but d1 is).
+  //
+  // TODO: The above is probably restrictive and there might be a generalization
+  // of these that might allow for more fusion opportunities. Explore based on
+  // needs.
+  SmallVector<DenseSet<unsigned>, 1> commonTilableLoops;
+  for (auto dependence : fusableDependences) {
+    unsigned consumerIdx =
+        op.getIndexOfInputAndOutputBuffer(dependence.indexingView).getValue();
+    AffineMap consumerAccess = op.getIndexingMap(consumerIdx);
+    // Previously asserted that the consumerAccess map is a projected
+    // permutation, so all results are known to be AffineDimExprs. To remove
+    // this restriction walk the expression to find which dimensions of the
+    // consumer loop appear in the `consumerAccess`.
+    DenseSet<unsigned> positions;
+    for (auto expr : consumerAccess.getResults())
+      positions.insert(expr.cast<AffineDimExpr>().getPosition());
+    commonTilableLoops.emplace_back(std::move(positions));
+  }
+
+  // 2. Of the outer parallel loops, only those loops can be tiled + fused as
+  // computed above for all the fused dependences can be used to tile and fuse.
+  DenseSet<unsigned> tilableParallelLoops;
+  for (auto index : llvm::seq<unsigned>(0, numOuterParallelLoops)) {
+    if (llvm::all_of(commonTilableLoops,
+                     [&](const DenseSet<unsigned> &tilableLoops) {
+                       return tilableLoops.count(index);
+                     }))
+      tilableParallelLoops.insert(index);
+  }
+  return tilableParallelLoops;
+}
+
+/// Find all dependences that are to be fusable.
+static Optional<
+    SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
+findAllFusableDependences(LinalgOp op,
+                          const LinalgDependenceGraph &dependenceGraph,
+                          const LinalgFusionOptions &fusionOptions) {
+  SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>
+      fusableDependences;
+  for (auto operand : llvm::enumerate(op.getInputsAndOutputBuffers())) {
+    if (fusionOptions.indicesToFuse &&
+        !fusionOptions.indicesToFuse->count(operand.index()))
+      continue;
+    Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
+        fusableDependence =
+            findFusableProducer(op, operand.index(), dependenceGraph);
+    if (!fusableDependence)
+      continue;
+    // Make sure that the indexing map of the view used for fusion in the
+    // producer is a projected permutation.
+    LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
+    Value producerView = fusableDependence->dependentOpView.view;
+    unsigned producerIdx =
+        producerOp.getIndexOfInputAndOutputBuffer(producerView).getValue();
+    AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
+    if (!producerMap.isProjectedPermutation()) {
+      op.emitError("unhandled non permutation indexing map for fused view in "
+                   "producer for operand at index ")
+          << operand.index();
+      return llvm::None;
+    }
+    Value consumerView = fusableDependence->indexingView;
+    unsigned consumerIdx =
+        op.getIndexOfInputAndOutputBuffer(consumerView).getValue();
+    if (!op.getIndexingMap(consumerIdx).isProjectedPermutation()) {
+      op.emitError(
+          "unhandled case where indexing map for fused view in the consumer is "
+          "not a projected permuration while fusing at index ")
+          << operand.index();
+      return llvm::None;
+    }
+    fusableDependences.push_back(*fusableDependence);
+    if (!fusionOptions.indicesToFuse)
+      break;
+  }
+  return fusableDependences;
+}
+
+static bool isZero(Value v) {
+  if (auto cst = v.getDefiningOp<ConstantIndexOp>())
+    return cst.getValue() == 0;
+  return false;
+}
+
+template <typename LoopType>
+static Optional<TiledAndFusedLinalgOps>
+tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
+                         const LinalgDependenceGraph &dependenceGraph,
+                         const LinalgTilingOptions &tilingOptions,
+                         const LinalgFusionOptions &fusionOptions) {
+  assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
+  // Some of the tiling options might not be supportable with tile and fuse.
+  // TODO: Support interchange with tile + fuse.
+  if (!tilingOptions.interchangeVector.empty()) {
+    op.emitError("unable to handle tile and fuse with interchange");
+    return llvm::None;
+  }
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
+  ScopedContext scope(rewriter, op.getLoc());
+
+  // Find all the producers.
+  Optional<SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
+      fusableDependencesOpt =
+          findAllFusableDependences(op, dependenceGraph, fusionOptions);
+  if (!fusableDependencesOpt)
+    return llvm::None;
+  ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependences(
+      *fusableDependencesOpt);
+
+  // Enforce the convention that "tiling by zero" skips tiling a particular
+  // dimension. This convention is significantly simpler to handle instead of
+  // adjusting affine maps to account for missing dimensions.
+  auto nLoops = op.getNumLoops();
+  SmallVector<Value, 4> tileSizeVector =
+      tilingOptions.tileSizeComputationFunction(rewriter, op);
+  if (tileSizeVector.size() < nLoops) {
+    auto zero = std_constant_index(0);
+    tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
+  }
+
+  TiledAndFusedLinalgOps ret;
+
+  // Find the loops that can be tiled and fused.
+  DenseSet<unsigned> tileFuseLoops =
+      collectTileAndFuseLoops(op, fusableDependences);
+
+  // If there are no fusable dependences or there are no tile+fusable loops,
+  // just return.
+  if (fusableDependences.empty() || tileFuseLoops.empty()) {
+    return llvm::None;
+  }
+
+  // Get the tile sizes for the first and second tiling steps. For the first
+  // step the tile size are set to zero for the loops that arent
+  // fused. Similarly for the second step, the tile sizes are set to zero for
+  // the loops that are fused. For example, if for the following input
+  //
+  // ```
+  //   linalg.add ins(%a, %b) outs(%c)
+  //   linalg.matmul ins(%d, %c) outs(%e)
+  // ```
+  //
+  // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}`
+  // respectively, and since only `j` can be tiled and fused. The tile sizes
+  // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable
+  // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile
+  // the tiled matmul generated by the first tiling step.
+  SmallVector<Value, 4> tileAndFuseSizes, tileSizes;
+  for (auto tileSize : enumerate(tileSizeVector)) {
+    auto zero = std_constant_index(0);
+    if (tileFuseLoops.count(tileSize.index())) {
+      tileAndFuseSizes.push_back(tileSize.value());
+      tileSizes.push_back(zero);
+    } else {
+      tileSizes.push_back(tileSize.value());
+      tileAndFuseSizes.push_back(zero);
+    }
+  }
+
+  // Tile for the loops that can be fused.
+  LinalgTilingOptions firstTilingOptions = tilingOptions;
+  firstTilingOptions.setTileSizes(tileAndFuseSizes);
+  Optional<TiledLinalgOp> firstTiledOp =
+      tileLinalgOp(rewriter, op, firstTilingOptions);
+  if (!firstTiledOp)
+    return llvm::None;
+  ret.op = firstTiledOp->op;
+  ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end());
+
+  rewriter.setInsertionPoint(ret.op);
+  // Fuse the operands.
+  for (auto producer : enumerate(fusableDependences)) {
+    LinalgOp producerOp = cast<LinalgOp>(producer.value().dependentOpView.op);
+    unsigned producerIdx = producerOp
+                               .getIndexOfInputAndOutputBuffer(
+                                   producer.value().dependentOpView.view)
+                               .getValue();
+    unsigned consumerIdx =
+        op.getIndexOfInputAndOutputBuffer(producer.value().indexingView)
+            .getValue();
+    LinalgOp fusedOp =
+        fuse(rewriter, producerOp, producerIdx, ret.op, consumerIdx);
+    ret.fusedProducers.push_back(fusedOp);
+    ret.originalProducers.push_back(producerOp);
+  }
+
+  if (!llvm::all_of(tileSizes, isZero)) {
+    // Tile the remaining loops of the root operation.
+    LinalgTilingOptions secondTilingOptions = tilingOptions;
+    // The distribution is done only for the tile+fused loops.
+    secondTilingOptions.distribution = llvm::None;
+    secondTilingOptions.setTileSizes(tileSizes);
+    Optional<TiledLinalgOp> secondTiledOp =
+        tileLinalgOp(rewriter, ret.op, secondTilingOptions);
+    if (!secondTiledOp)
+      return llvm::None;
+    ret.unfusedLoops.assign(secondTiledOp->loops.begin(),
+                            secondTiledOp->loops.end());
+    rewriter.eraseOp(ret.op);
+    ret.op = secondTiledOp->op;
+  }
+
+  return ret;
+}
+
+Optional<TiledAndFusedLinalgOps>
+mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
+                                   const LinalgDependenceGraph &dependenceGraph,
+                                   const LinalgTilingOptions &tilingOptions,
+                                   const LinalgFusionOptions &fusionOptions) {
+  switch (tilingOptions.loopType) {
+  case LinalgTilingLoopType::Loops:
+    return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph,
+                                                tilingOptions, fusionOptions);
+  case LinalgTilingLoopType::ParallelLoops:
+    return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
+        rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
+  default:;
   }
   return llvm::None;
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 3db801bc2d57..68d69549611c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -318,25 +318,10 @@ static SmallVector<Value, 4> makeTiledViews(OpBuilder &b, Location loc,
 }
 
 template <typename LoopTy>
-Optional<TiledLinalgOp> static tileLinalgOpImpl(
-    OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) {
-  OpBuilder::InsertionGuard g(b);
-  b.setInsertionPoint(op);
-  ScopedContext scope(b, op.getLoc());
-
-  assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
-  // 1. Enforce the convention that "tiling by zero" skips tiling a particular
-  // dimension. This convention is significantly simpler to handle instead of
-  // adjusting affine maps to account for missing dimensions.
+static Optional<TiledLinalgOp>
+tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
+                 const LinalgTilingOptions &options) {
   auto nLoops = op.getNumLoops();
-  SmallVector<Value, 4> tileSizeVector =
-      options.tileSizeComputationFunction(b, op);
-  if (tileSizeVector.size() < nLoops) {
-    auto zero = std_constant_index(0);
-    tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
-  }
-
-  ArrayRef<Value> tileSizes = tileSizeVector;
   // Initial tile sizes may be too big, only take the first nLoops.
   tileSizes = tileSizes.take_front(nLoops);
 
@@ -350,17 +335,7 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
       return llvm::None;
   }
 
-  // If interchangeVector is empty, use the identity. Build the permutation map
-  // otherwise.
-  auto invPermutationMap =
-      AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext());
-  if (!options.interchangeVector.empty())
-    invPermutationMap = inversePermutation(AffineMap::getPermutationMap(
-        options.interchangeVector, b.getContext()));
-  if (!invPermutationMap)
-    return llvm::None;
-
-  // 2. Build the tiled loop ranges.
+  // 1. Build the tiled loop ranges.
   auto allViewSizes = getViewSizes(b, op);
   // The flattened loopToOperandRangesMaps is expected to be an invertible
   // permutation map (asserted in the inverse calculation).
@@ -374,17 +349,39 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
   SmallVector<SubViewOp::Range, 4> loopRanges;
   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
   std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges(
-      b, scope.getLocation(), viewSizesToLoopsMap, allViewSizes, tileSizes);
-  if (!options.interchangeVector.empty())
-    applyPermutationToVector(loopRanges, options.interchangeVector);
+      b, op.getLoc(), viewSizesToLoopsMap, allViewSizes, tileSizes);
+  SmallVector<Attribute, 4> iteratorTypes;
+  for (auto attr :
+       enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) {
+    if (loopIndexToRangeIndex.count(attr.index()))
+      iteratorTypes.push_back(attr.value());
+  }
+  // If interchangeVector is empty, use the identity. Build the permutation map
+  // otherwise.
+  auto invPermutationMap =
+      AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext());
+  if (!options.interchangeVector.empty()) {
+    // Based on the pruned iterations (due to zero tile size), recompute the
+    // interchange vector.
+    SmallVector<unsigned, 4> interchangeVector;
+    interchangeVector.reserve(options.interchangeVector.size());
+    for (auto pos : options.interchangeVector) {
+      auto it = loopIndexToRangeIndex.find(pos);
+      if (it == loopIndexToRangeIndex.end())
+        continue;
+      interchangeVector.push_back(it->second);
+    }
+    invPermutationMap = inversePermutation(
+        AffineMap::getPermutationMap(interchangeVector, b.getContext()));
+    if (!invPermutationMap)
+      return llvm::None;
+    applyPermutationToVector(loopRanges, interchangeVector);
+    applyPermutationToVector(iteratorTypes, interchangeVector);
+  }
 
-  // 3. Create the tiled loops.
+  // 2. Create the tiled loops.
   LinalgOp res = op;
   SmallVector<Value, 4> ivs;
-  SmallVector<Attribute, 4> iteratorTypes =
-      llvm::to_vector<4>(op.iterator_types().cast<ArrayAttr>().getValue());
-  if (!options.interchangeVector.empty())
-    applyPermutationToVector(iteratorTypes, options.interchangeVector);
   GenerateLoopNest<LoopTy>::doit(
       loopRanges, /*iterArgInitValues*/ {}, iteratorTypes,
       [&](ValueRange localIvs, ValueRange iterArgs) -> scf::ValueVector {
@@ -410,10 +407,10 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
       },
       options.distribution);
 
-  // 4. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
+  // 3. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
   transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex);
 
-  // 5. Gather the newly created loops and return them with the new op.
+  // 4. Gather the newly created loops and return them with the new op.
   SmallVector<Operation *, 8> loops;
   loops.reserve(ivs.size());
   for (auto iv : ivs) {
@@ -429,14 +426,38 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
   return TiledLinalgOp{res, loops};
 }
 
+template <typename LoopTy>
+Optional<TiledLinalgOp> static tileLinalgOpImpl(
+    OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) {
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(op);
+  ScopedContext scope(b, op.getLoc());
+
+  assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
+  // Enforce the convention that "tiling by zero" skips tiling a particular
+  // dimension. This convention is significantly simpler to handle instead of
+  // adjusting affine maps to account for missing dimensions.
+  auto nLoops = op.getNumLoops();
+  SmallVector<Value, 4> tileSizeVector =
+      options.tileSizeComputationFunction(b, op);
+  if (tileSizeVector.size() < nLoops) {
+    auto zero = std_constant_index(0);
+    tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
+  }
+
+  return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
+}
+
 Optional<TiledLinalgOp>
 mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
                            const LinalgTilingOptions &options) {
-  if (options.loopType == LinalgTilingLoopType::Loops)
+  switch (options.loopType) {
+  case LinalgTilingLoopType::Loops:
     return tileLinalgOpImpl<scf::ForOp>(b, op, options);
-  if (options.loopType == LinalgTilingLoopType::ParallelLoops)
+  case LinalgTilingLoopType::ParallelLoops:
     return tileLinalgOpImpl<scf::ParallelOp>(b, op, options);
-  // TODO: Impl tiling to affine loops when it makes sense.
+  default:;
+  }
   return llvm::None;
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index c1aad620fe08..56652cbcb527 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -129,6 +129,43 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite(
   return success();
 }
 
+mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
+    StringRef opName, MLIRContext *context,
+    const LinalgDependenceGraph &dependenceGraph,
+    LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
+    LinalgMarker marker, LinalgMarker fusedOpMarker,
+    LinalgMarker originalOpMarker, PatternBenefit benefit)
+    : RewritePattern(opName, {}, benefit, context),
+      dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
+      fusionOptions(fusionOptions), marker(marker),
+      fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {}
+
+LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
+    Operation *op, PatternRewriter &rewriter) const {
+  LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
+  if (!linalgOp)
+    return failure();
+  if (failed(marker.checkAndNotify(rewriter, linalgOp)))
+    return failure();
+  if (!linalgOp.hasBufferSemantics())
+    return failure();
+
+  Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
+      rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
+  if (!tiledAndFusedOps)
+    return failure();
+  marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation());
+  for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
+    fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation());
+  }
+  for (auto origProducerOp : tiledAndFusedOps->originalProducers)
+    originalOpMarker.replaceLinalgMarker(rewriter,
+                                         origProducerOp.getOperation());
+  rewriter.updateRootInPlace(
+      op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); });
+  return success();
+}
+
 /// Linalg base interchange pattern.
 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
     StringRef opName, MLIRContext *context,

diff  --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
new file mode 100644
index 000000000000..61e5b746deac
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
@@ -0,0 +1,297 @@
+// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file | FileCheck %s
+
+module {
+  func @basic_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+                     %arg2: memref<?x?xf32>) {
+    %cst = constant 0.000000e+00 : f32
+    linalg.fill(%arg2, %cst) : memref<?x?xf32>, f32
+    linalg.matmul {__internal_linalg_transform__ = "basic_fusion"}
+      ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg2 : memref<?x?xf32>)
+    return
+  }
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+//      CHECK: func @basic_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//  CHECK-DAG:   %[[C32:.+]] = constant 32 : index
+//  CHECK-DAG:   %[[C64:.+]] = constant 64 : index
+//  CHECK-DAG:   %[[C16:.+]] = constant 16 : index
+//  CHECK-DAG:   %[[CST:.+]] = constant 0.0{{.*}} : f32
+//  CHECK-DAG:   linalg.fill(%[[ARG2]], %[[CST]])
+// CHECK-SAME:   __internal_linalg_transform__ = "after_basic_fusion_original"
+//  CHECK-DAG:   %[[M:.+]] = dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[N:.+]] = dim %[[ARG1]], %[[C1]]
+//      CHECK:   scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) =
+// CHECK-SAME:     to (%[[M]], %[[N]])
+// CHECK-SAME:     step (%[[C32]], %[[C64]]) {
+//      CHECK:     %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
+//      CHECK:     %[[K:.+]] = dim %[[ARG0]], %[[C1]]
+//      CHECK:     %[[SV1:.+]] = subview %[[ARG0]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M]], %[[K]]]
+//      CHECK:     %[[K_2:.+]] = dim %[[ARG1]], %[[C0]]
+//      CHECK:     %[[TILE_N:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N]]]
+//      CHECK:     %[[SV2:.+]] = subview %[[ARG1]][0, %[[IV1]]]
+// CHECK-SAME:       %[[K_2]], %[[TILE_N]]
+//      CHECK:     %[[M_2:.+]] = dim %[[ARG2]], %[[C0]]
+//      CHECK:     %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
+//      CHECK:     %[[N_2:.+]] = dim %[[ARG2]], %[[C1]]
+//      CHECK:     %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]]
+//      CHECK:     %[[SV3:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK-SAME:       [%[[TILE_M_2]], %[[TILE_N_2]]]
+//      CHECK:     linalg.fill(%[[SV3]], %[[CST]])
+// CHECK-SAME:       __internal_linalg_transform__ = "after_basic_fusion_producer"
+//      CHECK:     scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
+//      CHECK:       %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]]
+//      CHECK:       %[[SV4:.+]] = subview %[[SV1]][0, %[[IV2]]]
+// CHECK-SAME:         [%[[TILE_M]], %[[TILE_K]]]
+//      CHECK:       %[[TILE_K_2:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K_2]]]
+//      CHECK:       %[[SV5:.+]] = subview %[[SV2]][%[[IV2]], 0]
+// CHECK-SAME:         [%[[TILE_K_2]], %[[TILE_N]]]
+//      CHECK:       linalg.matmul
+// CHECK-SAME:         __internal_linalg_transform__ = "after_basic_fusion"
+// CHECK-SAME:         ins(%[[SV4]], %[[SV5]]
+// CHECK-SAME:           : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME:         outs(%[[SV3]] : memref<?x?xf32, #[[MAP1]]>)
+//      CHECK:     }
+//      CHECK:   }
+//      CHECK:   linalg.matmul
+// CHECK-SAME:     __internal_linalg_transform__ = "after_basic_fusion_original"
+
+// -----
+
+module {
+  func @rhs_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+                              %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) {
+    %cst = constant 0.000000e+00 : f32
+    linalg.copy(%arg1, %arg2) : memref<?x?xf32>, memref<?x?xf32>
+    linalg.fill(%arg3, %cst) : memref<?x?xf32>, f32
+    linalg.matmul {__internal_linalg_transform__ = "rhs_fusion"}
+      ins(%arg0, %arg2 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg3 : memref<?x?xf32>)
+    return
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+//      CHECK: func @rhs_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//  CHECK-DAG:   %[[C32:.+]] = constant 32 : index
+//  CHECK-DAG:   %[[C64:.+]] = constant 64 : index
+//  CHECK-DAG:   %[[C16:.+]] = constant 16 : index
+//  CHECK-DAG:   %[[CST:.+]] = constant 0.0{{.*}} : f32
+//  CHECK-DAG:   linalg.copy(%[[ARG1]], %[[ARG2]])
+// CHECK-SAME:   __internal_linalg_transform__ = "after_rhs_fusion_original"
+//  CHECK-DAG:   %[[N:.+]] = dim %[[ARG2]], %[[C1]]
+//      CHECK:   scf.parallel (%[[IV0:.+]]) =
+// CHECK-SAME:     (%[[C0]]) to (%[[N]]) step (%[[C64]]) {
+//      CHECK:     %[[K:.+]] = dim %[[ARG2]], %[[C0]]
+//      CHECK:     %[[TILE_N:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[N]]]
+//      CHECK:     %[[SV1:.+]] = subview %[[ARG2]][0, %[[IV0]]]
+// CHECK-SAME:       [%[[K]], %[[TILE_N]]]
+//      CHECK:     %[[M:.+]] = dim %[[ARG3]], %[[C0]]
+//      CHECK:     %[[N_2:.+]] = dim %[[ARG3]], %[[C1]]
+//      CHECK:     %[[TILE_N_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[N_2]]]
+//      CHECK:     %[[SV2:.+]] = subview %[[ARG3]][0, %[[IV0]]]
+// CHECK-SAME:       [%[[M]], %[[TILE_N_2]]]
+//      CHECK:     %[[SV3:.+]] = subview %[[ARG1]][0, %[[IV0]]]
+// CHECK-SAME:       [%[[K]], %[[TILE_N]]]
+//      CHECK:     linalg.copy(%[[SV3]], %[[SV1]])
+// CHECK-SAME:       __internal_linalg_transform__ = "after_rhs_fusion_producer"
+//  CHECK-NOT:     linalg.fill
+//  CHECK-DAG:     %[[M_2:.+]] = dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:     %[[K_2:.+]] = dim %[[ARG0]], %[[C1]]
+//      CHECK:     scf.parallel (%[[IV1:.+]]) =
+// CHECK-SAME:       (%[[C0]]) to (%[[M_2]]) step (%[[C32]]) {
+// CHECK-NEXT:       scf.for %[[IV2:.+]] = %[[C0]] to %[[K_2]] step %[[C16]] {
+//      CHECK:         %[[TILE_M:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[M_2]]]
+//      CHECK:         %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K_2]]]
+//      CHECK:         %[[SV4:.+]] = subview %[[ARG0]][%[[IV1]], %[[IV2]]]
+// CHECK-SAME:           [%[[TILE_M]], %[[TILE_K]]]
+//      CHECK:         %[[TILE_K_2:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]]
+//      CHECK:         %[[SV5:.+]] = subview %[[SV1]][%[[IV2]], 0]
+// CHECK-SAME:           [%[[TILE_K_2]], %[[TILE_N]]]
+//      CHECK:         %[[TILE_M_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[M]]]
+//      CHECK:         %[[SV6:.+]] = subview %[[SV2]][%[[IV1]], 0]
+// CHECK-SAME:           [%[[TILE_M_2]], %[[TILE_N_2]]]
+//      CHECK:         linalg.matmul
+// CHECK-SAME:           __internal_linalg_transform__ = "after_rhs_fusion"
+// CHECK-SAME:           ins(%[[SV4]], %[[SV5]]
+// CHECK-SAME:             : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME:           outs(%[[SV6]] : memref<?x?xf32, #[[MAP1]]>)
+//      CHECK:       }
+//      CHECK:     }
+//      CHECK:   }
+//      CHECK:   linalg.matmul
+// CHECK-SAME:     __internal_linalg_transform__ = "after_rhs_fusion_original"
+
+
+// -----
+
+module {
+  func @two_operand_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+                              %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) {
+    %cst = constant 0.000000e+00 : f32
+    linalg.copy(%arg0, %arg1) : memref<?x?xf32>, memref<?x?xf32>
+    linalg.fill(%arg3, %cst) : memref<?x?xf32>, f32
+    linalg.matmul {__internal_linalg_transform__ = "two_operand_fusion"}
+      ins(%arg1, %arg2 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg3 : memref<?x?xf32>)
+    return
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
+//      CHECK: func @two_operand_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//  CHECK-DAG:   %[[C32:.+]] = constant 32 : index
+//  CHECK-DAG:   %[[C64:.+]] = constant 64 : index
+//  CHECK-DAG:   %[[C16:.+]] = constant 16 : index
+//  CHECK-DAG:   %[[CST:.+]] = constant 0.0{{.*}} : f32
+//      CHECK:   linalg.copy(%[[ARG0]], %[[ARG1]])
+// CHECK-SAME:     __internal_linalg_transform__ = "after_two_operand_fusion_original"
+//      CHECK:   linalg.fill(%[[ARG3]], %[[CST]])
+// CHECK-SAME:     __internal_linalg_transform__ = "after_two_operand_fusion_original"
+//  CHECK-DAG:   %[[M:.+]] = dim %[[ARG1]], %[[C0]]
+//      CHECK:   scf.parallel (%[[IV0:.+]]) =
+// CHECK-SAME:     (%[[C0]]) to (%[[M]]) step (%[[C32]]) {
+//      CHECK:     %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
+//      CHECK:     %[[K:.+]] = dim %[[ARG1]], %[[C1]]
+//      CHECK:     %[[SV1:.+]] = subview %[[ARG1]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M]], %[[K]]]
+//      CHECK:     %[[M_2:.+]] = dim %[[ARG3]], %[[C0]]
+//      CHECK:     %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
+//      CHECK:     %[[N:.+]] = dim %[[ARG3]], %[[C1]]
+//      CHECK:     %[[SV2:.+]] = subview %[[ARG3]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M_2]], %[[N]]]
+//      CHECK:     %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M]], %[[K]]]
+//      CHECK:     linalg.copy(%[[SV3]], %[[SV1]])
+// CHECK-SAME:       __internal_linalg_transform__ = "after_two_operand_fusion_producer"
+//      CHECK:     linalg.fill(%[[SV2]], %[[CST]])
+// CHECK-SAME:       __internal_linalg_transform__ = "after_two_operand_fusion_producer"
+//  CHECK-DAG:     %[[N_2:.+]] = dim %[[ARG2]], %[[C1]]
+//      CHECK:     scf.parallel (%[[IV1:.+]]) =
+// CHECK-SAME:       (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) {
+// CHECK-NEXT:       scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
+//      CHECK:         %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K]]]
+//      CHECK:         %[[SV4:.+]] = subview %[[SV1]][0, %[[IV2]]]
+// CHECK-SAME:           [%[[TILE_M]], %[[TILE_K]]]
+//      CHECK:         %[[K_2:.+]] = dim %[[ARG2]], %[[C0]]
+//      CHECK:         %[[TILE_K_2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K_2]]]
+//      CHECK:         %[[TILE_N:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N_2]]]
+//      CHECK:         %[[SV5:.+]] = subview %[[ARG2]][%[[IV2]], %[[IV1]]]
+// CHECK-SAME:           [%[[TILE_K_2]], %[[TILE_N]]]
+//      CHECK:         %[[TILE_N_2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N]]]
+//      CHECK:         %[[SV6:.+]] = subview %[[SV2]][0, %[[IV1]]]
+// CHECK-SAME:           [%[[TILE_M_2]], %[[TILE_N_2]]]
+//      CHECK:         linalg.matmul
+// CHECK-SAME:           __internal_linalg_transform__ = "after_two_operand_fusion"
+// CHECK-SAME:           ins(%[[SV4]], %[[SV5]]
+// CHECK-SAME:             : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME:           outs(%[[SV6]] : memref<?x?xf32, #[[MAP1]]>)
+//      CHECK:       }
+//      CHECK:     }
+//      CHECK:   }
+//      CHECK:   linalg.matmul
+// CHECK-SAME:     __internal_linalg_transform__ = "after_two_operand_fusion_original"
+
+// -----
+
+module {
+  func @matmul_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+                      %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
+                      %arg4: memref<?x?xf32>) {
+    linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg2 : memref<?x?xf32>)
+    linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"}
+      ins(%arg2, %arg3 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg4 : memref<?x?xf32>)
+    return
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
+//      CHECK: func @matmul_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//  CHECK-DAG:   %[[C32:.+]] = constant 32 : index
+//  CHECK-DAG:   %[[C64:.+]] = constant 64 : index
+//  CHECK-DAG:   %[[C16:.+]] = constant 16 : index
+//      CHECK:   linalg.matmul
+// CHECK-SAME:     __internal_linalg_transform__ = "after_lhs_fusion_original"
+//  CHECK-DAG:   %[[M:.+]] = dim %[[ARG2]], %[[C0]]
+//      CHECK:   scf.parallel (%[[IV0:.+]]) =
+// CHECK-SAME:     (%[[C0]]) to (%[[M]]) step (%[[C32]]) {
+//      CHECK:     %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
+//      CHECK:     %[[K2:.+]] = dim %[[ARG2]], %[[C1]]
+//      CHECK:     %[[SV1:.+]] = subview %[[ARG2]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M]], %[[K2]]]
+//      CHECK:     %[[M_2:.+]] = dim %[[ARG4]], %[[C0]]
+//      CHECK:     %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
+//      CHECK:     %[[N:.+]] = dim %[[ARG4]], %[[C1]]
+//      CHECK:     %[[SV2:.+]] = subview %[[ARG4]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M_2]], %[[N]]]
+//      CHECK:     %[[K1:.+]] = dim %[[ARG0]], %[[C1]]
+//      CHECK:     %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M]], %[[K1]]]
+//      CHECK:     %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2]]]
+//      CHECK:     linalg.matmul
+// CHECK-SAME:         __internal_linalg_transform__ = "after_lhs_fusion_producer"
+// CHECK-SAME:         ins(%[[SV3]], %[[SV4]]
+// CHECK-SAME:           : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME:         outs(%[[SV1]] : memref<?x?xf32, #[[MAP1]]>)
+//  CHECK-DAG:     %[[N_2:.+]] = dim %[[ARG3]], %[[C1]]
+//      CHECK:     scf.parallel (%[[IV1:.+]]) =
+// CHECK-SAME:       (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) {
+// CHECK-NEXT:       scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
+//      CHECK:         %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K]]]
+//      CHECK:         %[[SV6:.+]] = subview %[[SV1]][0, %[[IV2]]]
+// CHECK-SAME:           [%[[TILE_M]], %[[TILE_K]]]
+//      CHECK:         %[[K_2:.+]] = dim %[[ARG3]], %[[C0]]
+//      CHECK:         %[[TILE_K_2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K_2]]]
+//      CHECK:         %[[TILE_N:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N_2]]]
+//      CHECK:         %[[SV7:.+]] = subview %[[ARG3]][%[[IV2]], %[[IV1]]]
+// CHECK-SAME:           [%[[TILE_K_2]], %[[TILE_N]]]
+//      CHECK:         %[[TILE_N_2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N]]]
+//      CHECK:         %[[SV8:.+]] = subview %[[SV2]][0, %[[IV1]]]
+// CHECK-SAME:           [%[[TILE_M_2]], %[[TILE_N_2]]]
+//      CHECK:         linalg.matmul
+// CHECK-SAME:           __internal_linalg_transform__ = "after_lhs_fusion"
+// CHECK-SAME:           ins(%[[SV6]], %[[SV7]]
+// CHECK-SAME:             : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME:           outs(%[[SV8]] : memref<?x?xf32, #[[MAP1]]>)
+//      CHECK:       }
+//      CHECK:     }
+//      CHECK:   }
+//      CHECK:   linalg.matmul
+// CHECK-SAME:     __internal_linalg_transform__ = "after_lhs_fusion_original"

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 3c82554fa13a..5bf606209ec2 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_library(MLIRTestTransforms
   TestGpuMemoryPromotion.cpp
   TestGpuParallelLoopMapping.cpp
   TestInlining.cpp
+  TestLinalgFusionTransforms.cpp
   TestLinalgHoisting.cpp
   TestLinalgTransforms.cpp
   TestLiveness.cpp

diff  --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
new file mode 100644
index 000000000000..9a376c548900
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -0,0 +1,112 @@
+//===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements logic for testing Linalg fusion patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+struct TestLinalgFusionTransforms
+    : public PassWrapper<TestLinalgFusionTransforms, FunctionPass> {
+  TestLinalgFusionTransforms() = default;
+  TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect,
+                    StandardOpsDialect>();
+  }
+
+  void runOnFunction() override;
+};
+} // namespace
+
+static void fillFusionPatterns(MLIRContext *context,
+                               const LinalgDependenceGraph &dependenceGraph,
+                               OwningRewritePatternList &patterns) {
+  patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
+      context, dependenceGraph,
+      LinalgTilingOptions()
+          .setTileSizes({32, 64, 16})
+          .setLoopType(LinalgTilingLoopType::ParallelLoops),
+      LinalgFusionOptions(),
+      LinalgMarker(Identifier::get("basic_fusion", context),
+                   Identifier::get("after_basic_fusion", context)),
+      LinalgMarker(ArrayRef<Identifier>(),
+                   Identifier::get("after_basic_fusion_producer", context)),
+      LinalgMarker(ArrayRef<Identifier>(),
+                   Identifier::get("after_basic_fusion_original", context)));
+
+  patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
+      context, dependenceGraph,
+      LinalgTilingOptions()
+          .setTileSizes({32, 64, 16})
+          .setLoopType(LinalgTilingLoopType::ParallelLoops),
+      LinalgFusionOptions().setIndicesToFuse({0}),
+      LinalgMarker(Identifier::get("lhs_fusion", context),
+                   Identifier::get("after_lhs_fusion", context)),
+      LinalgMarker(ArrayRef<Identifier>(),
+                   Identifier::get("after_lhs_fusion_producer", context)),
+      LinalgMarker(ArrayRef<Identifier>(),
+                   Identifier::get("after_lhs_fusion_original", context)));
+
+  patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
+      context, dependenceGraph,
+      LinalgTilingOptions()
+          .setTileSizes({32, 64, 16})
+          .setLoopType(LinalgTilingLoopType::ParallelLoops),
+      LinalgFusionOptions().setIndicesToFuse({1}),
+      LinalgMarker(Identifier::get("rhs_fusion", context),
+                   Identifier::get("after_rhs_fusion", context)),
+      LinalgMarker(ArrayRef<Identifier>(),
+                   Identifier::get("after_rhs_fusion_producer", context)),
+      LinalgMarker(ArrayRef<Identifier>(),
+                   Identifier::get("after_rhs_fusion_original", context)));
+
+  patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
+      context, dependenceGraph,
+      LinalgTilingOptions()
+          .setTileSizes({32, 64, 16})
+          .setLoopType(LinalgTilingLoopType::ParallelLoops),
+      LinalgFusionOptions().setIndicesToFuse({0, 2}),
+      LinalgMarker(Identifier::get("two_operand_fusion", context),
+                   Identifier::get("after_two_operand_fusion", context)),
+      LinalgMarker(
+          ArrayRef<Identifier>(),
+          Identifier::get("after_two_operand_fusion_producer", context)),
+      LinalgMarker(
+          ArrayRef<Identifier>(),
+          Identifier::get("after_two_operand_fusion_original", context)));
+}
+
+static void applyFusionPatterns(MLIRContext *context, FuncOp funcOp) {
+  OwningRewritePatternList fusionPatterns;
+  Aliases alias;
+  LinalgDependenceGraph dependenceGraph =
+      LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
+  fillFusionPatterns(context, dependenceGraph, fusionPatterns);
+  applyPatternsAndFoldGreedily(funcOp, fusionPatterns);
+}
+
+void TestLinalgFusionTransforms::runOnFunction() {
+  applyFusionPatterns(&getContext(), getFunction());
+}
+
+namespace mlir {
+void registerTestLinalgFusionTransforms() {
+  PassRegistration<TestLinalgFusionTransforms> testFusionTransformsPass(
+      "test-linalg-fusion-transform-patterns",
+      "Test Linalg fusion transformation patterns by applying them greedily.");
+}
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index aed8b0ae818b..0389c70be3d6 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -58,6 +58,7 @@ void registerTestFunc();
 void registerTestGpuMemoryPromotionPass();
 void registerTestGpuParallelLoopMappingPass();
 void registerTestInterfaces();
+void registerTestLinalgFusionTransforms();
 void registerTestLinalgHoisting();
 void registerTestLinalgTransforms();
 void registerTestLivenessPass();
@@ -114,6 +115,7 @@ void registerTestPasses() {
   registerTestExpandTanhPass();
   registerTestGpuMemoryPromotionPass();
   registerTestInterfaces();
+  registerTestLinalgFusionTransforms();
   registerTestLinalgHoisting();
   registerTestLinalgTransforms();
   registerTestLivenessPass();


        


More information about the Mlir-commits mailing list