[Mlir-commits] [mlir] 29c31cb - [mlir][linalg] Add support for transitive fusion.

Tobias Gysi llvmlistbot at llvm.org
Thu Nov 4 09:25:40 PDT 2021


Author: Tobias Gysi
Date: 2021-11-04T16:25:06Z
New Revision: 29c31cb79b57594381aa15bcebe8c71b9fa64aef

URL: https://github.com/llvm/llvm-project/commit/29c31cb79b57594381aa15bcebe8c71b9fa64aef
DIFF: https://github.com/llvm/llvm-project/commit/29c31cb79b57594381aa15bcebe8c71b9fa64aef.diff

LOG: [mlir][linalg] Add support for transitive fusion.

Extend fusion on tensors to fuse producers greedily.

Reviewed By: nicolasvasilache, hanchung

Differential Revision: https://reviews.llvm.org/D110262

Added: 
    mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index b32d8e1c12b0..0924a5c59dcc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -212,6 +212,7 @@ class TileLoopNest {
   bool isEmpty();
 
   /// Returns true if the tile loop nest invariants are satisfied:
+  /// - The `rootOp` has been tiled at least once.
   /// - The number of tile loop operations and dimensions match.
   /// - The innermost tile loop is the parent of `tiledOp`.
   /// - The tile loops are directly nested.
@@ -233,8 +234,8 @@ class TileLoopNest {
   bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp);
 
   LinalgOp rootOp;
-  SmallVector<scf::ForOp> loopOps;
-  SmallVector<int64_t> loopDims;
+  SmallVector<scf::ForOp> tileLoopOps;
+  DenseMap<Operation *, SmallVector<int64_t>> tiledRootAndFusedOpsLoops;
 };
 
 /// Tiles `consumerOp` and fuses its dependencies if possible. Uses the

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index bfac63b30586..7156515cedae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -42,19 +42,62 @@ static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
   AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand);
 
   // Search the slice dimensions tiled by a tile loop dimension.
-  DenseSet<int64_t> tiledSliceDims;
+  DenseSet<int64_t> tiledSliceDimIndices;
   for (auto en : enumerate(indexingMap.getResults())) {
     for (auto tiledLoopDim : tiledLoopDims) {
       if (en.value().isFunctionOfDim(tiledLoopDim))
-        tiledSliceDims.insert(en.index());
+        tiledSliceDimIndices.insert(en.index());
     }
   }
-  return {tiledSliceDims.begin(), tiledSliceDims.end()};
+  return {tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()};
+}
+
+/// Given a vector of `tiledSliceDimIndices` that represent the tiled dimensions
+/// of the producer result slice returns the tiled producer loop dimensions.
+/// Example:
+/// ```
+/// %res = linalg.fill(%cst, %input)
+/// scf.for %i
+///   scf.for %j
+///     %slice = tensor.extract_slice %res[%i, %j]
+/// ```
+/// getTiledProducerLoops(%res, [0, 1]) returns the loop indices [0, 1].
+static SmallVector<int64_t>
+getTiledProducerLoops(OpResult producerResult,
+                      ArrayRef<int64_t> tiledSliceDimIndices) {
+  LinalgOp producerOp = producerResult.getOwner();
+
+  // Get the indexing map of the `producerOp` output operand that matches
+  // ´producerResult´.
+  AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(
+      producerOp.getOutputOperand(producerResult.getResultNumber()));
+
+  // Keep only the tiled result slice dimensions of `producerIndexingMap`.
+  AffineMap tiledProducerIndexingSubMap =
+      producerIndexingMap.getSubMap(SmallVector<unsigned>(
+          tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()));
+
+  // Compute the producer loop indices mapped to the tiled result slice
+  // dimensions. As the output indexing map of structured operations are
+  // projected permutations, `tiledProducerIndexingSubMap` has to be a
+  // projected permutation as well. We can thus obtain the producer loop indices
+  // by getting the positions of the result dimensions.
+  // Example:
+  // (d0, d1, d2) -> (d0, d2) has the result positions [0, 2].
+  assert(tiledProducerIndexingSubMap.isProjectedPermutation() &&
+         "expect slice and producer loop dimensions map one-to-one");
+  SmallVector<int64_t> tiledProducerLoopIndices;
+  transform(llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()),
+            std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) {
+              return tiledProducerIndexingSubMap.getDimPosition(idx);
+            });
+
+  return tiledProducerLoopIndices;
 }
 
 /// Returns the producer fused in place of `sliceOp`. Tile the producer operands
