[Mlir-commits] [mlir] 29c31cb - [mlir][linalg] Add support for transitive fusion.
Tobias Gysi
llvmlistbot at llvm.org
Thu Nov 4 09:25:40 PDT 2021
Author: Tobias Gysi
Date: 2021-11-04T16:25:06Z
New Revision: 29c31cb79b57594381aa15bcebe8c71b9fa64aef
URL: https://github.com/llvm/llvm-project/commit/29c31cb79b57594381aa15bcebe8c71b9fa64aef
DIFF: https://github.com/llvm/llvm-project/commit/29c31cb79b57594381aa15bcebe8c71b9fa64aef.diff
LOG: [mlir][linalg] Add support for transitive fusion.
Extend fusion on tensors to fuse producers greedily.
Reviewed By: nicolasvasilache, hanchung
Differential Revision: https://reviews.llvm.org/D110262
Added:
mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index b32d8e1c12b0..0924a5c59dcc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -212,6 +212,7 @@ class TileLoopNest {
bool isEmpty();
/// Returns true if the tile loop nest invariants are satisfied:
+ /// - The `rootOp` has been tiled at least once.
/// - The number of tile loop operations and dimensions match.
/// - The innermost tile loop is the parent of `tiledOp`.
/// - The tile loops are directly nested.
@@ -233,8 +234,8 @@ class TileLoopNest {
bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp);
LinalgOp rootOp;
- SmallVector<scf::ForOp> loopOps;
- SmallVector<int64_t> loopDims;
+ SmallVector<scf::ForOp> tileLoopOps;
+ DenseMap<Operation *, SmallVector<int64_t>> tiledRootAndFusedOpsLoops;
};
/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index bfac63b30586..7156515cedae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -42,19 +42,62 @@ static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand);
// Search the slice dimensions tiled by a tile loop dimension.
- DenseSet<int64_t> tiledSliceDims;
+ DenseSet<int64_t> tiledSliceDimIndices;
for (auto en : enumerate(indexingMap.getResults())) {
for (auto tiledLoopDim : tiledLoopDims) {
if (en.value().isFunctionOfDim(tiledLoopDim))
- tiledSliceDims.insert(en.index());
+ tiledSliceDimIndices.insert(en.index());
}
}
- return {tiledSliceDims.begin(), tiledSliceDims.end()};
+ return {tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()};
+}
+
+/// Given a vector of `tiledSliceDimIndices` that represent the tiled dimensions
+/// of the producer result slice returns the tiled producer loop dimensions.
+/// Example:
+/// ```
+/// %res = linalg.fill(%cst, %input)
+/// scf.for %i
+/// scf.for %j
+/// %slice = tensor.extract_slice %res[%i, %j]
+/// ```
+/// getTiledProducerLoops(%res, [0, 1]) returns the loop indices [0, 1].
+static SmallVector<int64_t>
+getTiledProducerLoops(OpResult producerResult,
+ ArrayRef<int64_t> tiledSliceDimIndices) {
+ LinalgOp producerOp = producerResult.getOwner();
+
+ // Get the indexing map of the `producerOp` output operand that matches
+ // ´producerResult´.
+ AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(
+ producerOp.getOutputOperand(producerResult.getResultNumber()));
+
+ // Keep only the tiled result slice dimensions of `producerIndexingMap`.
+ AffineMap tiledProducerIndexingSubMap =
+ producerIndexingMap.getSubMap(SmallVector<unsigned>(
+ tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()));
+
+ // Compute the producer loop indices mapped to the tiled result slice
+ // dimensions. As the output indexing map of structured operations are
+ // projected permutations, `tiledProducerIndexingSubMap` has to be a
+ // projected permutation as well. We can thus obtain the producer loop indices
+ // by getting the positions of the result dimensions.
+ // Example:
+ // (d0, d1, d2) -> (d0, d2) has the result positions [0, 2].
+ assert(tiledProducerIndexingSubMap.isProjectedPermutation() &&
+ "expect slice and producer loop dimensions map one-to-one");
+ SmallVector<int64_t> tiledProducerLoopIndices;
+ transform(llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()),
+ std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) {
+ return tiledProducerIndexingSubMap.getDimPosition(idx);
+ });
+
+ return tiledProducerLoopIndices;
}
/// Returns the producer fused in place of `sliceOp`. Tile the producer operands
-/// along the `tiledSliceDims` and clone the producer. Consider the case of
-/// fusion of an output tensor:
+/// along the `tiledSliceDimIndices` and clone the producer. Consider the case
+/// of fusion of an output tensor:
/// ```
/// %1 = producer ins(...) outs(%0)
/// %2 = consumer ins(...) outs(%1)
@@ -84,7 +127,8 @@ static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
/// producer is fused into a consumer and fold away unused iter_args.
static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
tensor::ExtractSliceOp sliceOp,
- ArrayRef<int64_t> tiledSliceDims,
+ ArrayRef<int64_t> tiledSliceDimIndices,
+ ArrayRef<int64_t> tiledProducerLoopIndices,
OpOperand *iterArg) {
// Clone the producer after `sliceOp` since the slice may be reused to pass in
// the producer result.
@@ -102,23 +146,16 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
[](Range range) { return range.size; });
SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc);
- // Get the producer result indexing map.
- AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(
- producerOp.getOutputOperand(producerResult.getResultNumber()));
-
// Tile the producer operands given the `sliceOp` ranges. Iterate the
- // `tiledSliceDims` and store the tile offset and size for the tiled slice
- // dimension. Assumes the mapping from slice dimensions to producer loops is a
- // permutation.
+ // `tiledSliceDimIndices` and store the tile offset and size for the tiled
+ // slice dimension.
auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> tileIvs(producerOp.getNumLoops(), nullptr);
SmallVector<Value> tileSizes(producerOp.getNumLoops(), zero);
SmallVector<Value> allIvs(producerOp.getNumLoops(), nullptr);
- for (int64_t tiledSliceDim : tiledSliceDims) {
- AffineExpr result = producerIndexingMap.getResults()[tiledSliceDim];
- assert(result.isa<AffineDimExpr>() &&
- "expect producer indexing map is a projected permutation");
- int64_t tiledProducerLoop = result.cast<AffineDimExpr>().getPosition();
+ for (auto it : zip(tiledSliceDimIndices, tiledProducerLoopIndices)) {
+ int64_t tiledSliceDim = std::get<0>(it);
+ int64_t tiledProducerLoop = std::get<1>(it);
tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset;
tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size;
allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop];
@@ -156,22 +193,26 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
// TileLoopNest specific helpers.
//===----------------------------------------------------------------------===//
-bool TileLoopNest::isEmpty() { return loopOps.empty(); }
+bool TileLoopNest::isEmpty() { return tileLoopOps.empty(); }
bool TileLoopNest::isValid() {
- // Check if the number of `tileLoopOps` and `tileLoopDims` match.
- if (loopOps.size() != loopDims.size())
+ // Check if `rootOp` has been tiled at least once.
+ if (isEmpty() || tiledRootAndFusedOpsLoops.count(rootOp) == 0)
+ return false;
+
+ // Check if the number of loop operations and dimensions match.
+ if (tileLoopOps.size() != tiledRootAndFusedOpsLoops[rootOp].size())
return false;
// Check if the innermost tile loop is the parent of `tiledOp`.
- if (rootOp->getParentOp() != loopOps.back())
+ if (rootOp->getParentOp() != tileLoopOps.back())
return false;
// Check if the tile loops are directly nested.
- return std::adjacent_find(loopOps.begin(), loopOps.end(),
+ return std::adjacent_find(tileLoopOps.begin(), tileLoopOps.end(),
[](Operation *op1, Operation *op2) {
return op1 != op2->getParentOp();
- }) == loopOps.end();
+ }) == tileLoopOps.end();
}
SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
@@ -179,7 +220,7 @@ SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
SmallVector<BlockArgument> bbArgs;
// Search all tile loop block arguments from inner to outer.
- for (auto tileLoop : reverse(loopOps)) {
+ for (auto tileLoop : reverse(tileLoopOps)) {
if (bbArg.getOwner()->getParentOp() != tileLoop)
return {};
bbArgs.push_back(bbArg);
@@ -194,9 +235,9 @@ SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) {
// Search all block arguments and return the matching iteration argument.
SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
- if (bbArgs.size() != loopOps.size())
+ if (bbArgs.size() != tileLoopOps.size())
return nullptr;
- return &loopOps.front().getOpOperandForRegionIterArg(bbArgs.front());
+ return &tileLoopOps.front().getOpOperandForRegionIterArg(bbArgs.front());
}
bool TileLoopNest::hasOtherUses(BlockArgument bbArg,
@@ -255,24 +296,29 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
if (!isEmpty())
rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);
+ // Transfer the stored `rootOp` loop dimensions if it has been tiled before.
+ if (tiledRootAndFusedOpsLoops.count(rootOp) != 0) {
+ tiledRootAndFusedOpsLoops[tiledRootOp->op] =
+ tiledRootAndFusedOpsLoops[rootOp];
+ }
+
// Update the root operation and append the loops and tile loop dimensions.
rootOp = tiledRootOp->op;
- loopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
+ tileLoopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
for (auto en : enumerate(tileSizes)) {
// Copy only the tiled loop dimensions with non-zero tile size.
if (en.value() == 0)
continue;
- loopDims.push_back(tileInterchange[en.index()]);
+ tiledRootAndFusedOpsLoops[rootOp].push_back(tileInterchange[en.index()]);
}
assert(isValid() && "expect tile loop nest to be valid after tiling");
-
return success();
}
FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
- OpOperand *rootOpOperand) {
- assert(rootOpOperand->getOwner() == rootOp &&
- "expect the root op to be the owner of the operand to fuse");
+ OpOperand *consumerOpOperand) {
+ assert(tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) != 0 &&
+ "expect the operand owner is the root operation or a fused producer");
assert(this->isValid() &&
"expect the tile loop nest to satisfy all invariants");
@@ -280,13 +326,16 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
if (isEmpty())
return failure();
- // Check `rootOpOperand` is defined by an ExtractSliceOp.
- auto sliceOp = rootOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ // Check `consumerOpOperand` is defined by an ExtractSliceOp.
+ auto sliceOp =
+ consumerOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp)
return failure();
- // Check `sliceOp` is tiled by the tile loop nest.
- if (sliceOp->getParentOp() != rootOp->getParentOp())
+ // Check `sliceOp` and `consumerOp` are in the same block.
+ LinalgOp consumerOp = consumerOpOperand->getOwner();
+ if (sliceOp->getBlock() != rootOp->getBlock() ||
+ consumerOp->getBlock() != rootOp->getBlock())
return failure();
// Check if the producer is a LinalgOp possibly passed by iteration argument.
@@ -302,19 +351,24 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
if (!producerResult || !isa<LinalgOp>(producerResult.getOwner()))
return failure();
- // Compute the tiled producer slice dimensions given the tiled root operation
- // loop dimensions `loopDims`.
- SmallVector<int64_t> tiledSliceDims =
- getTiledSliceDims(rootOpOperand, loopDims);
- if (tiledSliceDims.empty())
+ // Compute the tiled producer slice dimensions given the tiled consumer loops.
+ SmallVector<int64_t> tiledSliceDimIndices = getTiledSliceDims(
+ consumerOpOperand, tiledRootAndFusedOpsLoops[consumerOp]);
+ if (tiledSliceDimIndices.empty())
return failure();
+ // Compute the tiled producer loop indices.
+ SmallVector<int64_t> tiledProducerLoopIndices =
+ getTiledProducerLoops(producerResult, tiledSliceDimIndices);
+
// Tile the producer operands and clone the producer in place of `sliceOp`.
LinalgOp clonedOp =
- getTiledProducer(b, producerResult, sliceOp, tiledSliceDims, iterArg);
+ getTiledProducer(b, producerResult, sliceOp, tiledSliceDimIndices,
+ tiledProducerLoopIndices, iterArg);
+ tiledRootAndFusedOpsLoops[clonedOp] = tiledProducerLoopIndices;
// Cast the `clonedOp` result to gap type mismatches before canonicalization.
- Type consumerOperandType = rootOpOperand->get().getType();
+ Type consumerOperandType = consumerOpOperand->get().getType();
Value newResult = clonedOp->getResult(producerResult.getResultNumber());
if (newResult.getType() != consumerOperandType) {
OpBuilder::InsertionGuard guard(b);
@@ -330,7 +384,7 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
ValueRange TileLoopNest::getRootOpReplacementResults() {
assert(!isEmpty() && "expect tile loop nest to be non-empty");
- return loopOps.front()->getOpResults();
+ return tileLoopOps.front()->getOpResults();
}
//===----------------------------------------------------------------------===//
@@ -359,14 +413,25 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
});
int64_t split = std::distance(iterTypes.begin(), it);
+ // Helper to fuse the producers greedily using a queue of fusion candidates.
+ auto fuseProducersGreedily = [&](ArrayRef<OpOperand *> operands) {
+ SmallVector<OpOperand *> candidates(operands.begin(), operands.end());
+ while (!candidates.empty()) {
+ FailureOr<LinalgOp> fusedProducer =
+ tileLoopNest.fuseProducer(b, candidates.pop_back_val());
+ if (failed(fusedProducer))
+ continue;
+ candidates.append(fusedProducer->getInputAndOutputOperands());
+ }
+ };
+
// Tile the outer parallel loops and fuse the output operands.
SmallVector<int64_t> outerTileSizes;
outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split);
outerTileSizes.append(tileSizes.size() - split, 0);
if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange)))
return failure();
- for (OpOperand *opOperand : tileLoopNest.getRootOp().getOutputOperands())
- (void)tileLoopNest.fuseProducer(b, opOperand);
+ fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands());
// Tile the remaining loops and fuse the input operands.
SmallVector<int64_t> innerTileSizes;
@@ -374,10 +439,7 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
innerTileSizes.append(tileSizes.begin() + split, tileSizes.end());
if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange)))
return failure();
- SmallVector<OpOperand *> inputOperands =
- tileLoopNest.getRootOp().getInputOperands();
- for (OpOperand *opOperand : tileLoopNest.getRootOp().getInputOperands())
- (void)tileLoopNest.fuseProducer(b, opOperand);
+ fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands());
return tileLoopNest;
}
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir
new file mode 100644
index 000000000000..1578d230017b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir
@@ -0,0 +1,83 @@
+// RUN: mlir-opt %s -linalg-tile-and-fuse-tensor-ops="tile-sizes=4,4,0,0 tile-interchange=0,1,2,3" -cse --canonicalize -split-input-file | FileCheck %s
+
+// CHECK: fuse_conv_chain
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<2x2xf32>
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<11x11xf32>
+// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<10x10xf32>
+// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<9x9xf32>
+// CHECK-SAME: %[[ARG4:[0-9a-zA-Z]*]]: tensor<8x8xf32>
+builtin.func @fuse_conv_chain(%arg0: tensor<2x2xf32>,
+ %arg1: tensor<11x11xf32>,
+ %arg2: tensor<10x10xf32>,
+ %arg3: tensor<9x9xf32>,
+ %arg4: tensor<8x8xf32>) -> tensor<8x8xf32> {
+ %cst = arith.constant 1.0 : f32
+
+ // Do not tile the filter fill since the filter dimensions are not tiled.
+ // CHECK: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]])
+ %0 = linalg.fill(%cst, %arg0) : f32, tensor<2x2xf32> -> tensor<2x2xf32>
+
+ // Fuse all other operations.
+ // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[ARG4]]
+ // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG6:.*]] = %[[ARG5]]
+
+ // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
+ // CHECK-SAME: %[[IV0]], %[[IV1]]
+ // CHECK: %[[T2:.*]] = tensor.extract_slice %[[ARG2]]
+ // CHECK-SAME: %[[IV0]], %[[IV1]]
+ // CHECK: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]])
+ // CHECK: %[[T4:.*]] = linalg.conv_2d ins(%[[T1]], %[[T0]] : {{.*}} outs(%[[T3]]
+ %1 = linalg.fill(%cst, %arg2) : f32, tensor<10x10xf32> -> tensor<10x10xf32>
+ %2 = linalg.conv_2d ins(%arg1, %0 : tensor<11x11xf32>, tensor<2x2xf32>) outs(%1 : tensor<10x10xf32>) -> tensor<10x10xf32>
+
+ // CHECK: %[[T5:.*]] = tensor.extract_slice %[[ARG3]]
+ // CHECK-SAME: %[[IV0]], %[[IV1]]
+ // CHECK: %[[T6:.*]] = linalg.fill(%{{.*}}, %[[T5]])
+ // CHECK: %[[T7:.*]] = linalg.conv_2d ins(%[[T4]], %[[T0]] : {{.*}} outs(%[[T6]]
+ %3 = linalg.fill(%cst, %arg3) : f32, tensor<9x9xf32> -> tensor<9x9xf32>
+ %4 = linalg.conv_2d ins(%2, %0 : tensor<10x10xf32>, tensor<2x2xf32>) outs(%3 : tensor<9x9xf32>) -> tensor<9x9xf32>
+
+ // Use the argument passed in by iteration argument.
+ // CHECK: %[[T8:.*]] = tensor.extract_slice %[[ARG6]]
+ // CHECK-SAME: %[[IV0]], %[[IV1]]
+ // CHECK: %[[T9:.*]] = linalg.fill(%{{.*}}, %[[T8]])
+ // CHECK: %[[T5:.*]] = linalg.conv_2d ins(%[[T7]], %[[T0]] {{.*}} outs(%[[T9]]
+ %5 = linalg.fill(%cst, %arg4) : f32, tensor<8x8xf32> -> tensor<8x8xf32>
+ %6 = linalg.conv_2d ins(%4, %0 : tensor<9x9xf32>, tensor<2x2xf32>) outs(%5 : tensor<8x8xf32>) -> tensor<8x8xf32>
+ return %6 : tensor<8x8xf32>
+}
+
+// -----
+
+// CHECK: fuse_matmul_chain
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<8x8xf32>
+builtin.func @fuse_matmul_chain(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
+ %c0 = arith.constant 0 : index
+ %c12 = arith.constant 12 : index
+ %c25 = arith.constant 25 : index
+ %c24 = arith.constant 24 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.000000e+00 : f32
+
+ // Do not tile rhs fill of the producer matmul since none of its loop dimension is tiled.
+ // CHECK: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]])
+ %0 = linalg.fill(%cst, %arg0) : f32, tensor<8x8xf32> -> tensor<8x8xf32>
+
+ // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG1:.*]] = %[[ARG0]]
+ // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[ARG1]]
+
+ // Only the outermost loop of the producer matmul is tiled.
+ // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
+ // CHECK-SAME: %[[IV0]], 0
+ // CHECK: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]])
+ // CHECK: %[[T3:.*]] = linalg.matmul ins(%[[T2]], %[[T0]] {{.*}}
+ %1 = linalg.matmul ins(%0, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32>
+
+ // Use the argument passed in by iteration argument.
+ // CHECK: %[[T4:.*]] = tensor.extract_slice %[[ARG2]]
+ // CHECK-SAME: %[[IV0]], %[[IV1]]
+ // CHECK: %[[T5:.*]] = linalg.fill(%{{.*}}, %[[T4]])
+ // CHECK: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[T5]]
+ %2 = linalg.matmul ins(%1, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32>
+ return %2 : tensor<8x8xf32>
+}
More information about the Mlir-commits
mailing list