[Mlir-commits] [mlir] 01defcc - [mlir][Linalg] Extend tile+fuse to work on Linalg operation on tensors.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 22 11:34:01 PST 2021


Author: MaheshRavishankar
Date: 2021-01-22T11:33:35-08:00
New Revision: 01defcc8d74e65f3d304274bc4ede44d838ff22b

URL: https://github.com/llvm/llvm-project/commit/01defcc8d74e65f3d304274bc4ede44d838ff22b
DIFF: https://github.com/llvm/llvm-project/commit/01defcc8d74e65f3d304274bc4ede44d838ff22b.diff

LOG: [mlir][Linalg] Extend tile+fuse to work on Linalg operation on tensors.

Differential Revision: https://reviews.llvm.org/D93086

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
    mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/test/Dialect/Linalg/fusion-sequence.mlir
    mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
index 5ffe4c6c9461..fecaeff1c8df 100644
--- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
+++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
@@ -247,8 +247,9 @@ class LinalgDependenceGraph {
   // Uses std::pair to keep operations and view together and avoid usage errors
   // related to src/dst and producer/consumer terminology in the context of
   // dependences.
-  void addDependenceElem(DependenceType dt, OpOperand *indexingOpView,
-                         OpOperand *dependentOpView);
+  void addDependenceElem(DependenceType dt,
+                         LinalgDependenceGraphElem::OpView indexingOpView,
+                         LinalgDependenceGraphElem::OpView dependentOpView);
 
   /// Implementation detail for findCoveringxxx.
   SmallVector<Operation *, 8>

diff  --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
index f80a00bf64d4..59004867a333 100644
--- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
+++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
@@ -113,18 +113,21 @@ LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
   }
 }
 
-void LinalgDependenceGraph::addDependenceElem(DependenceType dt,
-                                              OpOperand *indexingOpView,
-                                              OpOperand *dependentOpView) {
+void LinalgDependenceGraph::addDependenceElem(
+    DependenceType dt, LinalgDependenceGraphElem::OpView indexingOpView,
+    LinalgDependenceGraphElem::OpView dependentOpView) {
   LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t ("
-                    << indexingOpView->get() << " @"
-                    << indexingOpView->getOperandNumber() << ") -> \n\t\t("
-                    << dependentOpView->get() << " @"
-                    << dependentOpView->getOperandNumber() << ")");
-  dependencesFromGraphs[dt][indexingOpView->getOwner()].push_back(
-      LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt});
-  dependencesIntoGraphs[dt][dependentOpView->getOwner()].push_back(
-      LinalgDependenceGraphElem{indexingOpView, dependentOpView, dt});
+                    << LinalgDependenceGraphElem::getValue(indexingOpView)
+                    << " @) -> \n\t\t("
+                    << LinalgDependenceGraphElem::getValue(dependentOpView)
+                    << " @)");
+  dependencesFromGraphs[dt][LinalgDependenceGraphElem::getOwner(indexingOpView)]
+      .push_back(
+          LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt});
+  dependencesIntoGraphs[dt]
+                       [LinalgDependenceGraphElem::getOwner(dependentOpView)]
+                           .push_back(LinalgDependenceGraphElem{
+                               indexingOpView, dependentOpView, dt});
 }
 
 LinalgDependenceGraph::dependence_range
@@ -158,6 +161,18 @@ LinalgDependenceGraph::getDependencesInto(
 }
 
 void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
+  if (src.hasTensorSemantics() && dst.hasTensorSemantics()) {
+    for (OpOperand &dstOpOperand : dst.getInputOpOperands()) {
+      // Check if the operand is defined by the src.
+      auto definingOp = dstOpOperand.get().getDefiningOp<LinalgOp>();
+      if (definingOp && definingOp == src)
+        addDependenceElem(DependenceType::RAW, dstOpOperand.get(),
+                          &dstOpOperand);
+    }
+    return;
+  }
+  assert(src.hasBufferSemantics() && dst.hasBufferSemantics() &&
+         "unhandled dependence tracking for mixed buffer/tensor operations");
   for (OpOperand *srcOpOperand : src.getOutputBuffersOpOperands()) { // W
     // RAW graph
     for (OpOperand *dstOpOperand : dst.getInputBuffersOpOperands()) // R

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 5d37e8f9d782..714bb0f97777 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -25,6 +25,7 @@
 #include "mlir/IR/Dominance.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -348,13 +349,15 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
   return true;
 }
 
