[Mlir-commits] [mlir] bce951c - [mlir][linalg] Update vectorization logic for linalg.unpack (#149156)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 17 01:14:20 PDT 2025


Author: Andrzej WarzyƄski
Date: 2025-07-17T09:14:17+01:00
New Revision: bce951c572465c6ccd59b73a58c536641abc43eb

URL: https://github.com/llvm/llvm-project/commit/bce951c572465c6ccd59b73a58c536641abc43eb
DIFF: https://github.com/llvm/llvm-project/commit/bce951c572465c6ccd59b73a58c536641abc43eb.diff

LOG: [mlir][linalg] Update vectorization logic for linalg.unpack (#149156)

This PR makes sure that we don't generate unnecessary `tensor.empty`
when vectorizing `linalg.unpack`.

To better visualize the changes implemented here, consider this IR:
```mlir
func.func @example(
  %source: tensor<8x4x16x16xf32>,
  %dest: tensor<64x127xf32>) -> tensor<64x127xf32> {

    %res = linalg.unpack %source
      outer_dims_perm = [1, 0]
      inner_dims_pos = [0, 1]
      inner_tiles = [16, 16]
    into %dest : tensor<8x4x16x16xf32> -> tensor<64x127xf32>

    return %res : tensor<64x127xf32>
 }
```

Below is the output after vectorization, BEFORE and AFTER this PR.

BEFORE (note `tensor.empty` and the fact that `%arg1` is not used):
```mlir
  func.func @example(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<64x127xf32>) -> tensor<64x127xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %c0 = arith.constant 0 : index
    %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32>
    %1 = vector.transpose %0, [1, 2, 0, 3] : vector<8x4x16x16xf32> to vector<4x16x8x16xf32>
    %2 = vector.shape_cast %1 : vector<4x16x8x16xf32> to vector<64x128xf32>
    %3 = tensor.empty() : tensor<64x127xf32>
    %c0_0 = arith.constant 0 : index
    %4 = vector.transfer_write %2, %3[%c0_0, %c0_0] {in_bounds = [true, false]} : vector<64x128xf32>, tensor<64x127xf32>
    return %4 : tensor<64x127xf32>
  }
```

AFTER (note that `%arg1` is correctly used):
```mlir
  func.func @example(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<64x127xf32>) -> tensor<64x127xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %c0 = arith.constant 0 : index
    %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32>
    %1 = vector.transpose %0, [1, 2, 0, 3] : vector<8x4x16x16xf32> to vector<4x16x8x16xf32>
    %2 = vector.shape_cast %1 : vector<4x16x8x16xf32> to vector<64x128xf32>
    %c0_0 = arith.constant 0 : index
    %3 = vector.transfer_write %2, %arg1[%c0_0, %c0_0] {in_bounds = [true, false]} : vector<64x128xf32>, tensor<64x127xf32>
    return %3 : tensor<64x127xf32>
  }
```

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5a8c5eab3f444..458ed543b8216 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1928,11 +1928,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
       unpackOp.getDestType().hasStaticShape()
           ? vectorSizes
           : shapeCastOp.getResultVectorType().getShape());
-  Value dest = rewriter.create<tensor::EmptyOp>(
-      loc, reifiedRetShapes[0],
-      shapeCastOp.getResult().getType().getElementType());
   Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, shapeCastOp.getResult(), dest,
+      rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
       /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
   newResults.push_back(write->getResult(0));
   return success();

diff  --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 9e501affdd2a5..679adf0a52175 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -1158,6 +1158,7 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 // CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack
+// CHECK-SAME:      %[[ARG_0:.*]]: tensor<?x?xf32>,
 func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
 // CHECK: %[[C0:.*]] = arith.constant 0
 // CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32>
@@ -1175,9 +1176,8 @@ func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: t
 // CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
 // CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
 // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x2x1x16xf32> to vector<4x16xf32>
-// CHECK: %[[empt0:.*]] = tensor.empty
 // CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<4x16xi1>
-// CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
+// CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[ARG_0]]
 // CHECK: return %[[write0]]
  %ret = linalg.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
  return %ret : tensor<?x?xf32>
@@ -1193,6 +1193,8 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 // CHECK-LABEL: func @test_vectorize_unpack
+// CHECK-SAME:      %[[SRC:.*]]: tensor<8x8x32x16xf32>
+// CHECK-SAME:      %[[DEST:.*]]: tensor<256x128xf32>
 func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
     // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
     // CHECK: %[[C0:.*]]= arith.constant 0 : index
@@ -1201,15 +1203,14 @@ func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<2
     // CHECK: %[[C32:.*]] = arith.constant 32 : index
     // CHECK: %[[C16:.*]] = arith.constant 16 : index
     // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<16x8x32x16xi1>
-    // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32>
+    // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] { vector.transfer_read %[[SRC]]{{.*}}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32>
     // CHECK: %[[TRANSP0:.*]] = vector.transpose %[[READ0]], [0, 2, 1, 3] : vector<16x8x32x16xf32> to vector<16x32x8x16xf32>
     // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP0]] : vector<16x32x8x16xf32> to vector<512x128xf32>
-    // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
     // CHECK: %[[C01:.*]] = arith.constant 0 : index
     // CHECK: %[[C256:.*]] = arith.constant 256 : index
     // CHECK: %[[C128:.*]] = arith.constant 128 : index
     // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C256]], %[[C128]] : vector<512x128xi1>
-    // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<512x128xi1> -> tensor<256x128xf32>
+    // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] { vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<512x128xi1> -> tensor<256x128xf32>
     // CHECK: return %[[WRIT]] : tensor<256x128xf32>
    %0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
    return %0 : tensor<256x128xf32>
@@ -1225,15 +1226,16 @@ func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<2
 // -----
 
 // CHECK-LABEL: func @test_vectorize_unpack_no_masks
+// CHECK-SAME:      %[[SRC:.*]]: tensor<8x8x32x16xf32>
+// CHECK-SAME:      %[[DEST:.*]]: tensor<256x128xf32>
 func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
   // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+  // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32> 
   // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
   // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
-  // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
   // CHECK: %[[C00:.*]] = arith.constant 0 : index
-  // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
+  // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<256x128xf32>, tensor<256x128xf32> 
   // CHECK: return %[[WRIT]] : tensor<256x128xf32>
    %0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
    return %0 : tensor<256x128xf32>
@@ -1248,16 +1250,17 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
 
   // -----
 
-  // CHECK-LABEL: test_vectorize_unpack_with_outer_perm
+// CHECK-LABEL: test_vectorize_unpack_with_outer_perm
+// CHECK-SAME:      %[[SRC:.*]]: tensor<8x8x32x16xf32>
+// CHECK-SAME:      %[[DEST:.*]]: tensor<256x128xf32>
   func.func @test_vectorize_unpack_with_outer_perm(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
   // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+  // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32> 
   // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
   // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
-  // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
   // CHECK: %[[C00:.*]] = arith.constant 0 : index
-  // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
+  // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<256x128xf32>, tensor<256x128xf32> 
   // CHECK: return %[[WRIT]] : tensor<256x128xf32>
    %0 = linalg.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
    return %0 : tensor<256x128xf32>
@@ -1327,15 +1330,17 @@ module attributes {transform.with_named_sequence} {
 
   // -----
 
+// CHECK-LABEL: test_vectorize_unpack_no_vector_sizes
+// CHECK-SAME:      %[[SRC:.*]]: tensor<8x8x32x16xf32>
+// CHECK-SAME:      %[[DEST:.*]]: tensor<256x128xf32>
 func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
   // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+  // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32> 
   // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
   // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
-  // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
   // CHECK: %[[C00:.*]] = arith.constant 0 : index
-  // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
+  // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<256x128xf32>, tensor<256x128xf32> 
   // CHECK: return %[[WRIT]] : tensor<256x128xf32>
    %0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
    return %0 : tensor<256x128xf32>
@@ -1350,15 +1355,17 @@ func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>,
 
   // -----
 
+// CHECK-LABEL: test_vectorize_unpack_no_vector_sizes_slice_output
+// CHECK-SAME:      %[[SRC:.*]]: tensor<8x4x16x16xf32>
+// CHECK-SAME:      %[[DEST:.*]]: tensor<64x127xf32>
 func.func @test_vectorize_unpack_no_vector_sizes_slice_output(%source: tensor<8x4x16x16xf32>, %dest: tensor<64x127xf32>) -> tensor<64x127xf32> {
   //      CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
   //      CHECK: %[[C0:.*]] = arith.constant 0 : index
-  //      CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32>
+  //      CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32>
   //      CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x4x16x16xf32> to vector<4x16x8x16xf32>
   //      CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<4x16x8x16xf32> to vector<64x128xf32>
-  //      CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x127xf32>
   //      CHECK: %[[C00:.*]] = arith.constant 0 : index
-  //      CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[EMPT]]{{\[}}%[[C00]], %[[C00]]]
+  //      CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]
   // CHECK-SAME:  {in_bounds = [true, false]} : vector<64x128xf32>, tensor<64x127xf32>
   //      CHECK: return %[[WRIT]] : tensor<64x127xf32>
    %0 = linalg.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %dest : tensor<8x4x16x16xf32> -> tensor<64x127xf32>
@@ -1374,18 +1381,20 @@ func.func @test_vectorize_unpack_no_vector_sizes_slice_output(%source: tensor<8x
 
 // -----
 
+// CHECK-LABEL: test_vectorize_unpack_no_vector_sizes_permute
+// CHECK-SAME:      %[[SRC:.*]]:  tensor<4x7x4xf32>
+// CHECK-SAME:      %[[DEST:.*]]:  tensor<7x16xf32>
 func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf32>, %dest: tensor<7x16xf32>) -> tensor<7x16xf32> {
    %0 = linalg.unpack %source outer_dims_perm=[1, 0] inner_dims_pos = [1] inner_tiles = [4] into %dest : tensor<4x7x4xf32> -> tensor<7x16xf32>
    return %0 : tensor<7x16xf32>
  }
   // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<4x7x4xf32>, vector<4x7x4xf32>
+  // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<4x7x4xf32>, vector<4x7x4xf32>
   // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 0, 2] : vector<4x7x4xf32> to vector<7x4x4xf32>
   // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<7x4x4xf32> to vector<7x16xf32>
-  // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<7x16xf32>
   // CHECK: %[[C00:.*]] = arith.constant 0 : index
-  // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<7x16xf32>, tensor<7x16xf32>
+  // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<7x16xf32>, tensor<7x16xf32> 
   // CHECK: return %[[WRIT]] : tensor<7x16xf32>
  module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {


        


More information about the Mlir-commits mailing list