[Mlir-commits] [mlir] [mlir][tensor] Update `GeneralizeOuterUnitDimsPackOpPattern` (PR #115312)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Nov 7 05:35:59 PST 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/115312

>From adc0fc70c2e3dda0a3bfad0586e43a23fa273e6d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 6 Nov 2024 21:16:54 +0000
Subject: [PATCH] [mlir][tensor] Update `GeneralizeOuterUnitDimsPackOpPattern`

Avoid generating spurious tensor.extract_slice, follow-on for #114315.

This is best to demonstrate with an example. Here's input for
`GeneralizeOuterUnitDimsPackOpPattern`:
```mlir
%pack = tensor.pack %input
  padding_value(%pad : f32)
  inner_dims_pos = [1, 0]
  inner_tiles = [2, %tile_dim_1]
  into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
```

Output _before_:
```mlir
%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
  tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%extracted_slice = tensor.extract_slice %padded[0, 0] [%arg3, 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
  ins(%extracted_slice : tensor<?x2xf32>)
  outs(%empty : tensor<2x?xf32>)
  permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed=
  into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
  tensor<2x?xf32> into tensor<1x1x2x?xf32>
```

Output _after_:
```mlir
%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
  tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
  ins(%padded : tensor<?x2xf32>)
  outs(%empty : tensor<2x?xf32>) permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed
  into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
  tensor<2x?xf32> into tensor<1x1x2x?xf32>
```

This PR also adds a check to verify that only the last N (for some value
of N) trailing dims that are being tiled. From what I can tell, that's
always the case in practice. For this PR, it simplifies how the
permutation for linalg.transpose is computed. If needed, this can be
relaxed in the future
---
 .../Dialect/Linalg/Transforms/Transforms.h    |  6 +-
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 82 +++++++++----------
 .../Linalg/generalize-tensor-pack-tile.mlir   | 18 ++--
 .../Linalg/generalize-tensor-pack.mlir        | 36 ++++----
 4 files changed, 65 insertions(+), 77 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a8662a3d6f63be..5209e1145506b9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1516,7 +1516,7 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
 };
 
 /// Rewrites a tensor::PackOp into a sequence of:
-///   * tensor::PadOp + linalg::TransposeOp + tensor::ExtractSliceOp +
+///   * tensor::PadOp + linalg::TransposeOp +
 ///     tensor::EmptyOp + tensor::InsertSliceOp ops.
 ///
 /// Required that all the outer dims of the input tensor::PackOp are 1.
@@ -1537,10 +1537,6 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
 ///     ^bb0(...):
 ///       tensor.yield %arg2 : f32
 ///   } : tensor<5x1xf32> to tensor<?x2xf32>
-///   // ExtractSliceOp
-///   %extracted_slice = tensor.extract_slice %padded[0, 0] [%tile_dim_1, 2] [1,
-///   1]
-///     : tensor<?x2xf32> to tensor<?x2xf32>
 ///   // EmptyOp + TransposeOp
 ///   %empty = tensor.empty(%arg3) : tensor<2x?xf32>
 ///   %transposed = linalg.transpose
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 64096954f56b95..f7409f75b9e122 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1153,71 +1153,65 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   Location loc = packOp.getLoc();
 
   Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
-  auto inputShape = packOp.getSourceType().getShape();
   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
       packOp.getDimAndTileMapping();
   int64_t srcRank = packOp.getSourceRank();
-
   int64_t destRank = packOp.getDestRank();
-  size_t numTiles = destRank - srcRank;
-
-  // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
-  //    %extracted_tile = tensor.extract_slice(%pack_op_input)
-  SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
-  SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
+  int64_t numTiles = destRank - srcRank;
 
-  // The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
-  // all outer dims are 1.
-  SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr);
-  // The shape of the output for ExtractSliceOp. All leading unit dims are
-  // effectively rank-reduced, hence skipped.
-  SmallVector<int64_t> outputShapeForExtractSlice;
+  if (!llvm::all_of(packOp.getInnerDimsPos(),
+                    [&srcRank, &numTiles](int64_t dimPos) {
+                      return dimPos >= (srcRank - numTiles - 1);
+                    }))
+    return rewriter.notifyMatchFailure(
+        packOp, "Attempting to tile non-trailing source dims!");
 
-  // Extract the trailing sizes and shape dims for ExtractSliceOp. These should
-  // be equal to the inner tile sizes.
+  // 1. Extract the inner tile sizes.
+  // Where possible, values are replaced with constant attributes (to match the
+  // behaviour of `getPackOpSourceOrPaddedSource`).
+  SmallVector<OpFoldResult> tileSizes;
   for (auto i : llvm::seq<unsigned>(0, srcRank)) {
     if (dimAndTileMapping.count(i)) {
-      auto [tileSize, tileSizeOfr] =
+      // Rather than taking the tile size as is, extact the actual constant
+      // value Attribute where possible, e.g.:
+      //    [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
+      auto [_, tileSize] =
           getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
-      extractSliceSizes.push_back(tileSizeOfr);
-      outputShapeForExtractSlice.push_back(tileSize);
+      tileSizes.push_back(tileSize);
     }
   }
 