-/// along the `tiledSliceDims` and clone the producer. Consider the case of
-/// fusion of an output tensor:
+/// along the `tiledSliceDimIndices` and clone the producer. Consider the case
+/// of fusion of an output tensor:
 /// ```
 /// %1 = producer ins(...) outs(%0)
 /// %2 = consumer ins(...) outs(%1)
@@ -84,7 +127,8 @@ static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
 /// producer is fused into a consumer and fold away unused iter_args.
 static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
                                  tensor::ExtractSliceOp sliceOp,
-                                 ArrayRef<int64_t> tiledSliceDims,
+                                 ArrayRef<int64_t> tiledSliceDimIndices,
+                                 ArrayRef<int64_t> tiledProducerLoopIndices,
                                  OpOperand *iterArg) {
   // Clone the producer after `sliceOp` since the slice may be reused to pass in
   // the producer result.
@@ -102,23 +146,16 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
             [](Range range) { return range.size; });
   SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc);
 
-  // Get the producer result indexing map.
-  AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(
-      producerOp.getOutputOperand(producerResult.getResultNumber()));
-
   // Tile the producer operands given the `sliceOp` ranges. Iterate the
-  // `tiledSliceDims` and store the tile offset and size for the tiled slice
-  // dimension. Assumes the mapping from slice dimensions to producer loops is a
-  // permutation.
+  // `tiledSliceDimIndices` and store the tile offset and size for the tiled
+  // slice dimension.
   auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
   SmallVector<Value> tileIvs(producerOp.getNumLoops(), nullptr);
   SmallVector<Value> tileSizes(producerOp.getNumLoops(), zero);
   SmallVector<Value> allIvs(producerOp.getNumLoops(), nullptr);
-  for (int64_t tiledSliceDim : tiledSliceDims) {
-    AffineExpr result = producerIndexingMap.getResults()[tiledSliceDim];
-    assert(result.isa<AffineDimExpr>() &&
-           "expect producer indexing map is a projected permutation");
-    int64_t tiledProducerLoop = result.cast<AffineDimExpr>().getPosition();
+  for (auto it : zip(tiledSliceDimIndices, tiledProducerLoopIndices)) {
+    int64_t tiledSliceDim = std::get<0>(it);
+    int64_t tiledProducerLoop = std::get<1>(it);
     tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset;
     tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size;
     allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop];
@@ -156,22 +193,26 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
 // TileLoopNest specific helpers.
 //===----------------------------------------------------------------------===//
 
-bool TileLoopNest::isEmpty() { return loopOps.empty(); }
+bool TileLoopNest::isEmpty() { return tileLoopOps.empty(); }
 
 bool TileLoopNest::isValid() {
-  // Check if the number of `tileLoopOps` and `tileLoopDims` match.
-  if (loopOps.size() != loopDims.size())
+  // Check if `rootOp` has been tiled at least once.
+  if (isEmpty() || tiledRootAndFusedOpsLoops.count(rootOp) == 0)
+    return false;
+
+  // Check if the number of loop operations and dimensions match.
+  if (tileLoopOps.size() != tiledRootAndFusedOpsLoops[rootOp].size())
     return false;
 
   // Check if the innermost tile loop is the parent of `tiledOp`.
-  if (rootOp->getParentOp() != loopOps.back())
+  if (rootOp->getParentOp() != tileLoopOps.back())
     return false;
 
   // Check if the tile loops are directly nested.
-  return std::adjacent_find(loopOps.begin(), loopOps.end(),
+  return std::adjacent_find(tileLoopOps.begin(), tileLoopOps.end(),
                             [](Operation *op1, Operation *op2) {
                               return op1 != op2->getParentOp();
-                            }) == loopOps.end();
+                            }) == tileLoopOps.end();
 }
 
 SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
@@ -179,7 +220,7 @@ SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
   SmallVector<BlockArgument> bbArgs;
 
   // Search all tile loop block arguments from inner to outer.
