[Mlir-commits] [mlir] 5ca2085 - [mlir][Linalg] Improve the logic to perform tile and fuse with better dependence tracking.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 12 00:25:46 PST 2020


Author: MaheshRavishankar
Date: 2020-11-12T00:25:24-08:00
New Revision: 5ca20851e44c906a446e1860f01ee5b0f6f795a6

URL: https://github.com/llvm/llvm-project/commit/5ca20851e44c906a446e1860f01ee5b0f6f795a6
DIFF: https://github.com/llvm/llvm-project/commit/5ca20851e44c906a446e1860f01ee5b0f6f795a6.diff

LOG: [mlir][Linalg] Improve the logic to perform tile and fuse with better dependence tracking.

This change does two main things
1) An operation might have multiple dependences to the same
   producer. Not tracking them correctly can result in incorrect code
   generation with fusion. To rectify this the dependence tracking
   needs to also have the operand number in the consumer.
2) Improve the logic used to find the fused loops making it easier to
   follow. The only constraint for fusion is that linalg ops (on
   buffers) have update semantics for the result. Fusion should be
   such that only one iteration of the fused loop (which is also a
   tiled loop) must touch only one (disjoint) tile of the output. This
   could be relaxed by allowing for recomputation that is the default
   when oeprands are tensors, or can be made legal with promotion of
   the fused view (in future).

Differential Revision: https://reviews.llvm.org/D90579

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/IR/AffineMap.cpp
    mlir/test/Dialect/Linalg/fusion-pattern.mlir
    mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
index 60d157b91c73..372f6c4e01a1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
+++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
@@ -45,7 +45,7 @@ class LinalgDependenceGraph {
 public:
   struct LinalgOpView {
     Operation *op;
-    Value view;
+    unsigned operandIndex;
   };
   struct LinalgDependenceGraphElem {
     // dependentOpView may be either:
@@ -55,7 +55,7 @@ class LinalgDependenceGraph {
     // View in the op that is used to index in the graph:
     //   1. src in the case of dependencesFromDstGraphs.
     //   2. dst in the case of dependencesIntoGraphs.
-    Value indexingView;
+    LinalgOpView indexingOpView;
   };
   using LinalgDependences = SmallVector<LinalgDependenceGraphElem, 8>;
   using DependenceGraph = DenseMap<Operation *, LinalgDependences>;

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index 6646964a983e..ec7167485104 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -555,7 +555,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     >,
     InterfaceMethod<
       /*desc=*/[{
-        Return the position of the shaped operand in the operand list.
+        Return the first position of the shaped operand in the operand list.
       }],
       /*retTy=*/"Optional<unsigned>",
       /*methodName=*/"getIndexOfShapedOperand",
@@ -573,6 +573,67 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return llvm::None;
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns the operand index given the input index. Returns None
+        of the input index is invalid.
+      }],
+      /*retTy=*/"Optional<unsigned>",
+      /*methodName=*/"getOperandIndexForInputIndex",
+      /*args=*/(ins "unsigned":$input_index),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        if (input_index >= $_op.getNumInputs())
+          return llvm::None;
+        return input_index;
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns the operand index given the output index. Returns None
+        of the output index is invalid.
+      }],
+      /*retTy=*/"Optional<unsigned>",
+      /*methodName=*/"getOperandIndexForOutputIndex",
+      /*args=*/(ins "unsigned":$output_index),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        if (output_index >= $_op.getNumOutputs())
+          return llvm::None;
+        return output_index + $_op.getNumInputs();
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns the input index given the operand index. Return None
+        if the operand index doesnt corresponding to an input.
+      }],
+      /*retTy=*/"Optional<unsigned>",
+      /*methodName=*/"getInputIndex",
+      /*args=*/(ins "unsigned":$operand_index),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+         if (operand_index >= $_op.getNumInputs())
+           return llvm::None;
+         return operand_index;
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns the output index given the operand index. Return None
+        if the operand index doesnt corresponding to an output.
+      }],
+      /*retTy=*/"Optional<unsigned>",
+      /*methodName=*/"getOutputIndex",
+      /*args=*/(ins "unsigned":$operand_index),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+         if (operand_index < $_op.getNumInputs() ||
+             operand_index >= $_op.getNumInputs() + $_op.getNumOutputs())
+           return llvm::None;
+         return operand_index - $_op.getNumInputs();
+      }]
+    >,
 
     //===------------------------------------------------------------------===//
     // Other interface methods.

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index e34150d26594..54357940b250 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -15,6 +15,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/Bufferize.h"
 #include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/SmallSet.h"
 
 namespace mlir {
 class BufferizeTypeConverter;
@@ -429,12 +430,10 @@ struct LinalgTilingPattern : public LinalgBaseTilingPattern {
 };
 
 struct LinalgFusionOptions {
-  /// Optional list of operands indices to use for fusion. When unspecified,
-  /// only one fusion is done, i.e., the pattern returns after the first fusion.
-  Optional<DenseSet<unsigned>> indicesToFuse = None;
+  /// List of operands indices to use for fusion.
+  llvm::SmallSet<unsigned, 1> indicesToFuse = {};
   LinalgFusionOptions &setIndicesToFuse(ArrayRef<int64_t> operands) {
-    indicesToFuse = DenseSet<unsigned>();
-    indicesToFuse->insert(operands.begin(), operands.end());
+    indicesToFuse.insert(operands.begin(), operands.end());
     return *this;
   }
 };

diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index dd4960a02c5c..c450024dcb57 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -323,6 +323,9 @@ AffineMap inversePermutation(AffineMap map);
 /// ```
 AffineMap concatAffineMaps(ArrayRef<AffineMap> maps);
 
+AffineMap getProjectedMap(AffineMap map,
+                          ArrayRef<unsigned> projectedDimensions);
+
 inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
   map.print(os);
   return os;

diff  --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
index cddca0de8343..01e167d1f0aa 100644
--- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
+++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
@@ -108,12 +108,14 @@ LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
 void LinalgDependenceGraph::addDependenceElem(DependenceType dt,
                                               LinalgOpView indexingOpView,
                                               LinalgOpView dependentOpView) {
-  LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t"
-                    << *indexingOpView.op << " -> " << *dependentOpView.op);
+  LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t ("
+                    << *indexingOpView.op << ", " << indexingOpView.operandIndex
+                    << ") -> \n\t\t(" << *dependentOpView.op << ", "
+                    << dependentOpView.operandIndex << ")");
   dependencesFromGraphs[dt][indexingOpView.op].push_back(
-      LinalgDependenceGraphElem{dependentOpView, indexingOpView.view});
+      LinalgDependenceGraphElem{dependentOpView, indexingOpView});
   dependencesIntoGraphs[dt][dependentOpView.op].push_back(
-      LinalgDependenceGraphElem{indexingOpView, dependentOpView.view});
+      LinalgDependenceGraphElem{indexingOpView, dependentOpView});
 }
 
 LinalgDependenceGraph::dependence_range