+/// For `consumer` with buffer semantics, find the Linalg operation on buffers
+/// that is the last writer of `consumerOpOperand`. For now the fusable
+/// dependence is returned as an instance of the `dependenceGraph`.
 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
 findFusableProducer(OpOperand &consumerOpOperand,
                     const LinalgDependenceGraph &dependenceGraph) {
-  LinalgOp consumerOp = cast<LinalgOp>(consumerOpOperand.getOwner());
-  // Note that buffer semantics implies that the dependence will only be from
-  // OpOperand -> OpOperand.
-  assert(consumerOp.hasBufferSemantics() && "revisit usage of shaped operand");
+  LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
+  if (!consumerOp)
+    return {};
 
   // Only consider RAW and WAW atm.
   for (auto depType : {
@@ -378,18 +381,21 @@ findFusableProducer(OpOperand &consumerOpOperand,
       LLVM_DEBUG(llvm::dbgs()
                  << "\n"
                  << LinalgDependenceGraph::getDependenceTypeStr(depType)
-                 << "producer: " << *dependence.getDependentOp() << " view: "
-                 << dependence.getDependentValue() << " output index: "
-                 << (dependence.getDependentOpViewOperandNum().getValue() -
-                     producer.getNumInputs())
-                 << "\n");
-
-      // Simple fusability checks.
-      if (!isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(),
-                         producer))
-        continue;
-
-      return dependence;
+                 << "producer: " << *dependence.getDependentOp()
+                 << " view: " << dependence.getDependentValue() << "\n");
+
+      // If the producer and consumer have tensor semantics, the only dependence
+      // between them is through a RAW dependence and they are fusable by
+      // construction. For buffer semantics need additional checks.
+      if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() &&
+          isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(),
+                        producer))
+        return dependence;
+      if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) {
+        assert(dependence.dependenceType ==
+               LinalgDependenceGraph::DependenceType::RAW);
+        return dependence;
+      }
     }
   }
   return {};
@@ -439,6 +445,10 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
 
 /// Walk back use-def chain through scf::For yields.
 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