-  for (auto tileLoop : reverse(loopOps)) {
+  for (auto tileLoop : reverse(tileLoopOps)) {
     if (bbArg.getOwner()->getParentOp() != tileLoop)
       return {};
     bbArgs.push_back(bbArg);
@@ -194,9 +235,9 @@ SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
 OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) {
   // Search all block arguments and return the matching iteration argument.
   SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
-  if (bbArgs.size() != loopOps.size())
+  if (bbArgs.size() != tileLoopOps.size())
     return nullptr;
-  return &loopOps.front().getOpOperandForRegionIterArg(bbArgs.front());
+  return &tileLoopOps.front().getOpOperandForRegionIterArg(bbArgs.front());
 }
 
 bool TileLoopNest::hasOtherUses(BlockArgument bbArg,
@@ -255,24 +296,29 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
   if (!isEmpty())
     rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);
 
+  // Transfer the stored `rootOp` loop dimensions if it has been tiled before.
+  if (tiledRootAndFusedOpsLoops.count(rootOp) != 0) {
+    tiledRootAndFusedOpsLoops[tiledRootOp->op] =
+        tiledRootAndFusedOpsLoops[rootOp];
+  }
+
   // Update the root operation and append the loops and tile loop dimensions.
   rootOp = tiledRootOp->op;
-  loopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
+  tileLoopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
   for (auto en : enumerate(tileSizes)) {
     // Copy only the tiled loop dimensions with non-zero tile size.
     if (en.value() == 0)
       continue;
-    loopDims.push_back(tileInterchange[en.index()]);
+    tiledRootAndFusedOpsLoops[rootOp].push_back(tileInterchange[en.index()]);
   }
   assert(isValid() && "expect tile loop nest to be valid after tiling");
-
   return success();
 }
 
 FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
-                                               OpOperand *rootOpOperand) {
-  assert(rootOpOperand->getOwner() == rootOp &&
-         "expect the root op to be the owner of the operand to fuse");
+                                               OpOperand *consumerOpOperand) {
+  assert(tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) != 0 &&
+         "expect the operand owner is the root operation or a fused producer");
   assert(this->isValid() &&
          "expect the tile loop nest to satisfy all invariants");
 
@@ -280,13 +326,16 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
   if (isEmpty())
     return failure();
 
-  // Check `rootOpOperand` is defined by an ExtractSliceOp.
-  auto sliceOp = rootOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+  // Check `consumerOpOperand` is defined by an ExtractSliceOp.
+  auto sliceOp =
+      consumerOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
   if (!sliceOp)
     return failure();
 
-  // Check `sliceOp` is tiled by the tile loop nest.
-  if (sliceOp->getParentOp() != rootOp->getParentOp())
+  // Check `sliceOp` and `consumerOp` are in the same block.
+  LinalgOp consumerOp = consumerOpOperand->getOwner();
+  if (sliceOp->getBlock() != rootOp->getBlock() ||
+      consumerOp->getBlock() != rootOp->getBlock())
     return failure();
 
   // Check if the producer is a LinalgOp possibly passed by iteration argument.
@@ -302,19 +351,24 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
   if (!producerResult || !isa<LinalgOp>(producerResult.getOwner()))
     return failure();
 
-  // Compute the tiled producer slice dimensions given the tiled root operation
-  // loop dimensions `loopDims`.
-  SmallVector<int64_t> tiledSliceDims =
-      getTiledSliceDims(rootOpOperand, loopDims);
-  if (tiledSliceDims.empty())
+  // Compute the tiled producer slice dimensions given the tiled consumer loops.
+  SmallVector<int64_t> tiledSliceDimIndices = getTiledSliceDims(
+      consumerOpOperand, tiledRootAndFusedOpsLoops[consumerOp]);
+  if (tiledSliceDimIndices.empty())
     return failure();
 
+  // Compute the tiled producer loop indices.
+  SmallVector<int64_t> tiledProducerLoopIndices =
+      getTiledProducerLoops(producerResult, tiledSliceDimIndices);
+
   // Tile the producer operands and clone the producer in place of `sliceOp`.
   LinalgOp clonedOp =
