[Mlir-commits] [mlir] [mlir][vector] Add support for dropping inner unit dims for transfer_read/write with masks. (PR #188841)

Han-Chung Wang llvmlistbot at llvm.org
Fri Mar 27 09:38:55 PDT 2026


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/188841

>From a624f778d35ef4819eca1bd2c67203667bc837c0 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 26 Mar 2026 13:51:24 -0700
Subject: [PATCH 1/2] [mlir][vector] Add support for dropping inner unit dims
 for transfer_read/write with masks.

The revision clears a long-due TODO, which supports the lowering when
transfer_read/write ops have mask via inserting a vector.shape_cast op
for the masked value.

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Vector/Transforms/VectorTransforms.cpp    | 39 ++++++++++++-------
 ...tor-transfer-collapse-inner-most-dims.mlir | 33 ++++++++++++++++
 2 files changed, 59 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index c694f4f58faa1..a2d59010a2901 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1563,10 +1563,6 @@ class DropInnerMostUnitDimsTransferRead
     if (readOp.getTransferRank() == 0)
       return failure();
 
-    // TODO: support mask.
-    if (readOp.getMask())
-      return failure();
-
     auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
     if (!srcType)
       return failure();
@@ -1614,12 +1610,22 @@ class DropInnerMostUnitDimsTransferRead
                                   readOp.getBase(), offsets, sizes, strides);
     auto permMap = getTransferMinorIdentityMap(
         cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
+
+    // If there is a mask, shape_cast it to drop the same inner unit dims.
+    Value mask = readOp.getMask();
+    if (mask) {
+      auto maskType = cast<VectorType>(mask.getType());
+      auto reducedMaskType = VectorType::get(
+          maskType.getShape().drop_back(dimsToDrop), maskType.getElementType(),
+          maskType.getScalableDims().drop_back(dimsToDrop));
+      mask = rewriter.createOrFold<vector::ShapeCastOp>(loc, reducedMaskType,
+                                                        mask);
+    }
+
     Value result = vector::TransferReadOp::create(
         rewriter, loc, resultTargetVecType, rankedReducedView,
         readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
-        readOp.getPadding(),
-        // TODO: support mask.
-        /*mask=*/Value(), inBoundsAttr);
+        readOp.getPadding(), mask, inBoundsAttr);
     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
                                                      result);
     return success();
@@ -1654,10 +1660,6 @@ class DropInnerMostUnitDimsTransferWrite
     if (writeOp.getTransferRank() == 0)
       return failure();
 
-    // TODO: support mask.
-    if (writeOp.getMask())
-      return failure();
-
     auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
     if (!srcType)
       return failure();
@@ -1709,11 +1711,22 @@ class DropInnerMostUnitDimsTransferWrite
 
     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
         loc, resultTargetVecType, writeOp.getVector());
+
+    // If there is a mask, shape_cast it to drop the same inner unit dims.
+    Value mask = writeOp.getMask();
+    if (mask) {
+      auto maskType = cast<VectorType>(mask.getType());
+      auto reducedMaskType = VectorType::get(
+          maskType.getShape().drop_back(dimsToDrop), maskType.getElementType(),
+          maskType.getScalableDims().drop_back(dimsToDrop));
+      mask = rewriter.createOrFold<vector::ShapeCastOp>(loc, reducedMaskType,
+                                                        mask);
+    }
+
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
         writeOp, shapeCast, rankedReducedView,
         writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