@@ -147,39 +149,55 @@ LinalgDependenceGraph::getDependencesInto(
 }
 
 void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
-  for (auto srcView : src.getOutputBuffers()) { // W
+  for (auto srcView : llvm::enumerate(src.getOutputBuffers())) { // W
+    unsigned srcIndex =
+        src.getOperandIndexForOutputIndex(srcView.index()).getValue();
     // RAW graph
-    for (auto dstView : dst.getInputBuffers()) { // R
-      if (aliases.alias(srcView, dstView)) { // if alias, fill RAW
+    for (auto dstView : llvm::enumerate(dst.getInputBuffers())) { // R
+      if (aliases.alias(srcView.value(),
+                        dstView.value())) { // if alias, fill RAW
+        unsigned dstIndex =
+            dst.getOperandIndexForInputIndex(dstView.index()).getValue();
         addDependenceElem(DependenceType::RAW,
-                          LinalgOpView{src.getOperation(), srcView},
-                          LinalgOpView{dst.getOperation(), dstView});
+                          LinalgOpView{src.getOperation(), srcIndex},
+                          LinalgOpView{dst.getOperation(), dstIndex});
       }
     }
     // WAW graph
-    for (auto dstView : dst.getOutputBuffers()) { // W
-      if (aliases.alias(srcView, dstView)) {      // if alias, fill WAW
+    for (auto dstView : llvm::enumerate(dst.getOutputBuffers())) { // W
+      if (aliases.alias(srcView.value(),
+                        dstView.value())) { // if alias, fill WAW
+        unsigned dstIndex =
+            dst.getOperandIndexForOutputIndex(dstView.index()).getValue();
         addDependenceElem(DependenceType::WAW,
-                          LinalgOpView{src.getOperation(), srcView},
-                          LinalgOpView{dst.getOperation(), dstView});
+                          LinalgOpView{src.getOperation(), srcIndex},
+                          LinalgOpView{dst.getOperation(), dstIndex});
       }
     }
   }
-  for (auto srcView : src.getInputBuffers()) { // R
+  for (auto srcView : llvm::enumerate(src.getInputBuffers())) { // R
+    unsigned srcIndex =
+        src.getOperandIndexForInputIndex(srcView.index()).getValue();
     // RAR graph
-    for (auto dstView : dst.getInputBuffers()) { // R
-      if (aliases.alias(srcView, dstView)) { // if alias, fill RAR
+    for (auto dstView : llvm::enumerate(dst.getInputBuffers())) { // R
+      if (aliases.alias(srcView.value(),
+                        dstView.value())) { // if alias, fill RAR
+        unsigned dstIndex =
+            dst.getOperandIndexForInputIndex(dstView.index()).getValue();
         addDependenceElem(DependenceType::RAR,
-                          LinalgOpView{src.getOperation(), srcView},
-                          LinalgOpView{dst.getOperation(), dstView});
+                          LinalgOpView{src.getOperation(), srcIndex},
+                          LinalgOpView{dst.getOperation(), dstIndex});
       }
     }
     // WAR graph
-    for (auto dstView : dst.getOutputBuffers()) { // W
-      if (aliases.alias(srcView, dstView)) {      // if alias, fill WAR
+    for (auto dstView : llvm::enumerate(dst.getOutputBuffers())) { // W
+      if (aliases.alias(srcView.value(),
+                        dstView.value())) { // if alias, fill WAR
+        unsigned dstIndex =
+            dst.getOperandIndexForOutputIndex(dstView.index()).getValue();
         addDependenceElem(DependenceType::WAR,
-                          LinalgOpView{src.getOperation(), srcView},
-                          LinalgOpView{dst.getOperation(), dstView});
+                          LinalgOpView{src.getOperation(), srcIndex},
+                          LinalgOpView{dst.getOperation(), dstIndex});
       }
     }
   }
@@ -227,12 +245,16 @@ LinalgDependenceGraph::findOperationsWithCoveringDependences(
       // Skip if not interleaved.
       if (interimPos >= dstPos || interimPos <= srcPos)
         continue;
-      if (view && !aliases.alias(view, dependence.indexingView))
+      linalg::LinalgOp consumer =
+          cast<linalg::LinalgOp>(dependence.indexingOpView.op);
+      Value consumerView =
+          consumer.getShapedOperand(dependence.indexingOpView.operandIndex);
+      if (view && !aliases.alias(view, consumerView))
         continue;
       auto *op = dependence.dependentOpView.op;
       LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
                         << getDependenceTypeStr(dt) << ": " << *src << " -> "
-                        << *op << " on " << dependence.indexingView);
+                        << *op << " on " << consumerView);
       res.push_back(op);
     }
   }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index ac35d87a8413..969bea4a4549 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -24,10 +24,12 @@
 #include "mlir/IR/Dominance.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/MapVector.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 
+#include <set>
+
 #define DEBUG_TYPE "linalg-fusion"
 
 using namespace mlir;
@@ -95,8 +97,8 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
   for (auto en : llvm::enumerate(op.getShapedOperands())) {
     unsigned shapedOperandIdx = en.index();
     AffineMap map = op.getIndexingMap(shapedOperandIdx);
-    LLVM_DEBUG(dbgs() << "shapedOperandIdx: " << shapedOperandIdx
-                      << " with indexingMap: " << map << "\n");
+    LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx
+                            << " with indexingMap: " << map << "\n");
     SmallVector<Value, 4> offsets, sizes, strides;
     inferShapeComponents(map, loopRanges, offsets, sizes, strides);
     Value shape = en.value();
@@ -169,16 +171,18 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
   for (auto en : llvm::enumerate(ios)) {
     unsigned idx = en.index();
     auto map = maps[idx].cast<AffineMapAttr>().getValue();
-    LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
-    LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange map: " << map << "\n");
+    LLVM_DEBUG(llvm::dbgs()
+               << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
+    LLVM_DEBUG(llvm::dbgs()
+               << "getShapeDefiningLoopRange map: " << map << "\n");
     Value shape = en.value();
     SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
     for (auto en2 : llvm::enumerate(map.getResults())) {
       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
-        LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange loopDepth: "
-                          << loopDepth << "\n");
-        LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange shape: " << shape
-                          << "\n");
+        LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
+                                << loopDepth << "\n");
+        LLVM_DEBUG(llvm::dbgs()
+                   << "getShapeDefiningLoopRange shape: " << shape << "\n");
         return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
       }
     }
@@ -209,8 +213,8 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
   //   dimension.
   // TODO: extend this with range inference.
   AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
-  LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
-                    << ", producer map: " << producerMap << "\n");
+  LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
+                          << ", producer map: " << producerMap << "\n");
 
   unsigned nPar = producer.getNumParallelLoops();
   unsigned nRed = producer.getNumReductionLoops();
