[Mlir-commits] [mlir] 80f0785 - [mlir][Linalg] NFC - Refactor fusion APIs

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jan 12 06:32:12 PST 2021


Author: Nicolas Vasilache
Date: 2021-01-12T14:27:15Z
New Revision: 80f078548868d0dd3d74ab8a1deb8aa46870cdf3

URL: https://github.com/llvm/llvm-project/commit/80f078548868d0dd3d74ab8a1deb8aa46870cdf3
DIFF: https://github.com/llvm/llvm-project/commit/80f078548868d0dd3d74ab8a1deb8aa46870cdf3.diff

LOG: [mlir][Linalg] NFC - Refactor fusion APIs

This revision uniformizes fusion APIs to allow passing OpOperand, OpResult and adds a finer level of control fusion.

Differential Revision: https://reviews.llvm.org/D94493

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index 3fc3fa4a5556..f3b7181d71a5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -726,6 +726,18 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
           getNumShapedOperands());
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the OpOperands for all the shaped operands.
+      }],
+      /*retTy=*/" OpOperand&",
+      /*methodName=*/"getShapedOpOperand",
+      /*args=*/(ins "unsigned":$i),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return *(this->getShapedOpOperands().begin() + i);
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the range over input and output operands.

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d816414ef8b4..de1658f96a87 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -35,6 +35,7 @@ struct TiledLinalgOp {
   LinalgOp op;
   SmallVector<Operation *, 8> loops;
   SmallVector<Value, 4> tensorResults;
+  TiledLinalgOp &operator=(const TiledLinalgOp &) = default;
 };
 
 /// Populates patterns for vectorization of all ConvN-D ops.
@@ -412,9 +413,8 @@ struct LinalgBaseTilingPattern : public RewritePattern {
                           LinalgTilingOptions options,
                           LinalgMarker marker = LinalgMarker(),
                           PatternBenefit benefit = 1);
-  LogicalResult
-  matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
-                      SmallVectorImpl<Value> &tensorResults) const;
+  LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
+                                    TiledLinalgOp &result) const;
 
 private:
   /// LinalgTransformMarker handles special attribute manipulations.
@@ -432,14 +432,14 @@ struct LinalgTilingPattern : public LinalgBaseTilingPattern {
                                 marker, benefit) {}
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    SmallVector<Value, 4> tensorResults;
+    TiledLinalgOp tiledLinalgOp;
     if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
-                                                            tensorResults)))
+                                                            tiledLinalgOp)))
       return failure();
-    if (tensorResults.empty())
+    if (tiledLinalgOp.tensorResults.empty())
       rewriter.eraseOp(op);
     else
-      rewriter.replaceOp(op, tensorResults);
+      rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
     return success();
   }
 };

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 2ef32cfe378b..f194209f1910 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -92,26 +92,31 @@ findAllFusableDependences(ArrayRef<LinalgOp> ops,
 
 /// Fuses producer into consumer if the producer is structurally feasible and
 /// the fusion would not violate dependencies.
-/// Implements the fusion part of the "tileAndFuse on buffers"
-/// transformation and thus requires the `consumerdIdx`^th operand of `consumer`
-/// to be a `subview` op (generally obtained by applying the tiling
-/// transformation).
-Optional<FusionInfo> fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
-                                          unsigned consumerIdx,
+/// Implements the fusion part of the "tileAndFuse on buffers" transformation
+/// and thus requires the `consumerOpOperand` to be a `subview` op (generally
+/// obtained by applying the tiling transformation).
+Optional<FusionInfo> fuseProducerOfBuffer(OpBuilder &b,
+                                          OpOperand &consumerOpOperand,
                                           const LinalgDependenceGraph &graph);
 /// Tensor counterpart of `fuseProducerOfBuffer`.
 /// This implements the fusion part of the "tileAndFuse on tensors"
-/// transformation and thus requires the `consumerdIdx`^th operand of `consumer`
-/// to be the result of a `subtensor` op (generally obtained by applying the
-/// tiling transformation).
-Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer,
-                                          unsigned consumerIdx);
+/// transformation and thus requires the `consumerOpOperand` to be a `subtensor`
+/// op (generally obtained by applying the tiling transformation).
+Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
+                                          OpOperand &consumerOpOperand);
+/// Tensor counterpart of `fuseProducerOfBuffer`.
+/// This implements the fusion part of the "tileAndFuse on tensors"
+/// transformation and thus requires the `consumerOpOperand` to be a `subtensor`
+/// op (generally obtained by applying the tiling transformation).
+/// Assumes `producerOfTensor` is a Linalg op that produces `consumerOpOperand`.
+Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
+                                          OpResult producerOpResult,
+                                          OpOperand &consumerOpOperand);
 
 /// Fuse linalg operation on tensors, with the producer of the operand at
 /// position `consumerIdx` of the consumer.
 Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
