[Mlir-commits] [mlir] [mlir][tensor] Update `GeneralizeOuterUnitDimsPackOpPattern` (PR #115312)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Nov 7 05:03:36 PST 2024
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/115312
Avoid generating spurious tensor.extract_slice, follow-on for #114315.
This is best to demonstrate with an example. Here's input for
`GeneralizeOuterUnitDimsPackOpPattern`:
```mlir
%pack = tensor.pack %input
padding_value(%pad : f32)
inner_dims_pos = [1, 0]
inner_tiles = [2, %tile_dim_1]
into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
```
Output _before_:
```mlir
%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%extracted_slice = tensor.extract_slice %padded[0, 0] [%arg3, 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
ins(%extracted_slice : tensor<?x2xf32>)
outs(%empty : tensor<2x?xf32>)
permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed=
into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
tensor<2x?xf32> into tensor<1x1x2x?xf32>
```
Output _after_:
```mlir
%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
ins(%padded : tensor<?x2xf32>)
outs(%empty : tensor<2x?xf32>) permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed
into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
tensor<2x?xf32> into tensor<1x1x2x?xf32>
```
This PR also adds a check to verify that only the last N (for some value
of N) trailing dims that are being tiled. From what I can tell, that's
always the case in practice. For this PR, it simplifies how the
permutation for linalg.transpose is computed. If needed, this can be
relaxed in the future
>From 8ed63a25bc7f9b4c49bd920cd6407be7147c3b2f Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 6 Nov 2024 21:16:54 +0000
Subject: [PATCH] [mlir][tensor] Update `GeneralizeOuterUnitDimsPackOpPattern`
Avoid generating spurious tensor.extract_slice, follow-on for #114315.
This is best to demonstrate with an example. Here's input for
`GeneralizeOuterUnitDimsPackOpPattern`:
```mlir
%pack = tensor.pack %input
padding_value(%pad : f32)
inner_dims_pos = [1, 0]
inner_tiles = [2, %tile_dim_1]
into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
```
Output _before_:
```mlir
%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%extracted_slice = tensor.extract_slice %padded[0, 0] [%arg3, 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
ins(%extracted_slice : tensor<?x2xf32>)
outs(%empty : tensor<2x?xf32>)
permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed=
into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
tensor<2x?xf32> into tensor<1x1x2x?xf32>
```
Output _after_:
```mlir
%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
ins(%padded : tensor<?x2xf32>)
outs(%empty : tensor<2x?xf32>) permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed
into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
tensor<2x?xf32> into tensor<1x1x2x?xf32>
```
This PR also adds a check to verify that only the last N (for some value
of N) trailing dims that are being tiled. From what I can tell, that's
always the case in practice. For this PR, it simplifies how the
permutation for linalg.transpose is computed. If needed, this can be
relaxed in the future
---
.../Dialect/Linalg/Transforms/Transforms.h | 6 +-
.../Dialect/Linalg/Transforms/Transforms.cpp | 78 +++++++++----------
.../Linalg/generalize-tensor-pack-tile.mlir | 18 ++---
.../Linalg/generalize-tensor-pack.mlir | 36 ++++-----
4 files changed, 62 insertions(+), 76 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a8662a3d6f63be..5209e1145506b9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1516,7 +1516,7 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
};
/// Rewrites a tensor::PackOp into a sequence of:
-/// * tensor::PadOp + linalg::TransposeOp + tensor::ExtractSliceOp +
+/// * tensor::PadOp + linalg::TransposeOp +
/// tensor::EmptyOp + tensor::InsertSliceOp ops.
///
/// Required that all the outer dims of the input tensor::PackOp are 1.
@@ -1537,10 +1537,6 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
/// ^bb0(...):
/// tensor.yield %arg2 : f32
/// } : tensor<5x1xf32> to tensor<?x2xf32>
-/// // ExtractSliceOp
-/// %extracted_slice = tensor.extract_slice %padded[0, 0] [%tile_dim_1, 2] [1,
-/// 1]
-/// : tensor<?x2xf32> to tensor<?x2xf32>
/// // EmptyOp + TransposeOp
/// %empty = tensor.empty(%arg3) : tensor<2x?xf32>
/// %transposed = linalg.transpose
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 64096954f56b95..0be8799f327441 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1153,71 +1153,63 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
Location loc = packOp.getLoc();
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
- auto inputShape = packOp.getSourceType().getShape();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
int64_t srcRank = packOp.getSourceRank();
-
int64_t destRank = packOp.getDestRank();
- size_t numTiles = destRank - srcRank;
-
- // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
- // %extracted_tile = tensor.extract_slice(%pack_op_input)
- SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
- SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
+ int64_t numTiles = destRank - srcRank;
- // The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
- // all outer dims are 1.
- SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr);
- // The shape of the output for ExtractSliceOp. All leading unit dims are
- // effectively rank-reduced, hence skipped.
- SmallVector<int64_t> outputShapeForExtractSlice;
+ if (!llvm::all_of(packOp.getInnerDimsPos(),
+ [&srcRank, &numTiles](int64_t dimPos) {
+ return dimPos >= (srcRank - numTiles - 1);
+ }))
+ return rewriter.notifyMatchFailure(
+ packOp, "Attempting to tile non-trailing source dims!");
- // Extract the trailing sizes and shape dims for ExtractSliceOp. These should
- // be equal to the inner tile sizes.
+ // 1. Extract the inner tile sizes.
+ // Where possible, values are replaced with constant attributes (to match the
+ // behaviour of `getPackOpSourceOrPaddedSource`).
+ SmallVector<OpFoldResult> tileSizes;
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
- auto [tileSize, tileSizeOfr] =
+ // Rather than taking the tile size as is, extact the actual constant
+ // value Attribute where possible, e.g.:
+ // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
+ auto [_, tileSize] =
getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
- extractSliceSizes.push_back(tileSizeOfr);
- outputShapeForExtractSlice.push_back(tileSize);
+ tileSizes.push_back(tileSize);
}
}
- Type elemType = packOp.getSourceType().getElementType();
- auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType);
-
- Value tile = rewriter.create<tensor::ExtractSliceOp>(
- loc, readType, input, readOffsets, extractSliceSizes, readStrides);
-
- // 2. Transpose the tile to match the inner tile order:
+ // 2. Transpose the input to match the inner tile order:
// %init = tensor.empty()
// %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
- // NOTE: Outer dims are 1 and hence effectively ignored.
- SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
- inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
+ // Two assumptions are made:
+ // 1. All outer dims are 1 - the corresponding transposition doesn't matter.
+ // 2. Inner dims position correspond to the trailing `numTiles` dims.
+ SmallVector<int64_t> tilesPermNormalized =
+ getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
+ SmallVector<int64_t> srcPermForTranspose;
+ for (int64_t i = 0; i < (srcRank - numTiles); i++)
+ srcPermForTranspose.push_back(i);
+
+ srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
- llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
+ llvm::interleaveComma(srcPermForTranspose, DBGS() << "perm: "); DBGSNL(););
// 2.1 Create tensor.empty (init value for TransposeOp)
- SmallVector<OpFoldResult> transShapeForEmptyOp;
-
- // Acquire tensor shape required to create EmptyOp. This will match the inner
- // tile sizes.
- size_t idx = numTiles;
- while (idx != 0) {
- transShapeForEmptyOp.push_back(extractSliceSizes[srcRank - idx]);
- idx--;
- }
+ SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
+ oneIdxAttr);
+ transShapeForEmptyOp.append(tileSizes);
- applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, perm);
- Value empty =
- rewriter.create<tensor::EmptyOp>(loc, transShapeForEmptyOp, elemType);
+ applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, srcPermForTranspose);
+ Value empty = rewriter.create<tensor::EmptyOp>(
+ loc, transShapeForEmptyOp, packOp.getSourceType().getElementType());
// 2.2 Create linalg.transpose
auto transposedOp =
- rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
+ rewriter.create<linalg::TransposeOp>(loc, input, empty, srcPermForTranspose);
// 3. Insert the inner tile to the destination:
// %inserted_tile = tensor.insert_slice(%transposed_tile)
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
index d0c53ae4680013..8be3e7413bfc81 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
@@ -9,19 +9,19 @@ func.func @KCRS_to_KCRSsr(%arg0: tensor<1x1x128x64xf32>, %arg1: tensor<1x1x4x8x8
// CHECK: func.func @KCRS_to_KCRSsr
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK: %{{.+}} = scf.for %[[R:[a-zA-Z0-9]+]] =
-// CHECK: %{{.+}} = scf.for %[[S:[a-zA-Z0-9]+]] =
+// CHECK: scf.for %[[R:[a-zA-Z0-9]+]] =
+// CHECK: scf.for %[[S:[a-zA-Z0-9]+]] {{.*}} iter_args(%[[ITER_SLICE:.*]] =
// CHECK: %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]])
// CHECK: %[[IN_S:.+]] = affine.apply #[[MAP2]](%[[S]])
// CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
// CHECK-SAME: [0, 0, %[[IN_R]], %[[IN_S]]] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
-// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32>
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32>
-// CHECK: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[TILE]]
-// CHECK-SAME: outs(%[[EMPTY]]
-// CHECK-SAME: permutation = [1, 0]
+// CHECK: %[[TILE:.*]] = tensor.extract_slice %[[ITER_SLICE]]
+// CHECK-SAME: [0, 0, %[[R]], %[[S]], 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] : tensor<1x1x4x8x8x32xf32> to tensor<1x1x1x1x8x32xf32>
+// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x8x32xf32>
+// CHECK: %[[TRANSP:.*]] = linalg.transpose
+// CHECK-SAME: ins(%[[SRC_SLICE]] : tensor<1x1x32x8xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x8x32xf32>)
+// CHECK-SAME: permutation = [0, 1, 3, 2]
// CHECK: %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}}
module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 8abf7a11bed5c9..f4b1d9a55f0914 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -63,8 +63,7 @@ func.func @simple_pad_and_pack_dynamic_tile(%input: tensor<5x1xf32>, %output: te
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
// CHECK: tensor.yield %[[PAD_VAL]] : f32
// CHECK-NOT: linalg.transpose
-// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_0]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
-// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
func.func @simple_pad_and_pack_dynamic_tile_cst(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
@@ -95,10 +94,10 @@ func.func @simple_pad_and_pack_dynamic_tile_transpose(%input: tensor<5x1xf32>, %
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
// CHECK: tensor.yield %[[PAD_VAL]] : f32
// CHECK-NEXT: } : tensor<5x1xf32> to tensor<?x2xf32>
-// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_1]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[TILE_DIM_1]]) : tensor<2x?xf32>
// CHECK: %[[TR:.*]] = linalg.transpose
-// CHECK-SAME: ins(%[[SLICE]] : tensor<?x2xf32>) outs(%[[EMPTY]] : tensor<2x?xf32>)
+// CHECK-SAME: ins(%[[PAD:.*]] : tensor<?x2xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x?xf32>)
// CHECK-SAME: permutation = [1, 0]
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[TR]] into %[[DEST]][0, 0, 0, 0] [1, 1, 2, %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x2x?xf32>
// CHECK: return %[[RES]] : tensor<1x1x2x?xf32>
@@ -128,10 +127,10 @@ func.func @simple_pad_and_pack_scalable_tile(%input: tensor<5x1xf32>, %output: t
// CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
// CHECK: tensor.yield %[[PAD_VAL]] : f32
// CHECK-NOT: linalg.transpose
-// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
-// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
+
/// Same as example above, but with both tile sizes dynamic.
func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x?x?xf32>, %pad: f32, %tile_dim_0: index, %tile_dim_1: index) -> tensor<1x1x?x?xf32> {
@@ -149,8 +148,7 @@ func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: t
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH_1]], %[[PAD_HIGH_2]]] {
// CHECK: tensor.yield %[[PAD_VAL]] : f32
// CHECK-NOT: linalg.transpose
-// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<1x1x?x?xf32>
+// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<1x1x?x?xf32>
// CHECK: return %[[RES]] : tensor<1x1x?x?xf32>
// -----
@@ -170,12 +168,13 @@ func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x
// CHECK: ^bb0(%[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: index):
// CHECK: tensor.yield %[[VAL_2]] : f32
// CHECK: } : tensor<1x1x5x1xf32> to tensor<1x1x?x2xf32>
-// CHECK: %[[VAL_10:.*]] = tensor.extract_slice %[[VAL_11:.*]][0, 0, 0, 0] [1, 1, %[[VAL_3]], 2] [1, 1, 1, 1] : tensor<1x1x?x2xf32> to tensor<?x2xf32>
-// CHECK: %[[VAL_12:.*]] = tensor.empty(%[[VAL_3]]) : tensor<2x?xf32>
-// CHECK: %[[VAL_13:.*]] = linalg.transpose ins(%[[VAL_10]] : tensor<?x2xf32>) outs(%[[VAL_12]] : tensor<2x?xf32>) permutation = [1, 0]
-// CHECK: %[[VAL_14:.*]] = tensor.insert_slice %[[VAL_13]] into %[[VAL_1]][0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, %[[VAL_3]]] [1, 1, 1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x1x1x2x?xf32>
-// CHECK: return %[[VAL_14]] : tensor<1x1x1x1x2x?xf32>
-// CHECK: }
+// CHECK: %[[VAL_10:.*]] = tensor.empty(%[[VAL_3]]) : tensor<1x1x2x?xf32>
+// CHECK: %[[VAL_11:.*]] = linalg.transpose
+// CHECK-SAME: ins(%[[VAL_12:.*]] : tensor<1x1x?x2xf32>)
+// CHECK-SAME: outs(%[[VAL_10]] : tensor<1x1x2x?xf32>)
+// CHECK-SAME: permutation = [0, 1, 3, 2]
+// CHECK: %[[VAL_13:.*]] = tensor.insert_slice %[[VAL_11]] into %[[VAL_1]][0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, %[[VAL_3]]] [1, 1, 1, 1, 1, 1] : tensor<1x1x2x?xf32> into tensor<1x1x1x1x2x?xf32>
+// CHECK: return %[[VAL_13]] : tensor<1x1x1x1x2x?xf32>
// -----
@@ -218,12 +217,11 @@ func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x
// CHECK-LABEL: func.func @simple_KCRS_to_KRSCsr
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x32xf32>
// CHECK: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[TILE]] : tensor<32x8xf32>)
-// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x32xf32>)
-// CHECK-SAME: permutation = [1, 0]
+// CHECK-SAME: ins(%[[SRC]] : tensor<1x1x32x8xf32>
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x8x32xf32>)
+// CHECK-SAME: permutation = [0, 1, 3, 2]
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
// CHECK: return %[[INSERT]]
More information about the Mlir-commits
mailing list