@@ -258,7 +262,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
   assert(consumer.hasBufferSemantics() &&
          "expected linalg op with buffer semantics");
   if (producer.getNumOutputs() != 1) {
-    LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
+    LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)");
     return false;
   }
   // Only fuse when the producer block dominates.
@@ -266,7 +270,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
   if (!dom.dominates(producer.getOperation()->getBlock(),
                      consumer.getOperation()->getBlock())) {
     LLVM_DEBUG(
-        dbgs()
+        llvm::dbgs()
         << "\nNot structurally fusable (producer block does not dominate)");
     return false;
   }
@@ -284,14 +288,14 @@ bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
   // Make some simple structural checks that alleviate the need for more
   // complex analyses.
   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
-    LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
-                      << *producer.getOperation());
+    LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t"
+                            << *producer.getOperation());
     return false;
   }
   // Check for any interleaved write to consumedView.
   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
-    LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
-                      << *producer.getOperation());
+    LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t"
+                            << *producer.getOperation());
     return false;
   }
   return true;
@@ -309,8 +313,9 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
   // Check for any fusion-preventing dependence to any shape read/written that
   // would violate dependences.
   if (!graph.findCoveringDependences(producer, consumer).empty()) {
-    LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
-                      << *producer.getOperation());
+    LLVM_DEBUG(llvm::dbgs()
+               << "\n***Not fusable due to an interleaved dependence:\t"
+               << *producer.getOperation());
     return false;
   }
   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