-                                              Operation *consumer,
-                                              unsigned consumerIdx);
+                                              OpOperand &consumerOpOperand);
 
 /// Like `getShape`, but only returns statically-known information, without
 /// generating any new IR. For each shape dimension, returns >=0 if that

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index f9908af29313..8f02f3d83cf1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -258,20 +258,19 @@ static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
 ///      `producer.getOutputBuffers()`.
 ///   2. Tensor case: `producerIdx` is the index of the tensor in
 ///      `producer.getResults()`.
-static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
-                     LinalgOp consumer, unsigned consumerIdx) {
-  AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
-  LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
+static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp,
+                     unsigned producerOutNumber, OpOperand &consumerOpOperand) {
+  AffineMap producerMap = producerOp.getOutputIndexingMap(producerOutNumber);
+  LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerOutNumber
                           << ", producer map: " << producerMap << "\n");
   DenseMap<unsigned, Range> fusedLoopsAndRanges;
-  Location loc = consumer.getLoc();
-  Value shapedOperand = consumer.getShapedOperand(consumerIdx);
+  Value shapedOperand = consumerOpOperand.get();
   for (auto en : llvm::enumerate(producerMap.getResults())) {
     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
-    fusedLoopsAndRanges[posInProducerLoop] =
-        getRangeFromOperandShape(b, loc, shapedOperand, en.index());
+    fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape(
+        b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index());
   }
-  return fuse(b, producer, fusedLoopsAndRanges);
+  return fuse(b, producerOp, fusedLoopsAndRanges);
 }
 
 // Encode structural fusion safety preconditions.
@@ -378,9 +377,10 @@ static bool isSameSubView(Value a, Value b) {
 }
 
 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
-findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
+findFusableProducer(OpOperand &consumerOpOperand,
                     const LinalgDependenceGraph &dependenceGraph) {
-  assert(consumer.hasBufferSemantics() && "revisit usage of shaped operand");
+  LinalgOp consumerOp = cast<LinalgOp>(consumerOpOperand.getOwner());
+  assert(consumerOp.hasBufferSemantics() && "revisit usage of shaped operand");
 
   // Only consider RAW and WAW atm.
   for (auto depType : {
@@ -388,21 +388,16 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
            LinalgDependenceGraph::DependenceType::WAW,
        }) {
     for (auto dependence : llvm::make_filter_range(
-             dependenceGraph.getDependencesInto(consumer, depType),
-             [consumerIdx](
-                 LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
-               return elem.indexingOpView->getOperandNumber() == consumerIdx;
+             dependenceGraph.getDependencesInto(consumerOp, depType),
+             [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
+               return elem.indexingOpView->get() == consumerOpOperand.get() &&
+                      elem.indexingOpView->getOperandNumber() ==
+                          consumerOpOperand.getOperandNumber();
              })) {
 
-      // Check that the dependence is indeed on the input `consumerIdx` view.
-      Value consumedView = dependence.indexingOpView->get();
-      if (!isSameSubView(consumer.getShapedOperand(consumerIdx), consumedView))
-        continue;
-
       // Consumer consumes this view, `isStructurallyFusableProducer` also
       // checks whether it is a strict subview of the producer view.
       auto producer = cast<LinalgOp>(dependence.dependentOpView->getOwner());
-      Value producedView = dependence.dependentOpView->get();
       LLVM_DEBUG(llvm::dbgs()
                  << "\n"
                  << LinalgDependenceGraph::getDependenceTypeStr(depType)
@@ -412,10 +407,10 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
                  << dependence.dependentOpView->getOperandNumber() -
                         producer.getNumInputs()
                  << "\n");
-      (void)producedView;
 
       // Simple fusability checks.
-      if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
+      if (!isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(),
+                         producer))
         continue;
 
       return dependence;
@@ -425,29 +420,28 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
 }
 
 Optional<FusionInfo>
-mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
-                                   unsigned consumerIdx,
+mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
                                    const LinalgDependenceGraph &graph) {
   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
-      findFusableProducer(consumer, consumerIdx, graph);
+      findFusableProducer(consumerOpOperand, graph);
   if (!fusableDependence)
     return {};
 
   LinalgOp producerOp =
       cast<LinalgOp>(fusableDependence->dependentOpView->getOwner());
   // If producer is already in the same block as consumer, we are done.
-  if (consumer->getBlock() == producerOp->getBlock())
+  if (consumerOpOperand.get().getParentBlock() ==
+      fusableDependence->dependentOpView->get().getParentBlock())
     return {};
 
   unsigned producerIdx =
       fusableDependence->dependentOpView->getOperandNumber() -
       producerOp.getNumInputs();
-  Value consumerView = consumer.getShapedOperand(consumerIdx);
 
   // Must be a subview or a slice to guarantee there are loops we can fuse
   // into.
-  auto subView = consumerView.getDefiningOp<SubViewOp>();
-  auto slice = consumerView.getDefiningOp<SliceOp>();
+  auto subView = consumerOpOperand.get().getDefiningOp<SubViewOp>();
+  auto slice = consumerOpOperand.get().getDefiningOp<SliceOp>();
   if (!subView && !slice) {
     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)");
     return {};
@@ -455,25 +449,25 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
 
   // Fuse `producer` just before `consumer`.
   OpBuilder::InsertionGuard g(b);
-  b.setInsertionPoint(consumer.getOperation());
-  ScopedContext scope(b, consumer.getLoc());
-  LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
+  b.setInsertionPoint(consumerOpOperand.getOwner());
+  ScopedContext scope(b, consumerOpOperand.getOwner()->getLoc());
+  LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: "
+                          << *consumerOpOperand.getOwner() << "\n");
 
-  auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
+  auto fusedProducer = fuse(b, producerOp, producerIdx, consumerOpOperand);
   return FusionInfo{producerOp, fusedProducer};
 }
 
 /// Walk back use-def chain through scf::For yields.
 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
-static void getProducerOfTensor(Value tensor, LinalgOp &producer,
-                                unsigned &outputIndex) {
+static void getProducerOfTensor(Value tensor, OpResult &opResult) {
   if (!tensor.getType().isa<RankedTensorType>())
     return;
 
   while (true) {
+    LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor);
     if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
-      producer = linalgOp;
-      outputIndex = tensor.cast<OpResult>().getResultNumber();
+      opResult = tensor.cast<OpResult>();
       return;
     }
     if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) {
@@ -482,7 +476,7 @@ static void getProducerOfTensor(Value tensor, LinalgOp &producer,
     }
     if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
       if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
-        tensor = forOp.getResult(blockArg.getArgNumber());
+        tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber());
         continue;
       }
     }
@@ -490,45 +484,58 @@ static void getProducerOfTensor(Value tensor, LinalgOp &producer,
   }
 }
 
-Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
-                                                        LinalgOp consumer,
-                                                        unsigned consumerIdx) {
-  Value inputTensor = consumer.getInput(consumerIdx);
-  LinalgOp producerOp;
-  unsigned producerIdx;
-  getProducerOfTensor(inputTensor, producerOp, producerIdx);
+Optional<FusionInfo>
+mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
+  Value inputTensor = consumerOpOperand.get();
+  OpResult producerOpResult;
+  getProducerOfTensor(inputTensor, producerOpResult);
+  if (!producerOpResult) {
+    LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer");
+    return {};
+  }
+  return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
+}
+
+Optional<FusionInfo>
+mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
+                                   OpOperand &consumerOpOperand) {
+  auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner());
+  assert(producerOp && "expected Linalg producer");
+  LinalgOp consumerOp = cast<LinalgOp>(consumerOpOperand.getOwner());
+  Value inputTensor = consumerOpOperand.get();
 
   // Must be a subtensor to guarantee there are loops we can fuse into.
   auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
