[Mlir-commits] [mlir] feat(vector): drop unit dims from memrefs for xfer_read/write for non… (PR #187076)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 18 05:34:10 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Ege Beysel (egebeysel)

<details>
<summary>Changes</summary>

Handles the case where the mask does not need to be trimmed, i.e. it's already equal to the reduced vector type, for `XferRead/WriteDropUnitDims` patterns.

---
Full diff: https://github.com/llvm/llvm-project/pull/187076.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+12-6) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir (+40) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 19db8b3b48a25..6bccb1999e28c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -532,12 +532,15 @@ class TransferReadDropUnitDimsPattern
       }
       FailureOr<Value> rankReducedCreateMask =
           createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
-      if (failed(rankReducedCreateMask)) {
+      if (succeeded(rankReducedCreateMask)) {
+        maskOp = *rankReducedCreateMask;
+        LDBG() << "  -> Successfully reduced mask dimensions";
+      } else if (createMaskOp.getVectorType().getRank() !=
+                 reducedVectorType.getRank()) {
+        // Mask needs reduction but couldn't be reduced.
         LDBG() << "  -> Failed to reduce mask dimensions";
         return failure();
       }
-      maskOp = *rankReducedCreateMask;
-      LDBG() << "  -> Successfully reduced mask dimensions";
     }
 
     LDBG() << "  -> Creating rank-reduced subview and new transfer_read";
@@ -647,12 +650,15 @@ class TransferWriteDropUnitDimsPattern
       }
       FailureOr<Value> rankReducedCreateMask =
           createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
-      if (failed(rankReducedCreateMask)) {
+      if (succeeded(rankReducedCreateMask)) {
+        maskOp = *rankReducedCreateMask;
+        LDBG() << "  -> Successfully reduced mask dimensions";
+      } else if (createMaskOp.getVectorType().getRank() !=
+                 reducedVectorType.getRank()) {
+        // Mask needs reduction but couldn't be reduced.
         LDBG() << "  -> Failed to reduce mask dimensions";
         return failure();
       }
-      maskOp = *rankReducedCreateMask;
-      LDBG() << "  -> Successfully reduced mask dimensions";
     }
     LDBG() << "  -> Creating rank-reduced subview and new transfer_write";
     Value reducedShapeSource =
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index 8234351302f6b..417408aeaf336 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -317,6 +317,46 @@ func.func @masked_transfer_read_dynamic_rank_reducing_scalable_unit_dim(
 //   CHECK-NOT: memref.subview
 //       CHECK: vector.transfer_read {{.*}} vector<[16]x[1]xi8>
 
+/// Memref has unit dims but vector has no unit dims (all scalable). The mask
+/// does not need reduction — only the memref rank should be reduced.
+func.func @masked_transfer_read_memref_only_unit_dims(
+      %arg : memref<1x1x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
+      %mask_dim0 : index, %mask_dim1 : index) -> vector<[4]x[4]xf32> {
+    %c0 = arith.constant 0 : index
+    %pad = arith.constant 0.0 : f32
+    %mask = vector.create_mask %mask_dim0, %mask_dim1 : vector<[4]x[4]xi1>
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %pad, %mask {in_bounds = [true, true]} :
+      memref<1x1x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>, vector<[4]x[4]xf32>
+    return %v : vector<[4]x[4]xf32>
+}
+// CHECK-LABEL: func @masked_transfer_read_memref_only_unit_dims
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x?x?xf32
+//       CHECK:   %[[MASK:.+]] = vector.create_mask {{.*}} : vector<[4]x[4]xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0]
+//  CHECK-SAME:     memref<1x1x?x?xf32, {{.*}}> to memref<?x?xf32, {{.*}}>
+//       CHECK:   vector.transfer_read %[[SUBVIEW]]{{.*}}, %[[MASK]]
+//  CHECK-SAME:     memref<?x?xf32, {{.*}}>, vector<[4]x[4]xf32>
+
+func.func @masked_transfer_write_memref_only_unit_dims(
+      %arg : memref<1x1x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
+      %vec : vector<[4]x[4]xf32>,
+      %mask_dim0 : index, %mask_dim1 : index) {
+    %c0 = arith.constant 0 : index
+    %mask = vector.create_mask %mask_dim0, %mask_dim1 : vector<[4]x[4]xi1>
+    vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0], %mask {in_bounds = [true, true]} :
+      vector<[4]x[4]xf32>, memref<1x1x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>
+    return
+}
+// CHECK-LABEL: func @masked_transfer_write_memref_only_unit_dims
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x?x?xf32
+//  CHECK-SAME:     %[[VEC:.+]]: vector<[4]x[4]xf32>
+//       CHECK:   %[[MASK:.+]] = vector.create_mask {{.*}} : vector<[4]x[4]xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0]
+//  CHECK-SAME:     memref<1x1x?x?xf32, {{.*}}> to memref<?x?xf32, {{.*}}>
+//       CHECK:   vector.transfer_write %[[VEC]], %[[SUBVIEW]]{{.*}}, %[[MASK]]
+//  CHECK-SAME:     vector<[4]x[4]xf32>, memref<?x?xf32, {{.*}}>
+
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">

``````````

</details>


https://github.com/llvm/llvm-project/pull/187076


More information about the Mlir-commits mailing list