@@ -360,26 +365,33 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
            LinalgDependenceGraph::DependenceType::RAW,
            LinalgDependenceGraph::DependenceType::WAW,
        }) {
-    for (auto dependence :
-         dependenceGraph.getDependencesInto(consumer, depType)) {
+    for (auto dependence : llvm::make_filter_range(
+             dependenceGraph.getDependencesInto(consumer, depType),
+             [consumerIdx](
+                 LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
+               return elem.indexingOpView.operandIndex == consumerIdx;
+             })) {
       auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
 
       // Check that the dependence is indeed on the input `consumerIdx` view.
-      auto consumedView = dependence.indexingView;
+      auto consumedView =
+          consumer.getBuffer(dependence.indexingOpView.operandIndex);
       if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
         continue;
 
       // Consumer consumes this view, `isStructurallyFusableProducer` also
       // checks whether it is a strict subview of the producer view.
-      auto producedView = dependence.dependentOpView.view;
-      auto producerIdx =
-          producer.getIndexOfOutputBuffer(producedView).getValue();
-      // `consumerIdx` and `producerIdx` exist by construction.
-      LLVM_DEBUG(dbgs() << "\n"
-                        << LinalgDependenceGraph::getDependenceTypeStr(depType)
-                        << "producer: " << *producer.getOperation() << " view: "
-                        << producedView << " output index: " << producerIdx);
-      (void)producerIdx;
+      auto producedView =
+          producer.getBuffer(dependence.dependentOpView.operandIndex);
+      LLVM_DEBUG(llvm::dbgs()
+                 << "\n"
+                 << LinalgDependenceGraph::getDependenceTypeStr(depType)
+                 << "producer: " << *producer.getOperation()
+                 << " view: " << producedView << " output index: "
+                 << dependence.dependentOpView.operandIndex -
+                        producer.getNumInputs()
+                 << "\n");
+      (void)producedView;
 
       // Simple fusability checks.
       if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
@@ -406,15 +418,16 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
       producerOp.getOperation()->getBlock())
     return {};
 
-  Value producerView = fusableDependence->dependentOpView.view;
-  Value consumerView = fusableDependence->indexingView;
+  unsigned producerIdx = fusableDependence->dependentOpView.operandIndex -
+                         producerOp.getNumInputs();
+  Value consumerView = consumer.getShapedOperand(consumerIdx);
 
   // Must be a subview or a slice to guarantee there are loops we can fuse
   // into.
   auto subView = consumerView.getDefiningOp<SubViewOp>();
   auto slice = consumerView.getDefiningOp<SliceOp>();
   if (!subView && !slice) {
-    LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
+    LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)");
     return {};
   }
 
@@ -422,11 +435,7 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(consumer.getOperation());
   ScopedContext scope(b, consumer.getLoc());
-  LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
-  Optional<unsigned> producerIdxOpt =
-      producerOp.getIndexOfOutputBuffer(producerView);
-  assert(producerIdxOpt.hasValue() && "incorrect operand index");
-  unsigned producerIdx = producerIdxOpt.getValue();
+  LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
 
   auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
   return FusionInfo{producerOp, fusedProducer};
@@ -470,7 +479,7 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
   // Must be a subtensor to guarantee there are loops we can fuse into.
   auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
   if (!subTensor || !producerOp) {
-    LLVM_DEBUG(dbgs() << "\nNot fusable (not a subtensor)");
+    LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subtensor)");
     return {};
   }
 
@@ -483,7 +492,7 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(consumer.getOperation());
   ScopedContext scope(b, consumer.getLoc());
-  LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
+  LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
   LinalgOp fusedProducer =
       fuse(b, producerOp, producerIdx, consumer, consumerIdx);
 
@@ -501,6 +510,21 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
   return FusionInfo{producerOp, fusedProducer};
 }
 
+/// Prune all dimensions that are of reduction iterator type from `map`.
+static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
+                                           AffineMap map) {
+  SmallVector<unsigned, 2> projectedDims;
+  for (auto attr : llvm::enumerate(iteratorTypes)) {
+    if (!isParallelIterator(attr.value()))
+      projectedDims.push_back(attr.index());
+  }
+  return getProjectedMap(map, projectedDims);
+}
+
+using FusableOpDependencesTy = llvm::MapVector<
+    Operation *,
+    SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
+
 /// Returns the positions of the loop in `op` that can be tiled based on the
 /// operations that are to be fused with it. For example, in a
 ///
@@ -508,12 +532,58 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
 ///
 /// if the producer of %a needs to be fused with this op, only the `i` loop of
 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be
-/// fused, then no loops can be tiled while fusing.
-static DenseSet<unsigned> collectTileAndFuseLoops(
-    LinalgOp op, ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem>
-                     fusableDependences) {
-  // 1. Only parallel loops can be used for tile + fuse. Find the number of
-  // common outer parallel loops between the op and its producers being fused.
+/// fused, then no loops can be tiled while fusing. The conditions used are:
+/// 1. Only parallel loops can be used for tile + fuse. Find the number of
+///    common outer parallel loops between the op and its producers being fused.
+/// 2. Of the parallel loops only some can be fused. Only those loops can be
+///    fused such where the fusable loops iteration space only touches one tile
+///    of the fused operation. This is because the producer (which is writing
+///    the fused subview) has update semantics. To compute this,
+///    a. Find the mapping from iterations in the consumer that write to the
+///       same location as the iterations in the producer. To do so use
+///       - indexing map of the fused view in the consumer : consumerIndexMap
+///       - indexing map of the fused view in the producer : producerIndexMap
+///       consumerLoopToProducerLoop =
+///         inverse(producerIndexMap).compose(consumerIndexMap)
+///
+/// Since an inverse computation is needed, we need to consider the projection
+/// of the producerIndexMap w.r.t the parallel loops.  The actual fusable loops
+/// are the dimensions of the consumerLoopToProducerLoop map that correspond to
+/// parallel loops and appear in the result of the map
+///
+/// Example 1:
+///   linalg.fill(%c, %cst)
+///   linalg.matmul ins(%a, %b) outs(%c)
+///     Number of parallel loops : 2
+///     producerIndexMap = affine_map<(i, j) ->(i , j)>
+///     consumerIndexMap = affine_map<(i, j, k) -> (i, j)>
+///     consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)>
+///     Fused dimensions : i, j
+///
+/// Example 2:
+///   linalg.matmul ins(%a, %b) outs(%c)
+///   linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ...
+///                   iterator_types = ["parallel", "parallel"]}
+///     ins(%c) ...
+///
+///     Number of parallel loops = 2:
+///     producerIndexMap (projected to parallel loops) =
+///       affine_map<(i, j) -> (i, j)>
+///     consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)>
+///     Fused dimensions : i, j
+///
+/// Example 3:
+///   linalg.copy(%s, %b)
+///   linalg.matmul ins(%a, %b) outs(%c)
+///
+///   Number of parallel loops = 2
+///   produceIndexMap : affine_map<(i, j) -> (i, j)>
+///   consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)>
+///     submap with only parallel loops = affine_map<(i, j) -> (j)>
+///   Fused dimensions : j
+static std::set<unsigned>
+collectTileAndFuseLoops(LinalgOp op,
+                        const FusableOpDependencesTy &fusableDependences) {
   auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
     return linalgOp.iterator_types()
         .getValue()
@@ -524,135 +594,149 @@ static DenseSet<unsigned> collectTileAndFuseLoops(
         .size();
   };
 
+  LLVM_DEBUG({
+    llvm::dbgs() << "Op : ";
+    op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+    llvm::dbgs() << "\n";
+  });
+
   size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
   for (auto dependence : fusableDependences) {
+    linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
     numOuterParallelLoops =
-        std::min(numOuterParallelLoops, getNumOuterParallelLoops(cast<LinalgOp>(
-                                            dependence.dependentOpView.op)));
+        std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer));
   }
 
