[Mlir-commits] [mlir] 58da789 - [mlir][linalg] Fix and Refactor DecomposeOuterUnitDimsUnPackOpPattern (#119379)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 16 12:38:04 PST 2024

Author: Andrzej WarzyƄski
Date: 2024-12-16T20:38:00Z
New Revision: 58da789e72c3d26c9dac1b29884f5ce62b8150b1

URL: https://github.com/llvm/llvm-project/commit/58da789e72c3d26c9dac1b29884f5ce62b8150b1
DIFF: https://github.com/llvm/llvm-project/commit/58da789e72c3d26c9dac1b29884f5ce62b8150b1.diff

LOG: [mlir][linalg] Fix and Refactor DecomposeOuterUnitDimsUnPackOpPattern (#119379)




diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index ad629b7588e224..60cf897b00de37 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1254,64 +1254,98 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
         "require the tiled outer dimensions of the result are all 1s");
-  // 1. Use rank-reduced tensor.extract_slice op to extract the tile.
+  // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
+  //    %extracted_tile = tensor.extract_slice(%unpack_op_input)
   Location loc = unpackOp.getLoc();
   Value source = unpackOp.getSource();
   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
   Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
   Attribute oneIdxAttr = rewriter.getIndexAttr(1);
-  SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
-  SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
-  SmallVector<OpFoldResult> readSizes;
-  SmallVector<int64_t> readShape;
-  SmallVector<Value> dynamicDims;
+  // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
+  // dims:
+  //    [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
+  SmallVector<int64_t> readShapeForExtractSlice;
+  // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
+  // outer-tiled-dims being all 1), this will be
+  //    [ outer-untiled-dims, tile-sizes ]
+  SmallVector<OpFoldResult> extractSliceSizes;
+  // The offset and strides attributes for ExtractSliceOp.
+  SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
+  SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
+  // Shape for EmptyOp that's used as the init value for TransposeOp below.
+  // This should be:
+  //    [ outer-untiled-dims, tile-sizes ]
+  // However, skip unit dims - TransposeOp (below) applies rank-reduced
+  // permutation.
+  SmallVector<OpFoldResult> shapeForEmptyOp;
   for (auto i : llvm::seq<unsigned>(0, destRank)) {
+    // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
+    //
+    // As all outer tiled dims are 1, so the corresponding
+    // slice size to read will also 1. As this will be rank-reducing "extract
+    // slice" (i.e. the unit dims will be "collapsed"), there's no need to
+    // update:
+    //  * the output shape for ExtractSliceOp, nor
+    //  * the shape for EmptyOp.
     if (dimAndTileMapping.count(i)) {
-      readSizes.push_back(oneIdxAttr);
+      extractSliceSizes.push_back(oneIdxAttr);
+    // Compute sizes attribute for ExtractSliceOp + EmptyOp -
+    // outer-untiled-dims
     if (ShapedType::isDynamic(srcShape[i])) {
-      Value dynamicDim =
+      OpFoldResult dynamicDim =
           rewriter.create<tensor::DimOp>(loc, source, i).getResult();
-      readSizes.push_back(dynamicDim);
-      dynamicDims.push_back(dynamicDim);
+      extractSliceSizes.push_back(dynamicDim);
+      shapeForEmptyOp.push_back(dynamicDim);
     } else {
-      readSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
+      extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
+      if (srcShape[i] != 1)
+        shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
+    }
+    // Compute the output shape for ExtractSliceOp  - outer-untiled-dims (take
+    // into account rank-reducing)
+    if (srcShape[i] != 1) {
+      readShapeForExtractSlice.push_back(srcShape[i]);
-    if (srcShape[i] != 1)
-      readShape.push_back(srcShape[i]);
+  // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
+  // shape for EmptyOp.
   auto mixedTiles = unpackOp.getMixedTiles();
-  readSizes.append(mixedTiles.begin(), mixedTiles.end());
+  extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
+  shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
   // Explicitly create the type for extract_slice op because the inner tile
   // size could be 1. We want to represent the whole inner tile in this case.
   auto tileShape = srcShape.drop_front(destRank);
   // Append the inner tile shape to the permuted and rank-reduced outer shape.
-  readShape.append(tileShape.begin(), tileShape.end());
+  readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
   Type elemType = unpackOp.getSourceType().getElementType();
-  auto readType = RankedTensorType::get(readShape, elemType);
+  auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
   Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
-      loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides);
+      loc, readType, unpackOp.getSource(), extractSliceOffsets,
+      extractSliceSizes, extractSliceStrides);
   // 2. Transpose the tile to match the outer corresponding tile order.
   SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
       srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
   // Unpack is a transition out of packed space so we invert the permutation.
   perm = invertPermutationVector(perm);
-  SmallVector<int64_t> transpShape(readShape);
-  applyPermutationToVector<int64_t>(transpShape, perm);
+  applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
   Value empty =
-      rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims);
+      rewriter.create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType);
   auto transposedOp =
       rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
   // 3. Handle in-complete tiles if needed. It truncates trailing data from the
   // transposed tile.
