[Mlir-commits] [mlir] [mlir][vector] Add support for dropping inner unit dims for transfer_read/write with masks. (PR #188841)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 26 13:57:41 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Han-Chung Wang (hanhanW)
<details>
<summary>Changes</summary>
The revision clears a long-due TODO, which supports the lowering when transfer_read/write ops have mask via inserting a vector.shape_cast op for the masked value.
---
Full diff: https://github.com/llvm/llvm-project/pull/188841.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+26-13)
- (modified) mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir (+33)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index c694f4f58faa1..a2d59010a2901 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1563,10 +1563,6 @@ class DropInnerMostUnitDimsTransferRead
if (readOp.getTransferRank() == 0)
return failure();
- // TODO: support mask.
- if (readOp.getMask())
- return failure();
-
auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
if (!srcType)
return failure();
@@ -1614,12 +1610,22 @@ class DropInnerMostUnitDimsTransferRead
readOp.getBase(), offsets, sizes, strides);
auto permMap = getTransferMinorIdentityMap(
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
+
+ // If there is a mask, shape_cast it to drop the same inner unit dims.
+ Value mask = readOp.getMask();
+ if (mask) {
+ auto maskType = cast<VectorType>(mask.getType());
+ auto reducedMaskType = VectorType::get(
+ maskType.getShape().drop_back(dimsToDrop), maskType.getElementType(),
+ maskType.getScalableDims().drop_back(dimsToDrop));
+ mask = rewriter.createOrFold<vector::ShapeCastOp>(loc, reducedMaskType,
+ mask);
+ }
+
Value result = vector::TransferReadOp::create(
rewriter, loc, resultTargetVecType, rankedReducedView,
readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
- readOp.getPadding(),
- // TODO: support mask.
- /*mask=*/Value(), inBoundsAttr);
+ readOp.getPadding(), mask, inBoundsAttr);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
result);
return success();
@@ -1654,10 +1660,6 @@ class DropInnerMostUnitDimsTransferWrite
if (writeOp.getTransferRank() == 0)
return failure();
- // TODO: support mask.
- if (writeOp.getMask())
- return failure();
-
auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
if (!srcType)
return failure();
@@ -1709,11 +1711,22 @@ class DropInnerMostUnitDimsTransferWrite
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
loc, resultTargetVecType, writeOp.getVector());
+
+ // If there is a mask, shape_cast it to drop the same inner unit dims.
+ Value mask = writeOp.getMask();
+ if (mask) {
+ auto maskType = cast<VectorType>(mask.getType());
+ auto reducedMaskType = VectorType::get(
+ maskType.getShape().drop_back(dimsToDrop), maskType.getElementType(),
+ maskType.getScalableDims().drop_back(dimsToDrop));
+ mask = rewriter.createOrFold<vector::ShapeCastOp>(loc, reducedMaskType,
+ mask);
+ }
+
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
writeOp, shapeCast, rankedReducedView,
writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
- // TODO: support mask.
- /*mask=*/Value(), inBoundsAttr);
+ mask, inBoundsAttr);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 18c28799a62e5..ad1aef9a7cbb2 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -266,6 +266,24 @@ func.func @contiguous_inner_most_dim_with_subview_2d_scalable_inner_dim(%src: me
// -----
+func.func @contiguous_inner_most_masked_read(%src: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %mask: vector<1x8x1xi1>) -> vector<1x8x1xf32>{
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+ %v = vector.transfer_read %src[%c0, %c0, %c0, %c0], %pad, %mask {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32>
+ return %v : vector<1x8x1xf32>
+}
+
+// CHECK: func @contiguous_inner_most_masked_read(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[MASK:.+]]: vector<1x8x1xi1>)
+// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]]
+// CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
+// CHECK: %[[REDUCED_MASK:.+]] = vector.shape_cast %[[MASK]] : vector<1x8x1xi1> to vector<1x8xi1>
+// CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]{{.*}}, %[[REDUCED_MASK]]
+// CHECK-SAME: memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32>
+// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
// NOTE: This is an out-of-bounds access.
func.func @negative_non_unit_inner_vec_dim(%src: memref<4x1xf32>) -> vector<4x8xf32> {
@@ -580,6 +598,21 @@ func.func @contiguous_inner_most_dim_with_subview_2d_scalable(%dest: memref<1000
// -----
+func.func @contiguous_inner_most_masked_write(%dest: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %vec: vector<1x8x1xf32>, %mask: vector<1x8x1xi1>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %dest[%c0, %c0, %c0, %c0], %mask {in_bounds = [true, true, true]} : vector<1x8x1xf32>, memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>
+ return
+}
+
+// CHECK: func @contiguous_inner_most_masked_write(%[[DEST:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[VEC:.+]]: vector<1x8x1xf32>, %[[MASK:.+]]: vector<1x8x1xi1>)
+// CHECK: %[[DEST_0:.+]] = memref.subview %[[DEST]]
+// CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
+// CHECK: %[[REDUCED_VEC:.+]] = vector.shape_cast %[[VEC]] : vector<1x8x1xf32> to vector<1x8xf32>
+// CHECK: %[[REDUCED_MASK:.+]] = vector.shape_cast %[[MASK]] : vector<1x8x1xi1> to vector<1x8xi1>
+// CHECK: vector.transfer_write %[[REDUCED_VEC]], %[[DEST_0]]{{.*}}, %[[REDUCED_MASK]]
+// CHECK-SAME: vector<1x8xf32>, memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
+// -----
+
// NOTE: This is an out-of-bounds access.
func.func @negative_non_unit_inner_vec_dim(%dest: memref<4x1xf32>, %vec: vector<4x8xf32>) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/188841
More information about the Mlir-commits
mailing list