-  // Need to compute what tiled loops can be "fused". Given the precondition
-  // that all indexing map for the producer view is a projected permutation, we
-  // can assert that the producer iterates over the dimensions of the "fused
-  // view" only once. To be used a fused loop the producer should use this loop
-  // to access the fused view. For example, consider
-  //
-  // ```
-  //   linalg.add ins(%a, %b) outs(%c)
-  //   linalg.matmul ins(%d, %c) outs(%e)
-  // ```
-  //
-  // if `linalg.add` has the semantics of `c = a + b`, then the following
-  // tile+fuse code is correct.
-  //
-  // ```
-  // for j ... += TSj
-  //   %sa = subview %a[0, %j][...]
-  //   %sb = subview %b[0, %j][...]
-  //   %sc = subview %c[0, %j][...]
-  //   %sd = subview %d[0, 0][...]
-  //   %se = subview %e[0, %j][...]
-  //   linalg.add ins(%sa, %sb) outs(%sc)
-  //   linalg.matmul ins(%sd, %sc) outs(%se)
-  // ```
-  //
-  // On the other hand tiling along i would be incorrect
-  //
-  // ```
-  // for %i .. += TSi
-  //   %sa = subview %a[%i, 0][...]
-  //   %sb = subview %b[%i, 0][...]
-  //   %sc = subview %c[%i, 0][...]
-  //   %sc2 = subview %c[0, 0][...]
-  //   %sd = subview %d[%i, 0][...]
-  //   %se = subview %e[%i, 0][...]
-  //   linalg.add ins(%sa, %sb) outs(%sc)
-  //   linalg.matmul ins(%sd, %sc2) outs(%se)
-  // ```
-  //
-  // The write to the subview `%sc` in `linalg.add` is performed after the read
-  // from it using `%sc2` violating the RAW dependence of the original code. To
-  // find such loops indexing map of the fused view in the consumer op is
-  // used. For the above example, this indexing map is
-  //
-  //   affine_map<(d0, d1, d2) -> (d2, d1)>
-  //
-  // Since d0 is not in the result expressions of this map, it is not treated as
-  // tile + fuse loop, (but d1 is).
-  //
-  // TODO: The above is probably restrictive and there might be a generalization
-  // of these that might allow for more fusion opportunities. Explore based on
-  // needs.
-  SmallVector<DenseSet<unsigned>, 1> commonTilableLoops;
+  std::set<unsigned> fusableLoops;
+  auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
+  fusableLoops.insert(range.begin(), range.end());
   for (auto dependence : fusableDependences) {
-    unsigned consumerIdx =
-        op.getIndexOfShapedOperand(dependence.indexingView).getValue();
-    AffineMap consumerAccess = op.getIndexingMap(consumerIdx);
-    // Previously asserted that the consumerAccess map is a projected
-    // permutation, so all results are known to be AffineDimExprs. To remove
-    // this restriction walk the expression to find which dimensions of the
-    // consumer loop appear in the `consumerAccess`.
-    DenseSet<unsigned> positions;
-    for (auto expr : consumerAccess.getResults())
-      positions.insert(expr.cast<AffineDimExpr>().getPosition());
-    commonTilableLoops.emplace_back(std::move(positions));
+    LLVM_DEBUG({
+      llvm::dbgs() << "\t fusable :";
+      for (unsigned i : fusableLoops)
+        llvm::dbgs() << " " << i;
+      llvm::dbgs() << "\n";
+    });
+    linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
+
+    assert(!dependence.second.empty() &&
+           "unexpected producer but not dependences");
+    AffineMap producerIndexingMap = producer.getIndexingMap(
+        dependence.second.front().dependentOpView.operandIndex);
+    AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
+        producer.iterator_types().getValue(), producerIndexingMap);
+    if (!prunedProducerIndexingMap.isPermutation())
+      return {};
+
+    AffineMap consumerIndexingMap = op.getIndexingMap(
+        dependence.second.front().indexingOpView.operandIndex);
+    if (consumerIndexingMap.getNumResults() !=
+        prunedProducerIndexingMap.getNumResults())
+      return {};
+
+    LLVM_DEBUG({
+      llvm::dbgs() << "\t producerMap : ";
+      producerIndexingMap.print(llvm::dbgs());
+      llvm::dbgs() << "  pruned : ";
+      prunedProducerIndexingMap.print(llvm::dbgs());
+      llvm::dbgs() << "\n";
+      llvm::dbgs() << "\t consumerMap : ";
+      consumerIndexingMap.print(llvm::dbgs());
+      llvm::dbgs() << "\n";
+    });
+
+    AffineMap invProducerIndexMap =
+        inversePermutation(prunedProducerIndexingMap);
+    if (!invProducerIndexMap)
+      return {};
+
+    AffineMap consumerLoopToProducerLoop =
+        invProducerIndexMap.compose(consumerIndexingMap);
+
+    LLVM_DEBUG({
+      llvm::dbgs() << "\t consumerLoopToProducerLoop : ";
+      consumerLoopToProducerLoop.print(llvm::dbgs());
+    });
+
+    std::set<unsigned> candidates;
+    for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) {
+      AffineDimExpr dimExpr = expr.dyn_cast<AffineDimExpr>();
+      if (!dimExpr)
+        continue;
+      unsigned position = dimExpr.getPosition();
+      if (fusableLoops.count(position))
+        candidates.insert(position);
+    }
+    LLVM_DEBUG({
+      llvm::dbgs() << "\t candidates :";
+      for (unsigned i : candidates)
+        llvm::dbgs() << " " << i;
+      llvm::dbgs() << "\n";
+    });
+    if (candidates.empty())
+      return {};
+    std::swap(candidates, fusableLoops);
   }
 