-        // TODO: support mask.
-        /*mask=*/Value(), inBoundsAttr);
+        mask, inBoundsAttr);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 18c28799a62e5..ad1aef9a7cbb2 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -266,6 +266,24 @@ func.func @contiguous_inner_most_dim_with_subview_2d_scalable_inner_dim(%src: me
 
 // -----
 
+func.func @contiguous_inner_most_masked_read(%src: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %mask: vector<1x8x1xi1>) -> vector<1x8x1xf32>{
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %v = vector.transfer_read %src[%c0, %c0, %c0, %c0], %pad, %mask {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32>
+  return %v : vector<1x8x1xf32>
+}
+
+//      CHECK: func @contiguous_inner_most_masked_read(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[MASK:.+]]: vector<1x8x1xi1>)
+//      CHECK:   %[[SRC_0:.+]] = memref.subview %[[SRC]]
+// CHECK-SAME:    memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
+//      CHECK:   %[[REDUCED_MASK:.+]] = vector.shape_cast %[[MASK]] : vector<1x8x1xi1> to vector<1x8xi1>
+//      CHECK:   %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]{{.*}}, %[[REDUCED_MASK]]
+// CHECK-SAME:    memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32>
+//      CHECK:   %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
 // NOTE: This is an out-of-bounds access.
 
 func.func @negative_non_unit_inner_vec_dim(%src: memref<4x1xf32>) -> vector<4x8xf32> {
@@ -580,6 +598,21 @@ func.func @contiguous_inner_most_dim_with_subview_2d_scalable(%dest: memref<1000
 
 // -----
 
+func.func @contiguous_inner_most_masked_write(%dest: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %vec: vector<1x8x1xf32>, %mask: vector<1x8x1xi1>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %dest[%c0, %c0, %c0, %c0], %mask {in_bounds = [true, true, true]} : vector<1x8x1xf32>, memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>
+  return
+}
+
+//      CHECK: func @contiguous_inner_most_masked_write(%[[DEST:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[VEC:.+]]: vector<1x8x1xf32>, %[[MASK:.+]]: vector<1x8x1xi1>)
+//      CHECK:   %[[DEST_0:.+]] = memref.subview %[[DEST]]
+// CHECK-SAME:    memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
+//      CHECK:   %[[REDUCED_VEC:.+]] = vector.shape_cast %[[VEC]] : vector<1x8x1xf32> to vector<1x8xf32>
+//      CHECK:   %[[REDUCED_MASK:.+]] = vector.shape_cast %[[MASK]] : vector<1x8x1xi1> to vector<1x8xi1>
+//      CHECK:   vector.transfer_write %[[REDUCED_VEC]], %[[DEST_0]]{{.*}}, %[[REDUCED_MASK]]
+// CHECK-SAME:    vector<1x8xf32>, memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
+// -----
+
 // NOTE: This is an out-of-bounds access.
 
 func.func @negative_non_unit_inner_vec_dim(%dest: memref<4x1xf32>, %vec: vector<4x8xf32>) {

>From 5b0dd142b06b2a2bba8ae55e0543202648111aa7 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 27 Mar 2026 09:38:31 -0700
Subject: [PATCH 2/2] address comments

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../vector-transfer-collapse-inner-most-dims.mlir      | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index ad1aef9a7cbb2..1bedce7ea6a67 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -266,14 +266,13 @@ func.func @contiguous_inner_most_dim_with_subview_2d_scalable_inner_dim(%src: me
 
 // -----
 
-func.func @contiguous_inner_most_masked_read(%src: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %mask: vector<1x8x1xi1>) -> vector<1x8x1xf32>{
+func.func @contiguous_inner_most_with_mask(%src: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %mask: vector<1x8x1xi1>) -> vector<1x8x1xf32>{
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f32
   %v = vector.transfer_read %src[%c0, %c0, %c0, %c0], %pad, %mask {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32>
   return %v : vector<1x8x1xf32>
 }
-
-//      CHECK: func @contiguous_inner_most_masked_read(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[MASK:.+]]: vector<1x8x1xi1>)
+//      CHECK: func @contiguous_inner_most_with_mask(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[MASK:.+]]: vector<1x8x1xi1>)
 //      CHECK:   %[[SRC_0:.+]] = memref.subview %[[SRC]]
 // CHECK-SAME:    memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
 //      CHECK:   %[[REDUCED_MASK:.+]] = vector.shape_cast %[[MASK]] : vector<1x8x1xi1> to vector<1x8xi1>
@@ -598,13 +597,12 @@ func.func @contiguous_inner_most_dim_with_subview_2d_scalable(%dest: memref<1000
 
 // -----
 
-func.func @contiguous_inner_most_masked_write(%dest: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %vec: vector<1x8x1xf32>, %mask: vector<1x8x1xi1>) {
+func.func @contiguous_inner_most_with_mask(%dest: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %vec: vector<1x8x1xf32>, %mask: vector<1x8x1xi1>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %vec, %dest[%c0, %c0, %c0, %c0], %mask {in_bounds = [true, true, true]} : vector<1x8x1xf32>, memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>
   return
 }
-
-//      CHECK: func @contiguous_inner_most_masked_write(%[[DEST:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[VEC:.+]]: vector<1x8x1xf32>, %[[MASK:.+]]: vector<1x8x1xi1>)
+//      CHECK: func @contiguous_inner_most_with_mask(%[[DEST:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[VEC:.+]]: vector<1x8x1xf32>, %[[MASK:.+]]: vector<1x8x1xi1>)
 //      CHECK:   %[[DEST_0:.+]] = memref.subview %[[DEST]]
 // CHECK-SAME:    memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
 //      CHECK:   %[[REDUCED_VEC:.+]] = vector.shape_cast %[[VEC]] : vector<1x8x1xf32> to vector<1x8xf32>



More information about the Mlir-commits mailing list