-  int numLoops = transpShape.size();
+  int numLoops = shapeForEmptyOp.size();
   SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
   SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
   SmallVector<OpFoldResult> tileSizes;

diff  --git a/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir
index a720c655e4be51..bd60504f533456 100644
--- a/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir
@@ -35,15 +35,15 @@ func.func @simple_unpack_static_tiles(%input: tensor<1x1x8x2xf32>, %output: tens
 /// Same as example above, but with 1 dynamic tile size.
-func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>, %tile_dim_0: index) -> tensor<5x1xf32> {
-  %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32>
+func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>, %tile_dim: index) -> tensor<5x1xf32> {
+  %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%tile_dim, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32>
   return %0 : tensor<5x1xf32>
 // CHECK-LABEL: func.func @simple_unpack_dynamic_tile
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-SAME:    %[[TILE_DIM_1:[a-zA-Z0-9]+]]
-// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_1]], 2] [1, 1, 1, 1]
+// CHECK-SAME:    %[[TILE_DIM:[a-zA-Z0-9]+]]
+// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[TILE_DIM]], 2] [1, 1, 1, 1]
 // CHECK-NOT:     linalg.transpose
 //                They have the same type, so the insert_slice op is folded
 //                away.
@@ -52,13 +52,23 @@ func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tens
 /// Same as example above, but with 1 dynamic tile size and a trasnpose
-/// FIXME: This is currently broken:
-///   * 'tensor.empty' op incorrect number of dynamic sizes, has 0, expected 1
+func.func @simple_unpack_dynamic_tile_transpose(%src: tensor<1x1x2x?xf32>, %dest: tensor<5x1xf32>, %tile_dim: index) -> tensor<5x1xf32> {
+  %0 = tensor.unpack %src inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim] into %dest : tensor<1x1x2x?xf32> -> tensor<5x1xf32>
+  return %0 : tensor<5x1xf32>
+// CHECK-LABEL:   func.func @simple_unpack_dynamic_tile_transpose
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[TILE_DIM:[a-zA-Z0-9]+]]
+// CHECK:           %[[TILE:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 2, %[[TILE_DIM]]] [1, 1, 1, 1] : tensor<1x1x2x?xf32> to tensor<2x?xf32>
+// CHECK:           %[[EMPTY:.*]] = tensor.empty(%[[TILE_DIM]]) : tensor<?x2xf32>
+// CHECK:           %[[TRANSP:.*]] = linalg.transpose
+// CHECK-SAME:        ins(%[[TILE]] : tensor<2x?xf32>)
+// CHECK-SAME:        outs(%[[EMPTY]] : tensor<?x2xf32>)
+// CHECK-SAME:        permutation = [1, 0]
+// CHECK:           %[[SLICE:.*]] = tensor.extract_slice %[[TRANSP]][0, 0] [5, 1] [1, 1] : tensor<?x2xf32> to tensor<5x1xf32>
+// CHECK:           return %[[SLICE]] : tensor<5x1xf32>
-//func.func @simple_unpack_dynamic_tile_transpose(%input: tensor<1x1x2x?xf32>, %output: tensor<5x1xf32>, %tile_dim_0: index) -> tensor<5x1xf32> {
-//  %0 = tensor.unpack %input inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim_0] into %output : tensor<1x1x2x?xf32> -> tensor<5x1xf32>
-//  return %0 : tensor<5x1xf32>
 /// Same as example above, but with 1 scalable tile size.


More information about the Mlir-commits mailing list