-  // 2. Of the outer parallel loops, only those loops can be tiled + fused as
-  // computed above for all the fused dependences can be used to tile and fuse.
-  DenseSet<unsigned> tilableParallelLoops;
-  for (auto index : llvm::seq<unsigned>(0, numOuterParallelLoops)) {
-    if (llvm::all_of(commonTilableLoops,
-                     [&](const DenseSet<unsigned> &tilableLoops) {
-                       return tilableLoops.count(index);
-                     }))
-      tilableParallelLoops.insert(index);
-  }
-  return tilableParallelLoops;
+  return fusableLoops;
 }
 
 /// Find all dependences that are to be fusable.
-static Optional<
-    SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
+static FusableOpDependencesTy
 findAllFusableDependences(LinalgOp op,
                           const LinalgDependenceGraph &dependenceGraph,
                           const LinalgFusionOptions &fusionOptions) {
-  SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>
-      fusableDependences;
-  for (auto operand : llvm::enumerate(op.getInputsAndOutputBuffers())) {
-    if (fusionOptions.indicesToFuse &&
-        !fusionOptions.indicesToFuse->count(operand.index()))
-      continue;
-    Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
-        fusableDependence =
-            findFusableProducer(op, operand.index(), dependenceGraph);
+  FusableOpDependencesTy fusableDependences;
+  // TODO: Currently fusion would not be legal if the fusable dependence is to
+  // the same producer but 
diff erent indexing map in the consumer. Fix this, but
+  // in the meanwhile disallow such a fusion.
+  DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
+  for (auto operandIndex : fusionOptions.indicesToFuse) {
+    auto fusableDependence =
+        findFusableProducer(op, operandIndex, dependenceGraph);
     if (!fusableDependence)
-      continue;
+      return FusableOpDependencesTy{};
+    LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
+    // Do not fuse dependences that are to operations not in the same basic
+    // block. This avoid moving fused operations across loops that might
+    // themselves carry dependency making the fusion illegal.
+    if (producerOp.getOperation()->getBlock() !=
+        op.getOperation()->getBlock()) {
+      op.emitRemark("unhandled fusion of ops in 
diff erent basic blocks");
+      return FusableOpDependencesTy{};
+    }
     // Make sure that the indexing map of the view used for fusion in the
     // producer is a projected permutation.
-    LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
-    Value producerView = fusableDependence->dependentOpView.view;
-    unsigned producerIdx =
-        producerOp.getIndexOfOutputBuffer(producerView).getValue();
-    AffineMap producerMap = producerOp.getOutputIndexingMap(producerIdx);
+    unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
+    AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
     if (!producerMap.isProjectedPermutation()) {
-      op.emitError("unhandled non permutation indexing map for fused view in "
-                   "producer for operand at index ")
-          << operand.index();
-      return llvm::None;
+      op.emitRemark("unhandled non permutation indexing map for fused view in "
+                    "producer for operand at index ")
+          << operandIndex;
+      return FusableOpDependencesTy{};
     }
-    Value consumerView = fusableDependence->indexingView;
-    unsigned consumerIdx = op.getIndexOfShapedOperand(consumerView).getValue();
-    if (!op.getIndexingMap(consumerIdx).isProjectedPermutation()) {
-      op.emitError(
+
+    unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
+    AffineMap consumerMap = op.getIndexingMap(consumerIdx);
+    if (!consumerMap.isProjectedPermutation()) {
+      op.emitRemark(
           "unhandled case where indexing map for fused view in the consumer is "
-          "not a projected permuration while fusing at index ")
-          << operand.index();
-      return llvm::None;
+          "not a projected permutation while fusing at index ")
+          << operandIndex;
+      return FusableOpDependencesTy{};
+    }
+
+    // Check if the producer is already a fusion candidate. Cannot fuse this
+    // dependence if it has a 
diff erent indexing map when used in the consumer.
+    if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
+        fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
+      op.emitRemark("unhandled fusion to the same producer but with 
diff erent "
+                    "indexing maps");
+      return FusableOpDependencesTy{};
     }
-    fusableDependences.push_back(*fusableDependence);
-    if (!fusionOptions.indicesToFuse)
-      break;
+    fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
+
+    fusableDependences[producerOp.getOperation()].push_back(*fusableDependence);
   }
   return fusableDependences;
 }
@@ -682,13 +766,10 @@ tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
   ScopedContext scope(rewriter, op.getLoc());
 
   // Find all the producers.
-  Optional<SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
-      fusableDependencesOpt =
-          findAllFusableDependences(op, dependenceGraph, fusionOptions);
-  if (!fusableDependencesOpt)
+  FusableOpDependencesTy fusableDependences =
+      findAllFusableDependences(op, dependenceGraph, fusionOptions);
+  if (fusableDependences.empty())
     return llvm::None;
-  ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependences(
-      *fusableDependencesOpt);
 
   // Enforce the convention that "tiling by zero" skips tiling a particular
   // dimension. This convention is significantly simpler to handle instead of
@@ -704,12 +785,12 @@ tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
   TiledAndFusedLinalgOps ret;
 
   // Find the loops that can be tiled and fused.
-  DenseSet<unsigned> tileFuseLoops =
+  std::set<unsigned> tileFuseLoops =
       collectTileAndFuseLoops(op, fusableDependences);
 
   // If there are no fusable dependences or there are no tile+fusable loops,
   // just return.
-  if (fusableDependences.empty() || tileFuseLoops.empty()) {
+  if (tileFuseLoops.empty()) {
     return llvm::None;
   }
 
@@ -752,15 +833,15 @@ tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
 
   rewriter.setInsertionPoint(ret.op);
   // Fuse the operands.
-  for (auto producer : enumerate(fusableDependences)) {
-    LinalgOp producerOp = cast<LinalgOp>(producer.value().dependentOpView.op);
+  for (auto dependence : fusableDependences) {
+    LinalgOp producerOp = cast<LinalgOp>(dependence.first);
     unsigned producerIdx =
-        producerOp.getIndexOfOutputBuffer(producer.value().dependentOpView.view)
-            .getValue();
+        dependence.second.front().dependentOpView.operandIndex;
     unsigned consumerIdx =
-        op.getIndexOfShapedOperand(producer.value().indexingView).getValue();
-    LinalgOp fusedOp =
-        fuse(rewriter, producerOp, producerIdx, ret.op, consumerIdx);
+        dependence.second.front().indexingOpView.operandIndex;
+    LinalgOp fusedOp = fuse(rewriter, producerOp,
+                            producerOp.getOutputIndex(producerIdx).getValue(),
+                            ret.op, consumerIdx);
     ret.fusedProducers.push_back(fusedOp);
     ret.originalProducers.push_back(producerOp);
   }

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index ba76976a17c1..1f73d07cc8ff 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -12,6 +12,7 @@
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -450,6 +451,22 @@ AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
                         maps.front().getContext());
 }
 
+AffineMap mlir::getProjectedMap(AffineMap map,
+                                ArrayRef<unsigned> projectedDimensions) {
+  DenseSet<unsigned> projectedDims(projectedDimensions.begin(),
+                                   projectedDimensions.end());
+  MLIRContext *context = map.getContext();
+  SmallVector<AffineExpr, 4> resultExprs;
+  for (auto dim : enumerate(llvm::seq<unsigned>(0, map.getNumDims()))) {
+    if (!projectedDims.count(dim.value()))
+      resultExprs.push_back(getAffineDimExpr(dim.index(), context));
+    else
+      resultExprs.push_back(getAffineConstantExpr(0, context));
+  }
+  return map.compose(AffineMap::get(
+      map.getNumDims() - projectedDimensions.size(), 0, resultExprs, context));
+}
+
 //===----------------------------------------------------------------------===//
 // MutableAffineMap.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
index 61e5b746deac..2ddc66651db2 100644
--- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s
 
 module {
   func @basic_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
@@ -295,3 +295,121 @@ module {
 //      CHECK:   }
 //      CHECK:   linalg.matmul
 // CHECK-SAME:     __internal_linalg_transform__ = "after_lhs_fusion_original"
+
+// -----
+
+module {
+  func @matmul_plus_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+                           %arg2: memref<?x?xf32>) {
+    %c0 = constant 0 : index
+    %c1 = constant 1 : index
+    %0 = dim %arg2, %c0 : memref<?x?xf32>
+    %1 = dim %arg2, %c1 : memref<?x?xf32>
+    %2 = alloc(%0, %1) : memref<?x?xf32>
+    linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%2 : memref<?x?xf32>)
+    linalg.generic
+      {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                        affine_map<(d0, d1) -> (d0, d1)>,
+                        affine_map<(d0, d1) -> (d0, d1)>],
+       iterator_types = ["parallel", "parallel"],
+       __internal_linalg_transform__ = "transpose_fusion"}
+      ins(%2, %2 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg2 : memref<?x?xf32>) {
+      ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
+        %3 = addf %arg3, %arg4 : f32
+        linalg.yield %3 : f32
+      }
+    return
+  }
+}
+//       CHECK: func @matmul_plus_matmul
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//       CHECK:   %[[T2:.+]] = alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     after_transpose_fusion_original
+//       CHECK:   scf.parallel (%[[ARG3:[a-zA-Z0-9_]+]], %[[ARG4:.[a-zA-Z0-9_]+]])
+//       CHECK:     %[[T5:.+]] = subview %[[T2]][%[[ARG3]], %[[ARG4]]]
+//       CHECK:     %[[T6:.+]] = subview %[[ARG2]][%[[ARG3]], %[[ARG4]]]
+//       CHECK:     %[[T8:.+]] = subview %[[ARG0]][%[[ARG3]], 0]
+//       CHECK:     %[[T9:.+]] = subview %[[ARG1]][0, %[[ARG4]]]
+//       CHECK:     linalg.matmul
+//  CHECK-SAME:       after_transpose_fusion_producer
+//  CHECK-SAME:       ins(%[[T8]], %[[T9]]
+//  CHECK-SAME:       outs(%[[T5]]
+//   CHECK-NOT:     linalg.matmul
+//       CHECK:     linalg.generic
+//  CHECK-SAME:       ins(%[[T5]], %[[T5]]
+//  CHECK-SAME:       outs(%[[T6]]
+//  CHECK-SAME:       after_transpose_fusion
+
+// -----
+
+module {
+  func @matmul_plus_transpose_matmul(%arg0: memref<?x?xf32>,
+                                     %arg1: memref<?x?xf32>,
+                                     %arg2: memref<?x?xf32>) {
+    %c0 = constant 0 : index
+    %c1 = constant 1 : index
+    %0 = dim %arg2, %c0 : memref<?x?xf32>
+    %1 = dim %arg2, %c1 : memref<?x?xf32>
+    %2 = alloc(%0, %1) : memref<?x?xf32>
+    linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%2 : memref<?x?xf32>)
+    // expected-remark @+1 {{unhandled fusion to the same producer but with 
diff erent indexing maps}}
+    linalg.generic
+      {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                        affine_map<(d0, d1) -> (d1, d0)>,
+                        affine_map<(d0, d1) -> (d0, d1)>],
+       iterator_types = ["parallel", "parallel"],
+       __internal_linalg_transform__ = "transpose_fusion"}
+      ins(%2, %2 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg2 : memref<?x?xf32>) {
+      ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
+        %3 = addf %arg3, %arg4 : f32
+        linalg.yield %3 : f32
+      }
+    return
+  }
+}
+
+// -----
+
+#map0 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+#map1 = affine_map<(d0)[s0] -> (64, -d0 + s0)>
+#map2 = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+#map3 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+module {
+  func @basic_no_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+                        %arg2: memref<?x?xf32>) {
+    %c0 = constant 0 : index
+    %c1 = constant 1 : index
+    %c2 = constant 2 : index
+    %c32 = constant 32 : index
+    %c64 = constant 64 : index
+    %c16 = constant 16 : index
+    %cst = constant 0.000000e+00 : f32
+    linalg.fill(%arg2, %cst) : memref<?x?xf32>, f32
+    %0 = dim %arg0, %c0 : memref<?x?xf32>
+    %1 = dim %arg1, %c1 : memref<?x?xf32>
+    %2 = dim %arg0, %c1 : memref<?x?xf32>
+    scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c32, %c64) {
+      scf.for %arg5 = %c0 to %2 step %c16 {
+        %3 = affine.min #map0(%arg3)[%0]
+        %4 = affine.min #map1(%arg4)[%1]
+        %5 = affine.min #map2(%arg5)[%2]
+        %6 = subview %arg0[%arg3, %arg5] [%3, %5] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
+        %7 = subview %arg1[%arg5, %arg4] [%5, %4] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
+        %8 = subview %arg2[%arg3, %arg4] [%3, %4] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
+	// expected-remark @+1 {{unhandled fusion of ops in 
diff erent basic blocks}}
+        linalg.matmul {__internal_linalg_transform__ = "basic_fusion"}
+          ins(%6, %7 : memref<?x?xf32, #map3>, memref<?x?xf32, #map3>)
+          outs(%8 : memref<?x?xf32, #map3>)
+      }
+      scf.yield
+    }
+    return
+  }
+}

