[Mlir-commits] [mlir] [mlir][vector] Add support for dropping inner unit dims for transfer_read/write with masks. (PR #188841)
Han-Chung Wang
llvmlistbot at llvm.org
Fri Mar 27 09:38:55 PDT 2026
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/188841
>From a624f778d35ef4819eca1bd2c67203667bc837c0 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 26 Mar 2026 13:51:24 -0700
Subject: [PATCH 1/2] [mlir][vector] Add support for dropping inner unit dims
for transfer_read/write with masks.
The revision clears a long-due TODO, which supports the lowering when
transfer_read/write ops have mask via inserting a vector.shape_cast op
for the masked value.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../Vector/Transforms/VectorTransforms.cpp | 39 ++++++++++++-------
...tor-transfer-collapse-inner-most-dims.mlir | 33 ++++++++++++++++
2 files changed, 59 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index c694f4f58faa1..a2d59010a2901 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1563,10 +1563,6 @@ class DropInnerMostUnitDimsTransferRead
if (readOp.getTransferRank() == 0)
return failure();
- // TODO: support mask.
- if (readOp.getMask())
- return failure();
-
auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
if (!srcType)
return failure();
@@ -1614,12 +1610,22 @@ class DropInnerMostUnitDimsTransferRead
readOp.getBase(), offsets, sizes, strides);
auto permMap = getTransferMinorIdentityMap(
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
+
+ // If there is a mask, shape_cast it to drop the same inner unit dims.
+ Value mask = readOp.getMask();
+ if (mask) {
+ auto maskType = cast<VectorType>(mask.getType());
+ auto reducedMaskType = VectorType::get(
+ maskType.getShape().drop_back(dimsToDrop), maskType.getElementType(),
+ maskType.getScalableDims().drop_back(dimsToDrop));
+ mask = rewriter.createOrFold<vector::ShapeCastOp>(loc, reducedMaskType,
+ mask);
+ }
+
Value result = vector::TransferReadOp::create(
rewriter, loc, resultTargetVecType, rankedReducedView,
readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
- readOp.getPadding(),
- // TODO: support mask.
- /*mask=*/Value(), inBoundsAttr);
+ readOp.getPadding(), mask, inBoundsAttr);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
result);
return success();
@@ -1654,10 +1660,6 @@ class DropInnerMostUnitDimsTransferWrite
if (writeOp.getTransferRank() == 0)
return failure();
- // TODO: support mask.
- if (writeOp.getMask())
- return failure();
-
auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
if (!srcType)
return failure();
@@ -1709,11 +1711,22 @@ class DropInnerMostUnitDimsTransferWrite
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
loc, resultTargetVecType, writeOp.getVector());
+
+ // If there is a mask, shape_cast it to drop the same inner unit dims.
+ Value mask = writeOp.getMask();
+ if (mask) {
+ auto maskType = cast<VectorType>(mask.getType());
+ auto reducedMaskType = VectorType::get(
+ maskType.getShape().drop_back(dimsToDrop), maskType.getElementType(),
+ maskType.getScalableDims().drop_back(dimsToDrop));
+ mask = rewriter.createOrFold<vector::ShapeCastOp>(loc, reducedMaskType,
+ mask);
+ }
+
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
writeOp, shapeCast, rankedReducedView,
writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
- // TODO: support mask.
- /*mask=*/Value(), inBoundsAttr);
+ mask, inBoundsAttr);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 18c28799a62e5..ad1aef9a7cbb2 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -266,6 +266,24 @@ func.func @contiguous_inner_most_dim_with_subview_2d_scalable_inner_dim(%src: me
// -----
+func.func @contiguous_inner_most_masked_read(%src: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %mask: vector<1x8x1xi1>) -> vector<1x8x1xf32>{
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+ %v = vector.transfer_read %src[%c0, %c0, %c0, %c0], %pad, %mask {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32>
+ return %v : vector<1x8x1xf32>
+}
+
+// CHECK: func @contiguous_inner_most_masked_read(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[MASK:.+]]: vector<1x8x1xi1>)
+// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]]
+// CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
+// CHECK: %[[REDUCED_MASK:.+]] = vector.shape_cast %[[MASK]] : vector<1x8x1xi1> to vector<1x8xi1>
+// CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]{{.*}}, %[[REDUCED_MASK]]
+// CHECK-SAME: memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32>
+// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
// NOTE: This is an out-of-bounds access.
func.func @negative_non_unit_inner_vec_dim(%src: memref<4x1xf32>) -> vector<4x8xf32> {
@@ -580,6 +598,21 @@ func.func @contiguous_inner_most_dim_with_subview_2d_scalable(%dest: memref<1000
// -----
+func.func @contiguous_inner_most_masked_write(%dest: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %vec: vector<1x8x1xf32>, %mask: vector<1x8x1xi1>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %dest[%c0, %c0, %c0, %c0], %mask {in_bounds = [true, true, true]} : vector<1x8x1xf32>, memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>
+ return
+}
+
+// CHECK: func @contiguous_inner_most_masked_write(%[[DEST:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[VEC:.+]]: vector<1x8x1xf32>, %[[MASK:.+]]: vector<1x8x1xi1>)
+// CHECK: %[[DEST_0:.+]] = memref.subview %[[DEST]]
+// CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
+// CHECK: %[[REDUCED_VEC:.+]] = vector.shape_cast %[[VEC]] : vector<1x8x1xf32> to vector<1x8xf32>
+// CHECK: %[[REDUCED_MASK:.+]] = vector.shape_cast %[[MASK]] : vector<1x8x1xi1> to vector<1x8xi1>
+// CHECK: vector.transfer_write %[[REDUCED_VEC]], %[[DEST_0]]{{.*}}, %[[REDUCED_MASK]]
+// CHECK-SAME: vector<1x8xf32>, memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
+// -----
+
// NOTE: This is an out-of-bounds access.
func.func @negative_non_unit_inner_vec_dim(%dest: memref<4x1xf32>, %vec: vector<4x8xf32>) {
>From 5b0dd142b06b2a2bba8ae55e0543202648111aa7 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 27 Mar 2026 09:38:31 -0700
Subject: [PATCH 2/2] address comments
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../vector-transfer-collapse-inner-most-dims.mlir | 10 ++++------
1 file changed, 4 insertions(+), 6 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index ad1aef9a7cbb2..1bedce7ea6a67 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -266,14 +266,13 @@ func.func @contiguous_inner_most_dim_with_subview_2d_scalable_inner_dim(%src: me
// -----
-func.func @contiguous_inner_most_masked_read(%src: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %mask: vector<1x8x1xi1>) -> vector<1x8x1xf32>{
+func.func @contiguous_inner_most_with_mask(%src: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %mask: vector<1x8x1xi1>) -> vector<1x8x1xf32>{
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
%v = vector.transfer_read %src[%c0, %c0, %c0, %c0], %pad, %mask {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32>
return %v : vector<1x8x1xf32>
}
-
-// CHECK: func @contiguous_inner_most_masked_read(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[MASK:.+]]: vector<1x8x1xi1>)
+// CHECK: func @contiguous_inner_most_with_mask(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[MASK:.+]]: vector<1x8x1xi1>)
// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]]
// CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
// CHECK: %[[REDUCED_MASK:.+]] = vector.shape_cast %[[MASK]] : vector<1x8x1xi1> to vector<1x8xi1>
@@ -598,13 +597,12 @@ func.func @contiguous_inner_most_dim_with_subview_2d_scalable(%dest: memref<1000
// -----
-func.func @contiguous_inner_most_masked_write(%dest: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %vec: vector<1x8x1xf32>, %mask: vector<1x8x1xi1>) {
+func.func @contiguous_inner_most_with_mask(%dest: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %vec: vector<1x8x1xf32>, %mask: vector<1x8x1xi1>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %dest[%c0, %c0, %c0, %c0], %mask {in_bounds = [true, true, true]} : vector<1x8x1xf32>, memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>
return
}
-
-// CHECK: func @contiguous_inner_most_masked_write(%[[DEST:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[VEC:.+]]: vector<1x8x1xf32>, %[[MASK:.+]]: vector<1x8x1xi1>)
+// CHECK: func @contiguous_inner_most_with_mask(%[[DEST:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[VEC:.+]]: vector<1x8x1xf32>, %[[MASK:.+]]: vector<1x8x1xi1>)
// CHECK: %[[DEST_0:.+]] = memref.subview %[[DEST]]
// CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
// CHECK: %[[REDUCED_VEC:.+]] = vector.shape_cast %[[VEC]] : vector<1x8x1xf32> to vector<1x8xf32>
More information about the Mlir-commits
mailing list