-      getTiledProducer(b, producerResult, sliceOp, tiledSliceDims, iterArg);
+      getTiledProducer(b, producerResult, sliceOp, tiledSliceDimIndices,
+                       tiledProducerLoopIndices, iterArg);
+  tiledRootAndFusedOpsLoops[clonedOp] = tiledProducerLoopIndices;
 
   // Cast the `clonedOp` result to gap type mismatches before canonicalization.
-  Type consumerOperandType = rootOpOperand->get().getType();
+  Type consumerOperandType = consumerOpOperand->get().getType();
   Value newResult = clonedOp->getResult(producerResult.getResultNumber());
   if (newResult.getType() != consumerOperandType) {
     OpBuilder::InsertionGuard guard(b);
@@ -330,7 +384,7 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
 
 ValueRange TileLoopNest::getRootOpReplacementResults() {
   assert(!isEmpty() && "expect tile loop nest to be non-empty");
-  return loopOps.front()->getOpResults();
+  return tileLoopOps.front()->getOpResults();
 }
 
 //===----------------------------------------------------------------------===//
@@ -359,14 +413,25 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
   });
   int64_t split = std::distance(iterTypes.begin(), it);
 
+  // Helper to fuse the producers greedily using a queue of fusion candidates.
+  auto fuseProducersGreedily = [&](ArrayRef<OpOperand *> operands) {
+    SmallVector<OpOperand *> candidates(operands.begin(), operands.end());
+    while (!candidates.empty()) {
+      FailureOr<LinalgOp> fusedProducer =
+          tileLoopNest.fuseProducer(b, candidates.pop_back_val());
+      if (failed(fusedProducer))
+        continue;
+      candidates.append(fusedProducer->getInputAndOutputOperands());
+    }
+  };
+
   // Tile the outer parallel loops and fuse the output operands.
   SmallVector<int64_t> outerTileSizes;
   outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split);
   outerTileSizes.append(tileSizes.size() - split, 0);
   if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange)))
     return failure();
-  for (OpOperand *opOperand : tileLoopNest.getRootOp().getOutputOperands())
-    (void)tileLoopNest.fuseProducer(b, opOperand);
+  fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands());
 
   // Tile the remaining loops and fuse the input operands.
   SmallVector<int64_t> innerTileSizes;
@@ -374,10 +439,7 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
   innerTileSizes.append(tileSizes.begin() + split, tileSizes.end());
   if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange)))
     return failure();