diff  --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 06ff91eb074b..e6e150b7bf47 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -43,7 +43,7 @@ static void fillFusionPatterns(MLIRContext *context,
       LinalgTilingOptions()
           .setTileSizes({32, 64, 16})
           .setLoopType(LinalgTilingLoopType::ParallelLoops),
-      LinalgFusionOptions(),
+      LinalgFusionOptions().setIndicesToFuse({2}),
       LinalgMarker(Identifier::get("basic_fusion", context),
                    Identifier::get("after_basic_fusion", context)),
       LinalgMarker(ArrayRef<Identifier>(),
@@ -91,6 +91,19 @@ static void fillFusionPatterns(MLIRContext *context,
       LinalgMarker(
           ArrayRef<Identifier>(),
           Identifier::get("after_two_operand_fusion_original", context)));
+
+  patterns.insert<LinalgTileAndFusePattern<GenericOp>>(
+      context, dependenceGraph,
+      LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(
+          LinalgTilingLoopType::ParallelLoops),
+      LinalgFusionOptions().setIndicesToFuse({0, 1}),
+      LinalgMarker(Identifier::get("transpose_fusion", context),
+                   Identifier::get("after_transpose_fusion", context)),
+      LinalgMarker(ArrayRef<Identifier>(),
+                   Identifier::get("after_transpose_fusion_producer", context)),
+      LinalgMarker(
+          ArrayRef<Identifier>(),
+          Identifier::get("after_transpose_fusion_original", context)));
 }
 
 static void applyFusionPatterns(MLIRContext *context, FuncOp funcOp) {


        


More information about the Mlir-commits mailing list