[Mlir-commits] [mlir] feat(vector): drop unit dims from memrefs for xfer_read/write for non… (PR #187076)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 18 05:34:10 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Ege Beysel (egebeysel)
<details>
<summary>Changes</summary>
Handles the case where the mask does not need to be trimmed, i.e. it's already equal to the reduced vector type, for `XferRead/WriteDropUnitDims` patterns.
---
Full diff: https://github.com/llvm/llvm-project/pull/187076.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+12-6)
- (modified) mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir (+40)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 19db8b3b48a25..6bccb1999e28c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -532,12 +532,15 @@ class TransferReadDropUnitDimsPattern
}
FailureOr<Value> rankReducedCreateMask =
createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
- if (failed(rankReducedCreateMask)) {
+ if (succeeded(rankReducedCreateMask)) {
+ maskOp = *rankReducedCreateMask;
+ LDBG() << " -> Successfully reduced mask dimensions";
+ } else if (createMaskOp.getVectorType().getRank() !=
+ reducedVectorType.getRank()) {
+ // Mask needs reduction but couldn't be reduced.
LDBG() << " -> Failed to reduce mask dimensions";
return failure();
}
- maskOp = *rankReducedCreateMask;
- LDBG() << " -> Successfully reduced mask dimensions";
}
LDBG() << " -> Creating rank-reduced subview and new transfer_read";
@@ -647,12 +650,15 @@ class TransferWriteDropUnitDimsPattern
}
FailureOr<Value> rankReducedCreateMask =
createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
- if (failed(rankReducedCreateMask)) {
+ if (succeeded(rankReducedCreateMask)) {
+ maskOp = *rankReducedCreateMask;
+ LDBG() << " -> Successfully reduced mask dimensions";
+ } else if (createMaskOp.getVectorType().getRank() !=
+ reducedVectorType.getRank()) {
+ // Mask needs reduction but couldn't be reduced.
LDBG() << " -> Failed to reduce mask dimensions";
return failure();
}
- maskOp = *rankReducedCreateMask;
- LDBG() << " -> Successfully reduced mask dimensions";
}
LDBG() << " -> Creating rank-reduced subview and new transfer_write";
Value reducedShapeSource =
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index 8234351302f6b..417408aeaf336 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -317,6 +317,46 @@ func.func @masked_transfer_read_dynamic_rank_reducing_scalable_unit_dim(
// CHECK-NOT: memref.subview
// CHECK: vector.transfer_read {{.*}} vector<[16]x[1]xi8>
+/// Memref has unit dims but vector has no unit dims (all scalable). The mask
+/// does not need reduction — only the memref rank should be reduced.
+func.func @masked_transfer_read_memref_only_unit_dims(
+ %arg : memref<1x1x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
+ %mask_dim0 : index, %mask_dim1 : index) -> vector<[4]x[4]xf32> {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+ %mask = vector.create_mask %mask_dim0, %mask_dim1 : vector<[4]x[4]xi1>
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %pad, %mask {in_bounds = [true, true]} :
+ memref<1x1x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>, vector<[4]x[4]xf32>
+ return %v : vector<[4]x[4]xf32>
+}
+// CHECK-LABEL: func @masked_transfer_read_memref_only_unit_dims
+// CHECK-SAME: %[[ARG:.+]]: memref<1x1x?x?xf32
+// CHECK: %[[MASK:.+]] = vector.create_mask {{.*}} : vector<[4]x[4]xi1>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0]
+// CHECK-SAME: memref<1x1x?x?xf32, {{.*}}> to memref<?x?xf32, {{.*}}>
+// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}}, %[[MASK]]
+// CHECK-SAME: memref<?x?xf32, {{.*}}>, vector<[4]x[4]xf32>
+
+func.func @masked_transfer_write_memref_only_unit_dims(
+ %arg : memref<1x1x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
+ %vec : vector<[4]x[4]xf32>,
+ %mask_dim0 : index, %mask_dim1 : index) {
+ %c0 = arith.constant 0 : index
+ %mask = vector.create_mask %mask_dim0, %mask_dim1 : vector<[4]x[4]xi1>
+ vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0], %mask {in_bounds = [true, true]} :
+ vector<[4]x[4]xf32>, memref<1x1x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>
+ return
+}
+// CHECK-LABEL: func @masked_transfer_write_memref_only_unit_dims
+// CHECK-SAME: %[[ARG:.+]]: memref<1x1x?x?xf32
+// CHECK-SAME: %[[VEC:.+]]: vector<[4]x[4]xf32>
+// CHECK: %[[MASK:.+]] = vector.create_mask {{.*}} : vector<[4]x[4]xi1>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0]
+// CHECK-SAME: memref<1x1x?x?xf32, {{.*}}> to memref<?x?xf32, {{.*}}>
+// CHECK: vector.transfer_write %[[VEC]], %[[SUBVIEW]]{{.*}}, %[[MASK]]
+// CHECK-SAME: vector<[4]x[4]xf32>, memref<?x?xf32, {{.*}}>
+
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
``````````
</details>
https://github.com/llvm/llvm-project/pull/187076
More information about the Mlir-commits
mailing list