-  SmallVector<OpOperand *> inputOperands =
-      tileLoopNest.getRootOp().getInputOperands();
-  for (OpOperand *opOperand : tileLoopNest.getRootOp().getInputOperands())
-    (void)tileLoopNest.fuseProducer(b, opOperand);
+  fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands());
 
   return tileLoopNest;
 }

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir
new file mode 100644
index 000000000000..1578d230017b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir
@@ -0,0 +1,83 @@
+// RUN: mlir-opt %s -linalg-tile-and-fuse-tensor-ops="tile-sizes=4,4,0,0 tile-interchange=0,1,2,3" -cse --canonicalize -split-input-file | FileCheck %s
+
+//      CHECK:  fuse_conv_chain
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<2x2xf32>
+// CHECK-SAME:    %[[ARG1:[0-9a-zA-Z]*]]: tensor<11x11xf32>
+// CHECK-SAME:    %[[ARG2:[0-9a-zA-Z]*]]: tensor<10x10xf32>
+// CHECK-SAME:    %[[ARG3:[0-9a-zA-Z]*]]: tensor<9x9xf32>
+// CHECK-SAME:    %[[ARG4:[0-9a-zA-Z]*]]: tensor<8x8xf32>
+builtin.func @fuse_conv_chain(%arg0: tensor<2x2xf32>,
+                              %arg1: tensor<11x11xf32>,
+                              %arg2: tensor<10x10xf32>,
+                              %arg3: tensor<9x9xf32>,
+                              %arg4: tensor<8x8xf32>) -> tensor<8x8xf32> {
+  %cst = arith.constant 1.0 : f32
+
+  // Do not tile the filter fill since the filter dimensions are not tiled.
+  //      CHECK:  %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]])
+  %0 = linalg.fill(%cst, %arg0) : f32, tensor<2x2xf32> -> tensor<2x2xf32>
+
+  // Fuse all other operations.
+  //      CHECK:  scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[ARG4]]
+  //      CHECK:    scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG6:.*]] = %[[ARG5]]
+
+  //      CHECK:          %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
+  // CHECK-SAME:                                            %[[IV0]], %[[IV1]]
+  //      CHECK:          %[[T2:.*]] = tensor.extract_slice %[[ARG2]]
+  // CHECK-SAME:                                            %[[IV0]], %[[IV1]]
+  //      CHECK:          %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]])
+  //      CHECK:          %[[T4:.*]] = linalg.conv_2d ins(%[[T1]], %[[T0]] : {{.*}} outs(%[[T3]]
+  %1 = linalg.fill(%cst, %arg2) : f32, tensor<10x10xf32> -> tensor<10x10xf32>
+  %2 = linalg.conv_2d ins(%arg1, %0 : tensor<11x11xf32>, tensor<2x2xf32>) outs(%1 : tensor<10x10xf32>) -> tensor<10x10xf32>
+
+  //      CHECK:          %[[T5:.*]] = tensor.extract_slice %[[ARG3]]
+  // CHECK-SAME:                                            %[[IV0]], %[[IV1]]
+  //      CHECK:          %[[T6:.*]] = linalg.fill(%{{.*}}, %[[T5]])
+  //      CHECK:          %[[T7:.*]] = linalg.conv_2d ins(%[[T4]], %[[T0]] : {{.*}} outs(%[[T6]]
+  %3 = linalg.fill(%cst, %arg3) : f32, tensor<9x9xf32> -> tensor<9x9xf32>
+  %4 = linalg.conv_2d ins(%2, %0 : tensor<10x10xf32>, tensor<2x2xf32>) outs(%3 : tensor<9x9xf32>) -> tensor<9x9xf32>
+
+  // Use the argument passed in by iteration argument.
+  //      CHECK:          %[[T8:.*]] = tensor.extract_slice %[[ARG6]]
+  // CHECK-SAME:                                            %[[IV0]], %[[IV1]]
+  //      CHECK:          %[[T9:.*]] = linalg.fill(%{{.*}}, %[[T8]])
+  //      CHECK:          %[[T5:.*]] = linalg.conv_2d ins(%[[T7]], %[[T0]] {{.*}} outs(%[[T9]]
+  %5 = linalg.fill(%cst, %arg4) : f32, tensor<8x8xf32> -> tensor<8x8xf32>
+  %6 = linalg.conv_2d ins(%4, %0 : tensor<9x9xf32>, tensor<2x2xf32>) outs(%5 : tensor<8x8xf32>) -> tensor<8x8xf32>
+  return %6 : tensor<8x8xf32>
+}
+
+// -----
+
+//      CHECK:  fuse_matmul_chain
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<8x8xf32>
+builtin.func @fuse_matmul_chain(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
+  %c0 = arith.constant 0 : index
+  %c12 = arith.constant 12 : index
+  %c25 = arith.constant 25 : index
+  %c24 = arith.constant 24 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.000000e+00 : f32
+
+  // Do not tile rhs fill of the producer matmul since none of its loop dimension is tiled.
+  //      CHECK:  %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]])
+  %0 = linalg.fill(%cst, %arg0) : f32, tensor<8x8xf32> -> tensor<8x8xf32>
+
+  //      CHECK:  scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG1:.*]] = %[[ARG0]]
+  //      CHECK:    scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[ARG1]]
+
+  // Only the outermost loop of the producer matmul is tiled.
+  //      CHECK:      %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
+  // CHECK-SAME:                                        %[[IV0]], 0
+  //      CHECK:      %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]])
+  //      CHECK:      %[[T3:.*]] = linalg.matmul ins(%[[T2]], %[[T0]] {{.*}}
+  %1 = linalg.matmul ins(%0, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32>
+
+  // Use the argument passed in by iteration argument.
+  //      CHECK:      %[[T4:.*]] = tensor.extract_slice %[[ARG2]]
+  // CHECK-SAME:                                        %[[IV0]], %[[IV1]]
+  //      CHECK:      %[[T5:.*]] = linalg.fill(%{{.*}}, %[[T4]])
+  //      CHECK:      %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[T5]]
+  %2 = linalg.matmul ins(%1, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32>
+  return %2 : tensor<8x8xf32>
+}


        


More information about the Mlir-commits mailing list