-  Type elemType = packOp.getSourceType().getElementType();
-  auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType);
-
-  Value tile = rewriter.create<tensor::ExtractSliceOp>(
-      loc, readType, input, readOffsets, extractSliceSizes, readStrides);
-
-  // 2. Transpose the tile to match the inner tile order:
+  // 2. Transpose the input to match the inner tile order:
   //    %init = tensor.empty()
   //    %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
-  // NOTE: Outer dims are 1 and hence effectively ignored.
-  SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
-      inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
+  // Two assumptions are made:
+  //  1. All outer dims are 1 - the corresponding transposition doesn't matter.
+  //  2. Inner dims position correspond to the trailing `numTiles` dims.
+  SmallVector<int64_t> tilesPermNormalized =
+      getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
+  SmallVector<int64_t> srcPermForTranspose;
+  for (int64_t i = 0; i < (srcRank - numTiles); i++)
+    srcPermForTranspose.push_back(i);
+
+  srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
 
   LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
-             llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
+             llvm::interleaveComma(srcPermForTranspose, DBGS() << "perm: ");
+             DBGSNL(););
 
   // 2.1 Create tensor.empty (init value for TransposeOp)
-  SmallVector<OpFoldResult> transShapeForEmptyOp;
-
-  // Acquire tensor shape required to create EmptyOp. This will match the inner
-  // tile sizes.
-  size_t idx = numTiles;
-  while (idx != 0) {
-    transShapeForEmptyOp.push_back(extractSliceSizes[srcRank - idx]);
-    idx--;
-  }
+  SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
+                                                 oneIdxAttr);
+  transShapeForEmptyOp.append(tileSizes);
 
-  applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, perm);
-  Value empty =
-      rewriter.create<tensor::EmptyOp>(loc, transShapeForEmptyOp, elemType);
+  applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
+                                         srcPermForTranspose);
+  Value empty = rewriter.create<tensor::EmptyOp>(
+      loc, transShapeForEmptyOp, packOp.getSourceType().getElementType());
 
   // 2.2 Create linalg.transpose