-  if (!subTensor || !producerOp) {
-    LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subtensor)");
+  if (!subTensor) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "\nNot fusable, not a subtensor: " << inputTensor);
     return {};
   }
 
   // If producer is already in the same block as consumer, we are done.
-  if (consumer->getBlock() == producerOp->getBlock())
+  if (consumerOpOperand.get().getParentBlock() ==
+      producerOpResult.getParentBlock())
     return {};
 
   // Insert fused `producer` just before `consumer`.
   OpBuilder::InsertionGuard g(b);
-  b.setInsertionPoint(consumer.getOperation());
-  ScopedContext scope(b, consumer.getLoc());
-  LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
-  LinalgOp fusedProducer =
-      fuse(b, producerOp, producerIdx, consumer, consumerIdx);
+  b.setInsertionPoint(consumerOp);
+  ScopedContext scope(b, consumerOp->getLoc());
+  LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
+  LinalgOp fusedProducer = fuse(
+      b, producerOp, producerOpResult.getResultNumber(), consumerOpOperand);
 
   // Replace use.
   // Canonicalizations are not guaranteed to have happened before constructing
   // `fusedProducer`. In the tensor case this can result in temporary type
   // mismatches. Insert a `tensor.cast` op to propagate the transformation
   // invariant that types are compatible.
-  Value def = fusedProducer->getResult(producerIdx);
-  OpOperand &use = consumer->getOpOperand(consumerIdx);
-  Type consumerType = use.get().getType();
+  Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
+  Type consumerType = consumerOpOperand.get().getType();
   if (consumerType != def.getType())
     def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
-  use.set(def);
-  return FusionInfo{producerOp, fusedProducer};
+  consumerOpOperand.set(def);
+  return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer};
 }
 
 /// Prune all dimensions that are of reduction iterator type from `map`.
@@ -734,11 +741,9 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
   // in the meanwhile disallow such a fusion.
   DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
   for (LinalgOp op : reverse(ops)) {
-    for (auto operandIndex :
-         llvm::seq<unsigned>(0, op.getNumShapedOperands())) {
+    for (OpOperand &opOperand : op.getShapedOpOperands()) {
       Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
-          fusableDependence =
-              findFusableProducer(op, operandIndex, dependenceGraph);
+          fusableDependence = findFusableProducer(opOperand, dependenceGraph);
       if (!fusableDependence)
         continue;
       LinalgOp producerOp =
@@ -759,7 +764,7 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
         op.emitRemark(
             "unhandled non permutation indexing map for fused view in "
             "producer for operand at index ")
-            << operandIndex;
+            << opOperand.getOperandNumber();
         return FusableOpDependencesTy{};
       }
 
@@ -770,7 +775,7 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
         op.emitRemark(
             "unhandled case where indexing map for fused view in the consumer "
             "is not a projected permutation while fusing at index ")
-            << operandIndex;
+            << opOperand.getOperandNumber();
         return FusableOpDependencesTy{};
       }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 833662d282b6..670d456ad2f2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -178,8 +178,10 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
 }
 
 static Optional<SmallVector<Value, 1>>
-fuseTensorOpsImpl(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx,
+fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
                   PatternRewriter &rewriter) {
+  LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
+  unsigned consumerIdx = consumerOpOperand.getOperandNumber();
   if (!areTensorOpsFusable(producer, consumer, consumerIdx))
     return llvm::None;
 
@@ -1027,21 +1029,19 @@ struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
 } // namespace
 
 Optional<SmallVector<Value, 1>>
-mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
-                            unsigned consumerIdx) {
-  if (consumerIdx >= consumer->getNumOperands())
-    return llvm::None;
-  Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp();
+mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
+                            OpOperand &consumerOpOperand) {
+  Operation *producer = consumerOpOperand.get().getDefiningOp();
   if (!producer || producer->getNumResults() != 1)
     return llvm::None;
 
   // Fuse when consumer is GenericOp or IndexedGenericOp.
-  if (!isa<GenericOp, IndexedGenericOp>(consumer) ||
+  if (!isa<GenericOp, IndexedGenericOp>(consumerOpOperand.getOwner()) ||
       !isa<GenericOp, IndexedGenericOp>(producer))
     return llvm::None;
 
-  return fuseTensorOpsImpl(cast<LinalgOp>(producer), cast<LinalgOp>(consumer),
-                           consumerIdx, rewriter);
+  return fuseTensorOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
+                           rewriter);
 }
 
 namespace {
@@ -1053,12 +1053,12 @@ struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
   LogicalResult matchAndRewrite(LinalgOpTy op,
                                 PatternRewriter &rewriter) const override {
     // Find the first operand that is defined by another generic op on tensors.
-    for (auto operandNum : llvm::seq<unsigned>(0, op->getNumOperands())) {
-      Operation *producer = op->getOperand(operandNum).getDefiningOp();
+    for (OpOperand &opOperand : op.getShapedOpOperands()) {
+      Operation *producer = opOperand.get().getDefiningOp();
       if (!producer)
         continue;
       Optional<SmallVector<Value, 1>> fusedOpResults =
-          fuseTensorOps(rewriter, op, operandNum);
+          fuseTensorOps(rewriter, opOperand);
       if (fusedOpResults) {
         rewriter.replaceOp(op, *fusedOpResults);
         if (producer->use_empty())

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index c5d811c41edb..5b6302a7e5a2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -117,8 +117,7 @@ mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
       options(options) {}
 
 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
-    Operation *op, PatternRewriter &rewriter,
-    SmallVectorImpl<Value> &tensorResults) const {
+    Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
   if (!linalgOp)
     return failure();
@@ -131,7 +130,7 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
     return failure();
 
   // Return relevant information to derived pattern.
-  tensorResults = res->tensorResults;
+  result = *res;
 
   // New marker if specified.
   marker.replaceLinalgMarker(rewriter, res->op.getOperation());

diff  --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 046fad43c3bf..5d55f0375f37 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -135,14 +135,14 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
   bool changed = false;
   for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
-    for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
-      if (en.value().getType().isa<MemRefType>()) {
+    for (OpOperand &opOperand : linalgOp.getShapedOpOperands()) {
+      if (opOperand.get().getType().isa<MemRefType>()) {
         // TODO: LinalgDependenceGraph should be able to update itself.
         // The current naive and expensive reconstruction of the graph should be
         // removed.
         linalg::Aliases aliases;
         linalg::LinalgDependenceGraph graph(aliases, linalgOps);
-        if (auto info = fuseProducerOfBuffer(b, linalgOp, en.index(), graph)) {
+        if (auto info = fuseProducerOfBuffer(b, opOperand, graph)) {
           auto *originalOp = info->originalProducer.getOperation();
           eraseSet.insert(originalOp);
           auto *originalOpInLinalgOpsVector =
@@ -151,11 +151,11 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
           changed = true;
         }
       } else {
-        assert(en.value().getType().isa<RankedTensorType>());
-        // Tile and Fuse tensor input (TODO: init_tensors too).
-        if (en.index() >= linalgOp.getNumInputs())
+        assert(opOperand.get().getType().isa<RankedTensorType>());
+        // Tile and Fuse tensor input.
+        if (opOperand.getOperandNumber() >= linalgOp.getNumInputs())
           continue;
-        if (auto info = fuseProducerOfTensor(b, linalgOp, en.index())) {
+        if (auto info = fuseProducerOfTensor(b, opOperand)) {
           auto *originalOp = info->originalProducer.getOperation();
           auto *originalOpInLinalgOpsVector =
               std::find(linalgOps.begin(), linalgOps.end(), originalOp);


        


More information about the Mlir-commits mailing list