[llvm-branch-commits] [mlir] bce318f - [mlir][Linalg] NFC: Refactor LinalgDependenceGraphElem to allow

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Jan 22 11:25:17 PST 2021


Author: MaheshRavishankar
Date: 2021-01-22T11:19:59-08:00
New Revision: bce318f58da3741e6dce143c6713906f3af3d913

URL: https://github.com/llvm/llvm-project/commit/bce318f58da3741e6dce143c6713906f3af3d913
DIFF: https://github.com/llvm/llvm-project/commit/bce318f58da3741e6dce143c6713906f3af3d913.diff

LOG: [mlir][Linalg] NFC: Refactor LinalgDependenceGraphElem to allow
representing dependence from producer result to consumer.

With Linalg on tensors the dependence between operations can be from
the result of the producer to the consumer. This change just does a
NFC refactoring of the LinalgDependenceGraphElem to allow representing
both OpResult and OpOperand*.

Differential Revision: https://reviews.llvm.org/D95208

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
    mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/fusion-pattern.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
index 9aa50c25cd79..5ffe4c6c9461 100644
--- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
+++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
@@ -48,16 +48,104 @@ class LinalgDependenceGraph {
   // TODO: OpOperand tracks dependencies on buffer operands. Tensor result will
   // need an extension to use OpResult.
   struct LinalgDependenceGraphElem {
+    using OpView = PointerUnion<OpOperand *, Value>;
     // dependentOpView may be either:
     //   1. src in the case of dependencesIntoGraphs.
     //   2. dst in the case of dependencesFromDstGraphs.
-    OpOperand *dependentOpView;
+    OpView dependentOpView;
     // View in the op that is used to index in the graph:
     //   1. src in the case of dependencesFromDstGraphs.
     //   2. dst in the case of dependencesIntoGraphs.
-    OpOperand *indexingOpView;
+    OpView indexingOpView;
     // Type of the dependence.
     DependenceType dependenceType;
+
+    // Return the Operation that owns the operand or result represented in
+    // `opView`.
+    static Operation *getOwner(OpView opView) {
+      if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
+        return operand->getOwner();
+      return opView.get<Value>().cast<OpResult>().getOwner();
+    }
+    // Return the operand or the result Value represented by the `opView`.
+    static Value getValue(OpView opView) {
+      if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
+        return operand->get();
+      return opView.get<Value>();
+    }
+    // Return the indexing map of the operand/result in `opView` specified in
+    // the owning LinalgOp. If the owner is not a LinalgOp returns llvm::None.
+    static Optional<AffineMap> getIndexingMap(OpView opView) {
+      auto owner = dyn_cast<LinalgOp>(getOwner(opView));
+      if (!owner)
+        return llvm::None;
+      if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
+        return owner.getIndexingMap(operand->getOperandNumber());
+      return owner.getOutputIndexingMap(
+          opView.get<Value>().cast<OpResult>().getResultNumber());
+    }
+    // Return the operand number if the `opView` is an OpOperand *. Otherwise
+    // return llvm::None.
+    static Optional<unsigned> getOperandNumber(OpView opView) {
+      if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
+        return operand->getOperandNumber();
+      return llvm::None;
+    }
+    // Return the result number if the `opView` is an OpResult. Otherwise return
+    // llvm::None.
+    static Optional<unsigned> getResultNumber(OpView opView) {
+      if (OpResult result = opView.dyn_cast<Value>().cast<OpResult>())
+        return result.getResultNumber();
+      return llvm::None;
+    }
+
+    // Return the owner of the dependent OpView.
+    Operation *getDependentOp() const { return getOwner(dependentOpView); }
+
+    // Return the owner of the indexing OpView.
+    Operation *getIndexingOp() const { return getOwner(indexingOpView); }
+
+    // Return the operand or result stored in the dependentOpView.
+    Value getDependentValue() const { return getValue(dependentOpView); }
+
+    // Return the operand or result stored in the indexingOpView.
+    Value getIndexingValue() const { return getValue(indexingOpView); }
+
+    // If the dependent OpView is an operand, return operand number. Return
+    // llvm::None otherwise.
+    Optional<unsigned> getDependentOpViewOperandNum() const {
+      return getOperandNumber(dependentOpView);
+    }
+
+    // If the indexing OpView is an operand, return operand number. Return
+    // llvm::None otherwise.
+    Optional<unsigned> getIndexingOpViewOperandNum() const {
+      return getOperandNumber(indexingOpView);
+    }
+
+    // If the dependent OpView is a result value, return the result
+    // number. Return llvm::None otherwise.
+    Optional<unsigned> getDependentOpViewResultNum() const {
+      return getResultNumber(dependentOpView);
+    }
+
+    // If the dependent OpView is a result value, return the result
+    // number. Return llvm::None otherwise.
+    Optional<unsigned> getIndexingOpViewResultNum() const {
+      return getResultNumber(indexingOpView);
+    }
+
+    // Return the indexing map of the operand/result in the dependent OpView as
+    // specified in the owner of the OpView.
+    Optional<AffineMap> getDependentOpViewIndexingMap() const {
+      return getIndexingMap(dependentOpView);
+    }
+
+    // Return the indexing map of the operand/result in the indexing OpView as
+    // specified in the owner of the OpView.
+    Optional<AffineMap> getIndexingOpViewIndexingMap() const {
+      return getIndexingMap(indexingOpView);
+    }
   };
   using LinalgDependences = SmallVector<LinalgDependenceGraphElem, 8>;
   using DependenceGraph = DenseMap<Operation *, LinalgDependences>;

diff  --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
index 1042930b1ef7..f80a00bf64d4 100644
--- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
+++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
@@ -218,15 +218,14 @@ LinalgDependenceGraph::findOperationsWithCoveringDependences(
   // TODO: we are not considering paths yet, just interleaved positions.
   for (auto dt : types) {
     for (auto dependence : getDependencesFrom(src, dt)) {
-      auto interimPos =
-          linalgOpPositions.lookup(dependence.dependentOpView->getOwner());
+      auto interimPos = linalgOpPositions.lookup(dependence.getDependentOp());
       // Skip if not interleaved.
       if (interimPos >= dstPos || interimPos <= srcPos)
         continue;
-      Value consumerView = dependence.indexingOpView->get();
+      Value consumerView = dependence.getIndexingValue();
       if (view && !aliases.alias(view, consumerView))
         continue;
-      auto *op = dependence.dependentOpView->getOwner();
+      auto *op = dependence.getDependentOp();
       LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
                         << getDependenceTypeStr(dt) << ": " << *src << " -> "
                         << *op << " on " << consumerView);
@@ -241,7 +240,7 @@ bool LinalgDependenceGraph::hasDependenceFrom(
     ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
   for (auto dep : depTypes)
     for (auto dependence : getDependencesInto(dstLinalgOp, dep))
-      if (dependence.dependentOpView->getOwner() == srcLinalgOp)
+      if (dependence.getDependentOp() == srcLinalgOp)
         return true;
   return false;
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 0deb4e3f59ae..5d37e8f9d782 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -258,11 +258,9 @@ static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
 ///      `producer.getOutputBuffers()`.
 ///   2. Tensor case: `producerIdx` is the index of the tensor in
 ///      `producer.getResults()`.
-static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp,
-                     unsigned producerOutNumber, OpOperand &consumerOpOperand) {
-  AffineMap producerMap = producerOp.getOutputIndexingMap(producerOutNumber);
-  LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerOutNumber
-                          << ", producer map: " << producerMap << "\n");
+static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
+                     OpOperand &consumerOpOperand) {
+  LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n");
   DenseMap<unsigned, Range> fusedLoopsAndRanges;
   Value shapedOperand = consumerOpOperand.get();
   for (auto en : llvm::enumerate(producerMap.getResults())) {
@@ -354,6 +352,8 @@ static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
 findFusableProducer(OpOperand &consumerOpOperand,
                     const LinalgDependenceGraph &dependenceGraph) {
   LinalgOp consumerOp = cast<LinalgOp>(consumerOpOperand.getOwner());
+  // Note that buffer semantics implies that the dependence will only be from
+  // OpOperand -> OpOperand.
   assert(consumerOp.hasBufferSemantics() && "revisit usage of shaped operand");
 
   // Only consider RAW and WAW atm.
@@ -364,22 +364,24 @@ findFusableProducer(OpOperand &consumerOpOperand,
     for (auto dependence : llvm::make_filter_range(
              dependenceGraph.getDependencesInto(consumerOp, depType),
              [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
-               return elem.indexingOpView->get() == consumerOpOperand.get() &&
-                      elem.indexingOpView->getOperandNumber() ==
+               Value v = elem.getIndexingValue();
+               Optional<unsigned> operandNum =
+                   elem.getIndexingOpViewOperandNum();
+               return isa<LinalgOp>(elem.getDependentOp()) &&
+                      v == consumerOpOperand.get() && operandNum &&
+                      operandNum.getValue() ==
                           consumerOpOperand.getOperandNumber();
              })) {
-
       // Consumer consumes this view, `isStructurallyFusableProducer` also
       // checks whether it is a strict subview of the producer view.
-      auto producer = cast<LinalgOp>(dependence.dependentOpView->getOwner());
+      auto producer = cast<LinalgOp>(dependence.getDependentOp());
       LLVM_DEBUG(llvm::dbgs()
                  << "\n"
                  << LinalgDependenceGraph::getDependenceTypeStr(depType)
-                 << "producer: " << *dependence.dependentOpView->getOwner()
-                 << " view: " << dependence.dependentOpView->get()
-                 << " output index: "
-                 << dependence.dependentOpView->getOperandNumber() -
-                        producer.getNumInputs()
+                 << "producer: " << *dependence.getDependentOp() << " view: "
+                 << dependence.getDependentValue() << " output index: "
+                 << (dependence.getDependentOpViewOperandNum().getValue() -
+                     producer.getNumInputs())
                  << "\n");
 
       // Simple fusability checks.
@@ -399,18 +401,21 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
       findFusableProducer(consumerOpOperand, graph);
   if (!fusableDependence)
-    return {};
+    return llvm::None;
+
+  LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
+  if (!producerOp)
+    return llvm::None;
 
-  LinalgOp producerOp =
-      cast<LinalgOp>(fusableDependence->dependentOpView->getOwner());
   // If producer is already in the same block as consumer, we are done.
   if (consumerOpOperand.get().getParentBlock() ==
-      fusableDependence->dependentOpView->get().getParentBlock())
-    return {};
+      fusableDependence->getDependentValue().getParentBlock())
+    return llvm::None;
 
-  unsigned producerIdx =
-      fusableDependence->dependentOpView->getOperandNumber() -
-      producerOp.getNumInputs();
+  Optional<AffineMap> producerMap =
+      fusableDependence->getDependentOpViewIndexingMap();
+  if (!producerMap)
+    return llvm::None;
 
   // Must be a subview or a slice to guarantee there are loops we can fuse
   // into.
@@ -418,7 +423,7 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
   auto slice = consumerOpOperand.get().getDefiningOp<SliceOp>();
   if (!subView && !slice) {
     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)");
-    return {};
+    return llvm::None;
   }
 
   // Fuse `producer` just before `consumer`.
@@ -428,7 +433,7 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: "
                           << *consumerOpOperand.getOwner() << "\n");
 
-  auto fusedProducer = fuse(b, producerOp, producerIdx, consumerOpOperand);
+  auto fusedProducer = fuse(b, producerOp, *producerMap, consumerOpOperand);
   return FusionInfo{producerOp, fusedProducer};
 }
 
@@ -474,8 +479,13 @@ Optional<FusionInfo>
 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
                                    OpOperand &consumerOpOperand) {
   auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner());
-  assert(producerOp && "expected Linalg producer");
-  LinalgOp consumerOp = cast<LinalgOp>(consumerOpOperand.getOwner());
+  if (!producerOp)
+    return llvm::None;
+
+  LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
+  if (!consumerOp)
+    return llvm::None;
+
   Value inputTensor = consumerOpOperand.get();
 
   // Must be a subtensor to guarantee there are loops we can fuse into.
@@ -496,8 +506,10 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
   b.setInsertionPoint(consumerOp);
   ScopedContext scope(b, consumerOp->getLoc());
   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
-  LinalgOp fusedProducer = fuse(
-      b, producerOp, producerOpResult.getResultNumber(), consumerOpOperand);
+  LinalgOp fusedProducer =
+      fuse(b, producerOp,
+           producerOp.getOutputIndexingMap(producerOpResult.getResultNumber()),
+           consumerOpOperand);
 
   // Replace use.
   // Canonicalizations are not guaranteed to have happened before constructing
@@ -531,30 +543,34 @@ static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
 ///       inverse(producerIndexMap).compose(consumerIndexMap)
 static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
     LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
-  auto producer = cast<LinalgOp>(dependence.dependentOpView->getOwner());
-  AffineMap producerIndexingMap =
-      producer.getIndexingMap(dependence.dependentOpView->getOperandNumber());
-  auto consumer = cast<LinalgOp>(dependence.indexingOpView->getOwner());
-  AffineMap consumerIndexingMap =
-      consumer.getIndexingMap(dependence.indexingOpView->getOperandNumber());
+  auto producer = dyn_cast<LinalgOp>(dependence.getDependentOp());
+  if (!producer)
+    return None;
+
+  Optional<AffineMap> producerIndexingMap =
+      dependence.getDependentOpViewIndexingMap();
+  Optional<AffineMap> consumerIndexingMap =
+      dependence.getIndexingOpViewIndexingMap();
+  if (!producerIndexingMap || !consumerIndexingMap)
+    return None;
 
   AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
-      producer.iterator_types().getValue(), producerIndexingMap);
+      producer.iterator_types().getValue(), *producerIndexingMap);
   if (!prunedProducerIndexingMap.isPermutation())
     return None;
 
