[Mlir-commits] [mlir] bce951c - [mlir][linalg] Update vectorization logic for linalg.unpack (#149156)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 17 01:14:20 PDT 2025
Author: Andrzej WarzyĆski
Date: 2025-07-17T09:14:17+01:00
New Revision: bce951c572465c6ccd59b73a58c536641abc43eb
URL: https://github.com/llvm/llvm-project/commit/bce951c572465c6ccd59b73a58c536641abc43eb
DIFF: https://github.com/llvm/llvm-project/commit/bce951c572465c6ccd59b73a58c536641abc43eb.diff
LOG: [mlir][linalg] Update vectorization logic for linalg.unpack (#149156)
This PR makes sure that we don't generate unnecessary `tensor.empty`
when vectorizing `linalg.unpack`.
To better visualize the changes implemented here, consider this IR:
```mlir
func.func @example(
%source: tensor<8x4x16x16xf32>,
%dest: tensor<64x127xf32>) -> tensor<64x127xf32> {
%res = linalg.unpack %source
outer_dims_perm = [1, 0]
inner_dims_pos = [0, 1]
inner_tiles = [16, 16]
into %dest : tensor<8x4x16x16xf32> -> tensor<64x127xf32>
return %res : tensor<64x127xf32>
}
```
Below is the output after vectorization, BEFORE and AFTER this PR.
BEFORE (note `tensor.empty` and the fact that `%arg1` is not used):
```mlir
func.func @example(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<64x127xf32>) -> tensor<64x127xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32>
%1 = vector.transpose %0, [1, 2, 0, 3] : vector<8x4x16x16xf32> to vector<4x16x8x16xf32>
%2 = vector.shape_cast %1 : vector<4x16x8x16xf32> to vector<64x128xf32>
%3 = tensor.empty() : tensor<64x127xf32>
%c0_0 = arith.constant 0 : index
%4 = vector.transfer_write %2, %3[%c0_0, %c0_0] {in_bounds = [true, false]} : vector<64x128xf32>, tensor<64x127xf32>
return %4 : tensor<64x127xf32>
}
```
AFTER (note that `%arg1` is correctly used):
```mlir
func.func @example(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<64x127xf32>) -> tensor<64x127xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32>
%1 = vector.transpose %0, [1, 2, 0, 3] : vector<8x4x16x16xf32> to vector<4x16x8x16xf32>
%2 = vector.shape_cast %1 : vector<4x16x8x16xf32> to vector<64x128xf32>
%c0_0 = arith.constant 0 : index
%3 = vector.transfer_write %2, %arg1[%c0_0, %c0_0] {in_bounds = [true, false]} : vector<64x128xf32>, tensor<64x127xf32>
return %3 : tensor<64x127xf32>
}
```
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5a8c5eab3f444..458ed543b8216 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1928,11 +1928,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
unpackOp.getDestType().hasStaticShape()
? vectorSizes
: shapeCastOp.getResultVectorType().getShape());
- Value dest = rewriter.create<tensor::EmptyOp>(
- loc, reifiedRetShapes[0],
- shapeCastOp.getResult().getType().getElementType());
Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, shapeCastOp.getResult(), dest,
+ rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
newResults.push_back(write->getResult(0));
return success();
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 9e501affdd2a5..679adf0a52175 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -1158,6 +1158,7 @@ module attributes {transform.with_named_sequence} {
// -----
// CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
// CHECK: %[[C0:.*]] = arith.constant 0
// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32>
@@ -1175,9 +1176,8 @@ func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: t
// CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
// CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x2x1x16xf32> to vector<4x16xf32>
-// CHECK: %[[empt0:.*]] = tensor.empty
// CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<4x16xi1>
-// CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
+// CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[ARG_0]]
// CHECK: return %[[write0]]
%ret = linalg.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
return %ret : tensor<?x?xf32>
@@ -1193,6 +1193,8 @@ module attributes {transform.with_named_sequence} {
// -----
// CHECK-LABEL: func @test_vectorize_unpack
+// CHECK-SAME: %[[SRC:.*]]: tensor<8x8x32x16xf32>
+// CHECK-SAME: %[[DEST:.*]]: tensor<256x128xf32>
func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]]= arith.constant 0 : index
@@ -1201,15 +1203,14 @@ func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<2
// CHECK: %[[C32:.*]] = arith.constant 32 : index
// CHECK: %[[C16:.*]] = arith.constant 16 : index
// CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<16x8x32x16xi1>
- // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32>
+ // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] { vector.transfer_read %[[SRC]]{{.*}}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32>
// CHECK: %[[TRANSP0:.*]] = vector.transpose %[[READ0]], [0, 2, 1, 3] : vector<16x8x32x16xf32> to vector<16x32x8x16xf32>
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP0]] : vector<16x32x8x16xf32> to vector<512x128xf32>
- // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
// CHECK: %[[C01:.*]] = arith.constant 0 : index
// CHECK: %[[C256:.*]] = arith.constant 256 : index
// CHECK: %[[C128:.*]] = arith.constant 128 : index
// CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C256]], %[[C128]] : vector<512x128xi1>
- // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<512x128xi1> -> tensor<256x128xf32>
+ // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] { vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<512x128xi1> -> tensor<256x128xf32>
// CHECK: return %[[WRIT]] : tensor<256x128xf32>
%0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
@@ -1225,15 +1226,16 @@ func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<2
// -----
// CHECK-LABEL: func @test_vectorize_unpack_no_masks
+// CHECK-SAME: %[[SRC:.*]]: tensor<8x8x32x16xf32>
+// CHECK-SAME: %[[DEST:.*]]: tensor<256x128xf32>
func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+ // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
- // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
// CHECK: %[[C00:.*]] = arith.constant 0 : index
- // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
+ // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<256x128xf32>, tensor<256x128xf32>
// CHECK: return %[[WRIT]] : tensor<256x128xf32>
%0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
@@ -1248,16 +1250,17 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
// -----
- // CHECK-LABEL: test_vectorize_unpack_with_outer_perm
+// CHECK-LABEL: test_vectorize_unpack_with_outer_perm
+// CHECK-SAME: %[[SRC:.*]]: tensor<8x8x32x16xf32>
+// CHECK-SAME: %[[DEST:.*]]: tensor<256x128xf32>
func.func @test_vectorize_unpack_with_outer_perm(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+ // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
- // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
// CHECK: %[[C00:.*]] = arith.constant 0 : index
- // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
+ // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<256x128xf32>, tensor<256x128xf32>
// CHECK: return %[[WRIT]] : tensor<256x128xf32>
%0 = linalg.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
@@ -1327,15 +1330,17 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: test_vectorize_unpack_no_vector_sizes
+// CHECK-SAME: %[[SRC:.*]]: tensor<8x8x32x16xf32>
+// CHECK-SAME: %[[DEST:.*]]: tensor<256x128xf32>
func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+ // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
- // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
// CHECK: %[[C00:.*]] = arith.constant 0 : index
- // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
+ // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<256x128xf32>, tensor<256x128xf32>
// CHECK: return %[[WRIT]] : tensor<256x128xf32>
%0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
@@ -1350,15 +1355,17 @@ func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>,
// -----
+// CHECK-LABEL: test_vectorize_unpack_no_vector_sizes_slice_output
+// CHECK-SAME: %[[SRC:.*]]: tensor<8x4x16x16xf32>
+// CHECK-SAME: %[[DEST:.*]]: tensor<64x127xf32>
func.func @test_vectorize_unpack_no_vector_sizes_slice_output(%source: tensor<8x4x16x16xf32>, %dest: tensor<64x127xf32>) -> tensor<64x127xf32> {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32>
+ // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32>
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x4x16x16xf32> to vector<4x16x8x16xf32>
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<4x16x8x16xf32> to vector<64x128xf32>
- // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x127xf32>
// CHECK: %[[C00:.*]] = arith.constant 0 : index
- // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[EMPT]]{{\[}}%[[C00]], %[[C00]]]
+ // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]
// CHECK-SAME: {in_bounds = [true, false]} : vector<64x128xf32>, tensor<64x127xf32>
// CHECK: return %[[WRIT]] : tensor<64x127xf32>
%0 = linalg.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %dest : tensor<8x4x16x16xf32> -> tensor<64x127xf32>
@@ -1374,18 +1381,20 @@ func.func @test_vectorize_unpack_no_vector_sizes_slice_output(%source: tensor<8x
// -----
+// CHECK-LABEL: test_vectorize_unpack_no_vector_sizes_permute
+// CHECK-SAME: %[[SRC:.*]]: tensor<4x7x4xf32>
+// CHECK-SAME: %[[DEST:.*]]: tensor<7x16xf32>
func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf32>, %dest: tensor<7x16xf32>) -> tensor<7x16xf32> {
%0 = linalg.unpack %source outer_dims_perm=[1, 0] inner_dims_pos = [1] inner_tiles = [4] into %dest : tensor<4x7x4xf32> -> tensor<7x16xf32>
return %0 : tensor<7x16xf32>
}
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<4x7x4xf32>, vector<4x7x4xf32>
+ // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<4x7x4xf32>, vector<4x7x4xf32>
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 0, 2] : vector<4x7x4xf32> to vector<7x4x4xf32>
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<7x4x4xf32> to vector<7x16xf32>
- // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<7x16xf32>
// CHECK: %[[C00:.*]] = arith.constant 0 : index
- // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<7x16xf32>, tensor<7x16xf32>
+ // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<7x16xf32>, tensor<7x16xf32>
// CHECK: return %[[WRIT]] : tensor<7x16xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
More information about the Mlir-commits
mailing list