+
+// TODO(ravishankarm, ntv): This can be moved into the dependence graphs
+// dependence tracking since the dependence tracking is similar to what is done
+// w.r.t to buffers.
 static void getProducerOfTensor(Value tensor, OpResult &opResult) {
   if (!tensor.getType().isa<RankedTensorType>())
     return;
@@ -722,6 +732,45 @@ collectFusableLoops(ArrayRef<LinalgOp> ops,
   return fusableLoops;
 }
 
+// /// For `consumer` with tensor semantics, find the Linalg operation on
+// tensors
+// /// producer the operand at position `consumerIdx`. This is a simple use-def
+// /// chain using the SSA value, but returned as an element of the
+// /// `LinalgDependenceGraphElem` to use the same analysis for both tensors and
+// /// buffers.
+// static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
+// findFusableProducerForTensorOp(OpOperand &consumerOpOperand) {
+//   // For now only looking for cases where the operand is produced by another
+//   // Linalg structured operation.
+//   LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
+//   if (!consumer || !consumer.hasTensorSemantics())
+//     return llvm::None;
+//   unsigned consumerIdx = consumerOpOperand.getOperandNumber();
+//   Value value = consumerOpOperand.get();
+//   if (auto linalgOp = value.getDefiningOp<LinalgOp>()) {
+//     return LinalgDependenceGraph::LinalgDependenceGraphElem{
+//         &(linalgOp
+//               .getOutputOpOperands()[value.cast<OpResult>().getResultNumber()]),
+//         &(consumer.getInputOpOperands()[consumerIdx]),
+//         LinalgDependenceGraph::DependenceType::RAW};
+//   }
+//   return llvm::None;
+// }
+
+// static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
+// findFusableProducer(OpOperand &consumerOpOperand,
+//                     const LinalgDependenceGraph &dependenceGraph) {
+//   LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
+//   if (!consumer)
+//     return llvm::None;
+//   if (consumer.hasBufferSemantics())
+//     return findFusableProducerForBufferOp(consumerOpOperand,
+//     dependenceGraph);
+//   if (consumer.hasTensorSemantics())
+//     return findFusableProducerForTensorOp(consumerOpOperand);
+//   return llvm::None;
+// }
+
 /// Find all dependences that are fusable.
 FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
     ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {
@@ -798,7 +847,7 @@ static Optional<TiledLinalgOp> tileRootOperation(
 /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
 /// `tiledOp`.
 static SmallVector<LinalgOp, 1>
-fuseOperations(OpBuilder &builder, LinalgOp tiledOp,
+fuseOperations(OpBuilder &builder, LinalgOp rootOp, LinalgOp tiledOp,
                ArrayRef<LinalgOp> fusionCandidates,
                const FusableOpDependencesTy &fusableDependences,
                const std::set<unsigned> &fusedLoops) {
@@ -812,9 +861,33 @@ fuseOperations(OpBuilder &builder, LinalgOp tiledOp,
   }
 
   SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
+  DenseMap<Operation *, LinalgOp> origOpToFusedOp;
+  origOpToFusedOp[rootOp.getOperation()] = tiledOp;
   for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
-    LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges);
+    LinalgOp origOp = candidate.value();
+    LinalgOp fusedOp = fuse(builder, origOp, fusedLoopsAndRanges);
+    origOpToFusedOp[origOp.getOperation()] = fusedOp;
     fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
+    // If the producer consumer operations are linalg operations on tensors, the
+    // dependence is due to value produced (as a return tensor) by the producer
+    // and used in the consumer. The returned value of the fused op needs to be
+    // made the operand of the tiled/fused consumer operation. By construction
+    // the value returned by the producer is the value used by the consumer.
+    for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) {
+      if (origOp.hasTensorSemantics() &&
+          dependence.dependenceType ==
+              LinalgDependenceGraph::DependenceType::RAW) {
+        unsigned resultIndex =
+            dependence.getDependentOpViewResultNum().getValue();
+        LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp());
+        if (!consumer)
+          continue;
+        Value replacementValue = fusedOp.getOperation()->getResult(resultIndex);
+        consumer.getOperation()->setOperand(
+            dependence.getIndexingOpViewOperandNum().getValue(),
+            replacementValue);
+      }
+    }
     builder.setInsertionPoint(fusedOp);
   }
   return fusedOps;
@@ -828,14 +901,16 @@ tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
   if (ops.size() < 2)
     return llvm::None;
   LinalgOp rootOp = ops.back();
-  for (auto op : enumerate(ops)) {
-    // TODO: Nothing in the fusion of sequence of ops is specific to
-    // buffers. This check can be removed after it is tested on tensors.
-    LinalgOp linalgOp = op.value();
-    if (!linalgOp.hasBufferSemantics()) {
-      linalgOp.emitRemark("tile and fuse only tested for buffer operation");
-      return llvm::None;
-    }
+  if (!llvm::all_of(
+          ops,
+          [](LinalgOp linalgOp) { return linalgOp.hasBufferSemantics(); }) &&
+      !llvm::all_of(ops, [](LinalgOp linalgOp) {
+        return linalgOp.hasTensorSemantics();
+      })) {
+    rootOp.emitError(
+        "unable to fuse operations that have tensor semantics with operations "
+        "that have buffer semantics and viceversa.");
+    return llvm::None;
   }
   // TODO: Support interchange with tile + fuse. This might actually help do
   // better fusion.
@@ -877,8 +952,9 @@ tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
   ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
 
   // Fuse the other operations into the fused inter-tile loops produced above.
-  ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(),
+  ret.fusedProducers = fuseOperations(builder, rootOp, ret.op, ops.drop_back(),
                                       fusableDependences, ret.fusedLoopDims);
+
   return ret;
 }
 

diff  --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir
index a02c878ef341..2738eb0f9114 100644
--- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-sequence.mlir
@@ -58,7 +58,7 @@ module {
 module {
   func @sequence_of_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
                            %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
-			   %arg4: memref<?x?xf32>) {
+                           %arg4: memref<?x?xf32>) {
     %cst = constant 0.000000e+00 : f32
     %c0 = constant 0 : index
     %c1 = constant 1 : index
@@ -131,3 +131,115 @@ module {
 //       CHECK:     scf.yield
 //       CHECK:   }
 
+// -----
+
+module {
+  func @tensor_op_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+                         %arg2: tensor<?x?xf32>, %arg3: tensor<?xf32>)
+    -> tensor<?x?xf32> {
+    %c0 = constant 0 : index
+    %c1 = constant 1 : index
+    %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+    %1 = dim %0, %c0 : tensor<?x?xf32>
+    %2 = dim %0, %c1 : tensor<?x?xf32>
+    %3 = linalg.init_tensor [%1, %2] : tensor<?x?xf32>
+    %4 = linalg.generic
+      {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                        affine_map<(d0, d1) -> (d0)>,
+                        affine_map<(d0, d1) -> (d0, d1)>],
+       iterator_types = ["parallel", "parallel"]}
+      ins(%0, %arg3 : tensor<?x?xf32>, tensor<?xf32>)
+      outs(%3 : tensor<?x?xf32>) {
+      ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
+        %5 = addf %arg4, %arg5 : f32
+        linalg.yield %5 : f32
+      } -> tensor<?x?xf32>
+    return %4 : tensor<?x?xf32>
+  }
+}
+// CHECK-LABEL: func @tensor_op_fusion
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?xf32>
+//       CHECK:   %[[INIT:.+]] = linalg.init_tensor
+//       CHECK:   %[[R0:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG5:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+//       CHECK:     %[[R1:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG7:.+]] = %[[ARG5]]) -> (tensor<?x?xf32>) {
+//   CHECK-DAG:       %[[STARG3:.+]] = subtensor %[[ARG3]]
+//   CHECK-DAG:       %[[STARG7:.+]] = subtensor %[[ARG7]]
+//   CHECK-DAG:       %[[STARG0:.+]] = subtensor %[[ARG0]]
+//   CHECK-DAG:       %[[STARG1:.+]] = subtensor %[[ARG1]]
+//   CHECK-DAG:       %[[STARG2:.+]] = subtensor %[[ARG2]]
+//       CHECK:       %[[T0:.+]] = linalg.matmul
+//  CHECK-SAME:         ins(%[[STARG0]], %[[STARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
+//  CHECK-SAME:         outs(%[[STARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+//       CHECK:       %[[T1:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[T0:.+]], %[[STARG3]] : tensor<?x?xf32>, tensor<?xf32>)
+//  CHECK-SAME:         outs(%[[STARG7]] : tensor<?x?xf32>)
+//       CHECK:       %[[RESULT:.+]] = subtensor_insert %[[T1]] into %[[ARG7]]
+//       CHECK:       scf.yield %[[RESULT]]
+//       CHECK:     }
+//       CHECK:     scf.yield %[[R1]]
+//       CHECK:   }
+//       CHECK:   return %[[R0]]
+
+// -----
+
+module {
+  func @tensor_matmul_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+                             %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>,
+			     %arg4: tensor<?x?xf32>, %arg5: tensor<?x?xf32>,
+			     %arg6: tensor<?x?xf32>) -> tensor<?x?xf32> {
+    %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N0] * [N0, N1]
+    %1 = linalg.matmul ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N1] * [N1, N2]
+    %2 = linalg.matmul ins(%1, %arg5 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N2] * [N2, N3]
+    return %2 : tensor<?x?xf32>
+  }
+}
+// CHECK-LABEL: func @tensor_matmul_fusion(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG5:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG6:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//       CHECK:   %[[R0:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] =
+//  CHECK-SAME:     iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) {
+//       CHECK:       %[[N3:.+]] = dim %[[ARG8]], %[[C1]]
+//       CHECK:       %[[STARG6:.+]] = subtensor %[[ARG8]][%[[IV0]], 0]
+//  CHECK-SAME:         [%{{[a-zA-Z0-9_]+}}, %[[N3]]]
+//       CHECK:       %[[N2:.+]] = dim %[[ARG3]], %[[C1]]
+//       CHECK:       %[[N1:.+]] = dim %[[ARG1]], %[[C1]]
+//       CHECK:       %[[STARG3:.+]] = subtensor %[[ARG3]][0, 0]
+//  CHECK-SAME:         [%[[N1]], %[[N2]]]
+//       CHECK:       %[[STARG4:.+]] = subtensor %[[ARG4]][%[[IV0]], 0]
+//  CHECK-SAME:         [%{{[a-zA-Z0-9_]+}}, %[[N2]]]
+//       CHECK:       %[[N0:.+]] = dim %[[ARG0]], %[[C1]]
+//       CHECK:       %[[STARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0]
+//  CHECK-SAME:         [%{{[a-zA-Z0-9_]+}}, %[[N0]]]
+//       CHECK:       %[[STARG1:.+]] = subtensor %[[ARG1]][0, 0]
+//  CHECK-SAME:         [%[[N0]], %[[N1]]]
+//       CHECK:       %[[STARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0]
+//  CHECK-SAME:         [%{{[a-zA-Z0-9_]+}}, %[[N1]]]
+//       CHECK:       %[[T0:.+]] = linalg.matmul
+//  CHECK-SAME:         ins(%[[STARG0]], %[[STARG1]]
+//  CHECK-SAME:         ) outs(%[[STARG2]] : tensor<?x?xf32>)
+//       CHECK:       %[[T1:.+]] = linalg.matmul
+//  CHECK-SAME:         ins(%[[T0]], %[[STARG3]]
+//  CHECK-SAME:         ) outs(%[[STARG4]] : tensor<?x?xf32>)
+//       CHECK:       %[[T2:.+]] = linalg.matmul
+//  CHECK-SAME:         ins(%[[T1]], %[[ARG5]]
+//  CHECK-SAME:         ) outs(%[[STARG6]] : tensor<?x?xf32>)
+//       CHECK:       %[[R1:.+]] = subtensor_insert %[[T2]]
+//  CHECK-SAME:         into %[[ARG8]][%[[IV0]], %[[C0]]]
+//       CHECK:       scf.yield %[[R1]]
+//       CHECK:     }
+//       CHECK:     return %[[R0]]
+//       CHECK:   }

diff  --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 5d55f0375f37..4ed00e4fbefc 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -226,14 +226,23 @@ struct TestLinalgTileAndFuseSequencePass
     Aliases aliases;
     LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
     OpBuilder builder(funcOp.getContext());
+    linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops;
+    if (llvm::all_of(linalgOps, [](LinalgOp linalgOp) {
+          return linalgOp.hasTensorSemantics();
+        }))
+      loopType = LinalgTilingLoopType::Loops;
     Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
         builder, linalgOps, dependenceGraph,
-        LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(
-            LinalgTilingLoopType::ParallelLoops));
+        LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType));
     if (!tileAndFuseOps)
       return signalPassFailure();
+    if (linalgOps.back().hasTensorSemantics()) {
+      linalgOps.back().getOperation()->replaceAllUsesWith(
+          tileAndFuseOps->fusedLoops.front());
+    }
     for (auto op : linalgOps)
-      op.erase();
+      if (op.hasBufferSemantics())
+        op.erase();
   }
 };
 } // namespace


        


More information about the Mlir-commits mailing list