-  if (consumerIndexingMap.getNumResults() !=
+  if (consumerIndexingMap->getNumResults() !=
       prunedProducerIndexingMap.getNumResults())
     return None;
 
   LLVM_DEBUG({
     llvm::dbgs() << "\t producerMap : ";
-    producerIndexingMap.print(llvm::dbgs());
+    producerIndexingMap->print(llvm::dbgs());
     llvm::dbgs() << "  pruned : ";
     prunedProducerIndexingMap.print(llvm::dbgs());
     llvm::dbgs() << "\n";
     llvm::dbgs() << "\t consumerMap : ";
-    consumerIndexingMap.print(llvm::dbgs());
+    consumerIndexingMap->print(llvm::dbgs());
     llvm::dbgs() << "\n";
   });
 
@@ -562,7 +578,7 @@ static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
   if (!invProducerIndexMap)
     return None;
 
-  return invProducerIndexMap.compose(consumerIndexingMap);
+  return invProducerIndexMap.compose(*consumerIndexingMap);
 }
 
 /// Given a projected permutation `map`, returns true if the map changes the
@@ -710,10 +726,7 @@ collectFusableLoops(ArrayRef<LinalgOp> ops,
 FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
     ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {
   FusableOpDependencesTy fusableDependences;
-  // TODO: Currently fusion would not be legal if the fusable dependence is to
-  // the same producer but 
diff erent indexing map in the consumer. Fix this, but
-  // in the meanwhile disallow such a fusion.
-  DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
+  DenseMap<Operation *, SmallVector<AffineMap, 1>> fusedProducerIndexingMap;
   for (LinalgOp op : reverse(ops)) {
     for (OpOperand &opOperand : op.getShapedOpOperands()) {
       Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
@@ -721,54 +734,47 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
       if (!fusableDependence)
         continue;
       LinalgOp producerOp =
-          cast<LinalgOp>(fusableDependence->dependentOpView->getOwner());
+          dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
+      if (!producerOp)
+        continue;
       // Do not fuse dependences that are to operations not in the same basic
       // block. This avoid moving fused operations across loops that might
       // themselves carry dependency making the fusion illegal.
-      if (producerOp->getBlock() != op->getBlock()) {
-        op.emitRemark("unhandled fusion of ops in 
diff erent basic blocks");
-        return FusableOpDependencesTy{};
-      }
+      if (producerOp->getBlock() != op->getBlock())
+        continue;
+
       // Make sure that the indexing map of the view used for fusion in the
       // producer is a projected permutation.
-      unsigned producerIdx =
-          fusableDependence->dependentOpView->getOperandNumber();
-      AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
-      if (!producerMap.isProjectedPermutation()) {
-        op.emitRemark(
-            "unhandled non permutation indexing map for fused view in "
-            "producer for operand at index ")
-            << opOperand.getOperandNumber();
-        return FusableOpDependencesTy{};
-      }
-
-      unsigned consumerIdx =
-          fusableDependence->indexingOpView->getOperandNumber();
-      AffineMap consumerMap = op.getIndexingMap(consumerIdx);
-      if (!consumerMap.isProjectedPermutation()) {
-        op.emitRemark(
-            "unhandled case where indexing map for fused view in the consumer "
-            "is not a projected permutation while fusing at index ")
-            << opOperand.getOperandNumber();
-        return FusableOpDependencesTy{};
-      }
-
-      // Check if the producer is already a fusion candidate. Cannot fuse this
-      // dependence if it has a 
diff erent indexing map when used in the
-      // consumer.
-      if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
-          fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
-        op.emitRemark(
-            "unhandled fusion to the same producer but with 
diff erent "
-            "indexing maps");
-        return FusableOpDependencesTy{};
-      }
-      fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
+      Optional<AffineMap> producerMap =
+          fusableDependence->getDependentOpViewIndexingMap();
+      Optional<AffineMap> consumerMap =
+          fusableDependence->getIndexingOpViewIndexingMap();
+      assert(
+          consumerMap &&
+          "unable to find indexing map of operand/result of indexing OpView");
+      fusedProducerIndexingMap[producerOp.getOperation()].push_back(
+          *consumerMap);
+      if (!producerMap || !producerMap->isProjectedPermutation() ||
+          !consumerMap->isProjectedPermutation())
+        continue;
 
       fusableDependences[producerOp.getOperation()].push_back(
           *fusableDependence);
     }
   }
+  // TODO: Currently fusion would not be legal if the fusable dependence is to
+  // the same producer but 
diff erent indexing map in the consumer. Fix this, but
+  // in the meanwhile disallow such a fusion.
+  for (auto useIndexingMapsList : fusedProducerIndexingMap) {
+    AffineMap map1 = useIndexingMapsList.second.front();
+    for (AffineMap map2 :
+         ArrayRef<AffineMap>(useIndexingMapsList.second).drop_front()) {
+      if (map1 != map2) {
+        fusableDependences.erase(useIndexingMapsList.first);
+        break;
+      }
+    }
+  }
   return fusableDependences;
 }
 
@@ -819,7 +825,7 @@ static Optional<TiledAndFusedLinalgOps>
 tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
                          const LinalgDependenceGraph &dependenceGraph,
                          const LinalgTilingOptions &tilingOptions) {
-  if (ops.empty())
+  if (ops.size() < 2)
     return llvm::None;
   LinalgOp rootOp = ops.back();
   for (auto op : enumerate(ops)) {
@@ -827,14 +833,14 @@ tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
     // buffers. This check can be removed after it is tested on tensors.
     LinalgOp linalgOp = op.value();
     if (!linalgOp.hasBufferSemantics()) {
-      linalgOp.emitError("tile and fuse only tested for buffer operation");
+      linalgOp.emitRemark("tile and fuse only tested for buffer operation");
       return llvm::None;
     }
   }
   // TODO: Support interchange with tile + fuse. This might actually help do
   // better fusion.
   if (!tilingOptions.interchangeVector.empty()) {
-    rootOp.emitError("unable to handle tile and fuse with interchange");
+    rootOp.emitRemark("unable to handle tile and fuse with interchange");
     return llvm::None;
   }
 
@@ -864,7 +870,7 @@ tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
   Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
       builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
   if (!tiledRootOp) {
-    rootOp.emitError("failed to tile the fused loops");
+    rootOp.emitRemark("failed to tile the fused loops");
     return llvm::None;
   }
   ret.op = tiledRootOp->op;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index b6171ff9c5b1..283ff20f611b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -161,12 +161,16 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
 
   DenseSet<Operation *> producers;
   producers.insert(linalgOp);
-  for (auto dependence : dependenceGraph.getDependentOperations(linalgOp)) {
-    if (!fusionOptions.indicesToFuse.count(
-            dependence.indexingOpView->getOperandNumber()))
+  for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
+    Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
+    // When looking at dependences into, indexingOp is always OpOperand. We
+    // could assert, but continue if this is not the case.
+    if (!operandNumber)
       continue;
-    if (isa<LinalgOp>(dependence.dependentOpView->getOwner()))
-      producers.insert(dependence.dependentOpView->getOwner());
+    if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
+      continue;
+    if (isa<LinalgOp>(dependence.getDependentOp()))
+      producers.insert(dependence.getDependentOp());
   }
 
   SmallVector<LinalgOp, 1> fusionOps;

diff  --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
index d14effa027ad..ca30a32e75b1 100644
--- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file | FileCheck %s
 
 module {
   func @basic_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
@@ -371,7 +371,6 @@ module {
     %2 = alloc(%0, %1) : memref<?x?xf32>
     linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
       outs(%2 : memref<?x?xf32>)
-    // expected-remark @+1 {{unhandled fusion to the same producer but with 
diff erent indexing maps}}
     linalg.generic
       {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                         affine_map<(d0, d1) -> (d1, d0)>,
@@ -387,6 +386,15 @@ module {
     return
   }
 }
+// CHECK-LABEL: func @matmul_plus_transpose_matmul
+//   CHECK-NOT:   scf.parallel
+//   CHECK-NOT:   scf.for
+//       CHECK:   linalg.matmul
+//   CHECK-NOT:   scf.parallel
+//   CHECK-NOT:   scf.for
+//       CHECK:   linalg.generic
+//   CHECK-NOT:   scf.parallel
+//   CHECK-NOT:   scf.for
 
 // -----
 
@@ -416,7 +424,6 @@ module {
         %6 = subview %arg0[%arg3, %arg5] [%3, %5] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
         %7 = subview %arg1[%arg5, %arg4] [%5, %4] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
         %8 = subview %arg2[%arg3, %arg4] [%3, %4] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
-	// expected-remark @+1 {{unhandled fusion of ops in 
diff erent basic blocks}}
         linalg.matmul {__internal_linalg_transform__ = "basic_fusion"}
           ins(%6, %7 : memref<?x?xf32, #map3>, memref<?x?xf32, #map3>)
           outs(%8 : memref<?x?xf32, #map3>)
@@ -426,6 +433,13 @@ module {
     return
   }
 }
+// CHECK-LABEL: func @basic_no_fusion
+//   CHECK-NOT:   scf.parallel
+//       CHECK:   linalg.fill
+//       CHECK:   scf.parallel
+//       CHECK:     scf.for
+//   CHECK-NOT:     linalg.fill
+//       CHECK:     linalg.matmul
 
 // -----
 


        


More information about the llvm-branch-commits mailing list