[Mlir-commits] [mlir] 58da789 - [mlir][linalg] Fix and Refactor DecomposeOuterUnitDimsUnPackOpPattern (#119379)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 16 12:38:04 PST 2024
Author: Andrzej WarzyĆski
Date: 2024-12-16T20:38:00Z
New Revision: 58da789e72c3d26c9dac1b29884f5ce62b8150b1
URL: https://github.com/llvm/llvm-project/commit/58da789e72c3d26c9dac1b29884f5ce62b8150b1
DIFF: https://github.com/llvm/llvm-project/commit/58da789e72c3d26c9dac1b29884f5ce62b8150b1.diff
LOG: [mlir][linalg] Fix and Refactor DecomposeOuterUnitDimsUnPackOpPattern (#119379)
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index ad629b7588e224..60cf897b00de37 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1254,64 +1254,98 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
"require the tiled outer dimensions of the result are all 1s");
}
- // 1. Use rank-reduced tensor.extract_slice op to extract the tile.
+ // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
+ // %extracted_tile = tensor.extract_slice(%unpack_op_input)
Location loc = unpackOp.getLoc();
Value source = unpackOp.getSource();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
unpackOp.getDimAndTileMapping();
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
- SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
- SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
- SmallVector<OpFoldResult> readSizes;
- SmallVector<int64_t> readShape;
- SmallVector<Value> dynamicDims;
+
+ // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
+ // dims:
+ // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
+ SmallVector<int64_t> readShapeForExtractSlice;
+ // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
+ // outer-tiled-dims being all 1), this will be
+ // [ outer-untiled-dims, tile-sizes ]
+ SmallVector<OpFoldResult> extractSliceSizes;
+ // The offset and strides attributes for ExtractSliceOp.
+ SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
+ SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
+
+ // Shape for EmptyOp that's used as the init value for TransposeOp below.
+ // This should be:
+ // [ outer-untiled-dims, tile-sizes ]
+ // However, skip unit dims - TransposeOp (below) applies rank-reduced
+ // permutation.
+ SmallVector<OpFoldResult> shapeForEmptyOp;
+
for (auto i : llvm::seq<unsigned>(0, destRank)) {
+ // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
+ //
+ // As all outer tiled dims are 1, so the corresponding
+ // slice size to read will also 1. As this will be rank-reducing "extract
+ // slice" (i.e. the unit dims will be "collapsed"), there's no need to
+ // update:
+ // * the output shape for ExtractSliceOp, nor
+ // * the shape for EmptyOp.
if (dimAndTileMapping.count(i)) {
- readSizes.push_back(oneIdxAttr);
+ extractSliceSizes.push_back(oneIdxAttr);
continue;
}
+ // Compute sizes attribute for ExtractSliceOp + EmptyOp -
+ // outer-untiled-dims
if (ShapedType::isDynamic(srcShape[i])) {
- Value dynamicDim =
+ OpFoldResult dynamicDim =
rewriter.create<tensor::DimOp>(loc, source, i).getResult();
- readSizes.push_back(dynamicDim);
- dynamicDims.push_back(dynamicDim);
+ extractSliceSizes.push_back(dynamicDim);
+ shapeForEmptyOp.push_back(dynamicDim);
} else {
- readSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
+ extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
+ if (srcShape[i] != 1)
+ shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
+ }
+ // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take
+ // into account rank-reducing)
+ if (srcShape[i] != 1) {
+ readShapeForExtractSlice.push_back(srcShape[i]);
}
- if (srcShape[i] != 1)
- readShape.push_back(srcShape[i]);
}
+ // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
+ // shape for EmptyOp.
auto mixedTiles = unpackOp.getMixedTiles();
- readSizes.append(mixedTiles.begin(), mixedTiles.end());
+ extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
+ shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
// Explicitly create the type for extract_slice op because the inner tile
// size could be 1. We want to represent the whole inner tile in this case.
auto tileShape = srcShape.drop_front(destRank);
// Append the inner tile shape to the permuted and rank-reduced outer shape.
- readShape.append(tileShape.begin(), tileShape.end());
+ readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
Type elemType = unpackOp.getSourceType().getElementType();
- auto readType = RankedTensorType::get(readShape, elemType);
+ auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
- loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides);
+ loc, readType, unpackOp.getSource(), extractSliceOffsets,
+ extractSliceSizes, extractSliceStrides);
// 2. Transpose the tile to match the outer corresponding tile order.
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
// Unpack is a transition out of packed space so we invert the permutation.
perm = invertPermutationVector(perm);
- SmallVector<int64_t> transpShape(readShape);
- applyPermutationToVector<int64_t>(transpShape, perm);
+ applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
Value empty =
- rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims);
+ rewriter.create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType);
auto transposedOp =
rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
// 3. Handle in-complete tiles if needed. It truncates trailing data from the
// transposed tile.
- int numLoops = transpShape.size();
+ int numLoops = shapeForEmptyOp.size();
SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
SmallVector<OpFoldResult> tileSizes;
diff --git a/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir
index a720c655e4be51..bd60504f533456 100644
--- a/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir
@@ -35,15 +35,15 @@ func.func @simple_unpack_static_tiles(%input: tensor<1x1x8x2xf32>, %output: tens
/// Same as example above, but with 1 dynamic tile size.
-func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>, %tile_dim_0: index) -> tensor<5x1xf32> {
- %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32>
+func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>, %tile_dim: index) -> tensor<5x1xf32> {
+ %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%tile_dim, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32>
return %0 : tensor<5x1xf32>
}
// CHECK-LABEL: func.func @simple_unpack_dynamic_tile
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[TILE_DIM_1:[a-zA-Z0-9]+]]
-// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_1]], 2] [1, 1, 1, 1]
+// CHECK-SAME: %[[TILE_DIM:[a-zA-Z0-9]+]]
+// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[TILE_DIM]], 2] [1, 1, 1, 1]
// CHECK-NOT: linalg.transpose
// They have the same type, so the insert_slice op is folded
// away.
@@ -52,13 +52,23 @@ func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tens
/// Same as example above, but with 1 dynamic tile size and a trasnpose
-/// FIXME: This is currently broken:
-/// * 'tensor.empty' op incorrect number of dynamic sizes, has 0, expected 1
+func.func @simple_unpack_dynamic_tile_transpose(%src: tensor<1x1x2x?xf32>, %dest: tensor<5x1xf32>, %tile_dim: index) -> tensor<5x1xf32> {
+ %0 = tensor.unpack %src inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim] into %dest : tensor<1x1x2x?xf32> -> tensor<5x1xf32>
+ return %0 : tensor<5x1xf32>
+}
+// CHECK-LABEL: func.func @simple_unpack_dynamic_tile_transpose
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[TILE_DIM:[a-zA-Z0-9]+]]
+// CHECK: %[[TILE:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 2, %[[TILE_DIM]]] [1, 1, 1, 1] : tensor<1x1x2x?xf32> to tensor<2x?xf32>
+// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[TILE_DIM]]) : tensor<?x2xf32>
+// CHECK: %[[TRANSP:.*]] = linalg.transpose
+// CHECK-SAME: ins(%[[TILE]] : tensor<2x?xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x2xf32>)
+// CHECK-SAME: permutation = [1, 0]
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[TRANSP]][0, 0] [5, 1] [1, 1] : tensor<?x2xf32> to tensor<5x1xf32>
+// CHECK: return %[[SLICE]] : tensor<5x1xf32>
-//func.func @simple_unpack_dynamic_tile_transpose(%input: tensor<1x1x2x?xf32>, %output: tensor<5x1xf32>, %tile_dim_0: index) -> tensor<5x1xf32> {
-// %0 = tensor.unpack %input inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim_0] into %output : tensor<1x1x2x?xf32> -> tensor<5x1xf32>
-// return %0 : tensor<5x1xf32>
-//}
/// Same as example above, but with 1 scalable tile size.
More information about the Mlir-commits
mailing list