-  auto transposedOp =
-      rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
+  auto transposedOp = rewriter.create<linalg::TransposeOp>(loc, input, empty,
+                                                           srcPermForTranspose);
 
   // 3. Insert the inner tile to the destination:
   //  %inserted_tile = tensor.insert_slice(%transposed_tile)
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
index d0c53ae4680013..8be3e7413bfc81 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
@@ -9,19 +9,19 @@ func.func @KCRS_to_KCRSsr(%arg0: tensor<1x1x128x64xf32>, %arg1: tensor<1x1x4x8x8
 // CHECK:       func.func @KCRS_to_KCRSsr
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:         %{{.+}} = scf.for %[[R:[a-zA-Z0-9]+]] =
-// CHECK:           %{{.+}} = scf.for %[[S:[a-zA-Z0-9]+]] =
+// CHECK:         scf.for %[[R:[a-zA-Z0-9]+]] =
+// CHECK:           scf.for %[[S:[a-zA-Z0-9]+]] {{.*}} iter_args(%[[ITER_SLICE:.*]] =
 // CHECK:             %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]])
 // CHECK:             %[[IN_S:.+]] = affine.apply #[[MAP2]](%[[S]])
 // CHECK:             %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
 // CHECK-SAME:          [0, 0, %[[IN_R]], %[[IN_S]]] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK:             %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
-// CHECK-SAME:          [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32>
-// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32>
-// CHECK:             %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-SAME:          ins(%[[TILE]]
-// CHECK-SAME:          outs(%[[EMPTY]]
-// CHECK-SAME:          permutation = [1, 0]
+// CHECK:             %[[TILE:.*]] = tensor.extract_slice %[[ITER_SLICE]]
+// CHECK-SAME:          [0, 0, %[[R]], %[[S]], 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] : tensor<1x1x4x8x8x32xf32> to tensor<1x1x1x1x8x32xf32>
+// CHECK:             %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x8x32xf32>
+// CHECK:             %[[TRANSP:.*]] = linalg.transpose
+// CHECK-SAME:          ins(%[[SRC_SLICE]] : tensor<1x1x32x8xf32>)
+// CHECK-SAME:          outs(%[[EMPTY]] : tensor<1x1x8x32xf32>)
+// CHECK-SAME:          permutation = [0, 1, 3, 2]
 // CHECK:             %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}}
 
 module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 8abf7a11bed5c9..f4b1d9a55f0914 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -63,8 +63,7 @@ func.func @simple_pad_and_pack_dynamic_tile(%input: tensor<5x1xf32>, %output: te
 // CHECK:           %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
 // CHECK:             tensor.yield %[[PAD_VAL]] : f32
 // CHECK-NOT:       linalg.transpose
-// CHECK:           %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_0]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
-// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
 // CHECK:           return %[[RES]] : tensor<1x1x?x2xf32>
 
 func.func @simple_pad_and_pack_dynamic_tile_cst(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
@@ -95,10 +94,10 @@ func.func @simple_pad_and_pack_dynamic_tile_transpose(%input: tensor<5x1xf32>, %
 // CHECK:           %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
 // CHECK:            tensor.yield %[[PAD_VAL]] : f32
 // CHECK-NEXT:      } : tensor<5x1xf32> to tensor<?x2xf32>
-// CHECK:           %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_1]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
 // CHECK:           %[[EMPTY:.*]] = tensor.empty(%[[TILE_DIM_1]]) : tensor<2x?xf32>
 // CHECK:           %[[TR:.*]] = linalg.transpose
-// CHECK-SAME:        ins(%[[SLICE]] : tensor<?x2xf32>) outs(%[[EMPTY]] : tensor<2x?xf32>)
+// CHECK-SAME:        ins(%[[PAD:.*]] : tensor<?x2xf32>)
+// CHECK-SAME:        outs(%[[EMPTY]] : tensor<2x?xf32>)
 // CHECK-SAME:        permutation = [1, 0]
 // CHECK:           %[[RES:.*]] = tensor.insert_slice %[[TR]] into %[[DEST]][0, 0, 0, 0] [1, 1, 2, %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x2x?xf32>
 // CHECK:           return %[[RES]] : tensor<1x1x2x?xf32>
@@ -128,10 +127,10 @@ func.func @simple_pad_and_pack_scalable_tile(%input: tensor<5x1xf32>, %output: t
 // CHECK:           %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
 // CHECK:             tensor.yield %[[PAD_VAL]] : f32
 // CHECK-NOT:       linalg.transpose
-// CHECK:           %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
-// CHECK:           %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
 // CHECK:           return %[[RES]] : tensor<1x1x?x2xf32>
 
+
 /// Same as example above, but with both tile sizes dynamic.
 
 func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x?x?xf32>, %pad: f32, %tile_dim_0: index, %tile_dim_1: index) -> tensor<1x1x?x?xf32> {
@@ -149,8 +148,7 @@ func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: t
 // CHECK:           %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH_1]], %[[PAD_HIGH_2]]] {
 // CHECK:             tensor.yield %[[PAD_VAL]] : f32
 // CHECK-NOT:       linalg.transpose
-// CHECK:           %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<1x1x?x?xf32>
+// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<1x1x?x?xf32>
 // CHECK:           return %[[RES]] : tensor<1x1x?x?xf32>
 
 // -----
@@ -170,12 +168,13 @@ func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x
 // CHECK:           ^bb0(%[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: index):
 // CHECK:             tensor.yield %[[VAL_2]] : f32
 // CHECK:           } : tensor<1x1x5x1xf32> to tensor<1x1x?x2xf32>
-// CHECK:           %[[VAL_10:.*]] = tensor.extract_slice %[[VAL_11:.*]][0, 0, 0, 0] [1, 1, %[[VAL_3]], 2] [1, 1, 1, 1] : tensor<1x1x?x2xf32> to tensor<?x2xf32>
-// CHECK:           %[[VAL_12:.*]] = tensor.empty(%[[VAL_3]]) : tensor<2x?xf32>
-// CHECK:           %[[VAL_13:.*]] = linalg.transpose ins(%[[VAL_10]] : tensor<?x2xf32>) outs(%[[VAL_12]] : tensor<2x?xf32>) permutation = [1, 0]
-// CHECK:           %[[VAL_14:.*]] = tensor.insert_slice %[[VAL_13]] into %[[VAL_1]][0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, %[[VAL_3]]] [1, 1, 1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x1x1x2x?xf32>
-// CHECK:           return %[[VAL_14]] : tensor<1x1x1x1x2x?xf32>
-// CHECK:         }
+// CHECK:           %[[VAL_10:.*]] = tensor.empty(%[[VAL_3]]) : tensor<1x1x2x?xf32>
+// CHECK:           %[[VAL_11:.*]] = linalg.transpose
+// CHECK-SAME:        ins(%[[VAL_12:.*]] : tensor<1x1x?x2xf32>)
+// CHECK-SAME:        outs(%[[VAL_10]] : tensor<1x1x2x?xf32>)
+// CHECK-SAME:        permutation = [0, 1, 3, 2]
+// CHECK:           %[[VAL_13:.*]] = tensor.insert_slice %[[VAL_11]] into %[[VAL_1]][0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, %[[VAL_3]]] [1, 1, 1, 1, 1, 1] : tensor<1x1x2x?xf32> into tensor<1x1x1x1x2x?xf32>
+// CHECK:           return %[[VAL_13]] : tensor<1x1x1x1x2x?xf32>
 
 // -----
 
@@ -218,12 +217,11 @@ func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x
 // CHECK-LABEL: func.func @simple_KCRS_to_KRSCsr
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32>
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x32xf32>
 // CHECK:         %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-SAME:      ins(%[[TILE]] : tensor<32x8xf32>)
-// CHECK-SAME:      outs(%[[EMPTY]] : tensor<8x32xf32>)
-// CHECK-SAME:      permutation = [1, 0]
+// CHECK-SAME:      ins(%[[SRC]] : tensor<1x1x32x8xf32>
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<1x1x8x32xf32>)
+// CHECK-SAME:      permutation = [0, 1, 3, 2]
 // CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]



More information about the Mlir-commits mailing list