[Mlir-commits] [mlir] [mlir][vector] Add support for dropping inner unit dims for transfer_read/write with masks. (PR #188841)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 26 13:57:41 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

<details>
<summary>Changes</summary>

The revision clears a long-due TODO, which supports the lowering when transfer_read/write ops have mask via inserting a vector.shape_cast op for the masked value.

---
Full diff: https://github.com/llvm/llvm-project/pull/188841.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+26-13) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir (+33) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index c694f4f58faa1..a2d59010a2901 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1563,10 +1563,6 @@ class DropInnerMostUnitDimsTransferRead
     if (readOp.getTransferRank() == 0)
       return failure();
 
-    // TODO: support mask.
-    if (readOp.getMask())
-      return failure();
-
     auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
     if (!srcType)
       return failure();
@@ -1614,12 +1610,22 @@ class DropInnerMostUnitDimsTransferRead
                                   readOp.getBase(), offsets, sizes, strides);
     auto permMap = getTransferMinorIdentityMap(
         cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
+
+    // If there is a mask, shape_cast it to drop the same inner unit dims.
+    Value mask = readOp.getMask();
+    if (mask) {
+      auto maskType = cast<VectorType>(mask.getType());
+      auto reducedMaskType = VectorType::get(
+          maskType.getShape().drop_back(dimsToDrop), maskType.getElementType(),
+          maskType.getScalableDims().drop_back(dimsToDrop));
+      mask = rewriter.createOrFold<vector::ShapeCastOp>(loc, reducedMaskType,
+                                                        mask);
+    }
+
     Value result = vector::TransferReadOp::create(
         rewriter, loc, resultTargetVecType, rankedReducedView,
         readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
-        readOp.getPadding(),
-        // TODO: support mask.
-        /*mask=*/Value(), inBoundsAttr);
+        readOp.getPadding(), mask, inBoundsAttr);
     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
                                                      result);
     return success();
@@ -1654,10 +1660,6 @@ class DropInnerMostUnitDimsTransferWrite
     if (writeOp.getTransferRank() == 0)
       return failure();
 
-    // TODO: support mask.
-    if (writeOp.getMask())
-      return failure();
-
     auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
     if (!srcType)
       return failure();
@@ -1709,11 +1711,22 @@ class DropInnerMostUnitDimsTransferWrite
 
     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
         loc, resultTargetVecType, writeOp.getVector());
+
+    // If there is a mask, shape_cast it to drop the same inner unit dims.
+    Value mask = writeOp.getMask();
+    if (mask) {
+      auto maskType = cast<VectorType>(mask.getType());
+      auto reducedMaskType = VectorType::get(
+          maskType.getShape().drop_back(dimsToDrop), maskType.getElementType(),
+          maskType.getScalableDims().drop_back(dimsToDrop));
+      mask = rewriter.createOrFold<vector::ShapeCastOp>(loc, reducedMaskType,
+                                                        mask);
+    }
+
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
         writeOp, shapeCast, rankedReducedView,
         writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
-        // TODO: support mask.
-        /*mask=*/Value(), inBoundsAttr);
+        mask, inBoundsAttr);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 18c28799a62e5..ad1aef9a7cbb2 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -266,6 +266,24 @@ func.func @contiguous_inner_most_dim_with_subview_2d_scalable_inner_dim(%src: me
 
 // -----
 
+func.func @contiguous_inner_most_masked_read(%src: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %mask: vector<1x8x1xi1>) -> vector<1x8x1xf32>{
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %v = vector.transfer_read %src[%c0, %c0, %c0, %c0], %pad, %mask {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32>
+  return %v : vector<1x8x1xf32>
+}
+
+//      CHECK: func @contiguous_inner_most_masked_read(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[MASK:.+]]: vector<1x8x1xi1>)
+//      CHECK:   %[[SRC_0:.+]] = memref.subview %[[SRC]]
+// CHECK-SAME:    memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
+//      CHECK:   %[[REDUCED_MASK:.+]] = vector.shape_cast %[[MASK]] : vector<1x8x1xi1> to vector<1x8xi1>
+//      CHECK:   %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]{{.*}}, %[[REDUCED_MASK]]
+// CHECK-SAME:    memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32>
+//      CHECK:   %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
 // NOTE: This is an out-of-bounds access.
 
 func.func @negative_non_unit_inner_vec_dim(%src: memref<4x1xf32>) -> vector<4x8xf32> {
@@ -580,6 +598,21 @@ func.func @contiguous_inner_most_dim_with_subview_2d_scalable(%dest: memref<1000
 
 // -----
 
+func.func @contiguous_inner_most_masked_write(%dest: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %vec: vector<1x8x1xf32>, %mask: vector<1x8x1xi1>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %dest[%c0, %c0, %c0, %c0], %mask {in_bounds = [true, true, true]} : vector<1x8x1xf32>, memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>
+  return
+}
+
+//      CHECK: func @contiguous_inner_most_masked_write(%[[DEST:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[VEC:.+]]: vector<1x8x1xf32>, %[[MASK:.+]]: vector<1x8x1xi1>)
+//      CHECK:   %[[DEST_0:.+]] = memref.subview %[[DEST]]
+// CHECK-SAME:    memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
+//      CHECK:   %[[REDUCED_VEC:.+]] = vector.shape_cast %[[VEC]] : vector<1x8x1xf32> to vector<1x8xf32>
+//      CHECK:   %[[REDUCED_MASK:.+]] = vector.shape_cast %[[MASK]] : vector<1x8x1xi1> to vector<1x8xi1>
+//      CHECK:   vector.transfer_write %[[REDUCED_VEC]], %[[DEST_0]]{{.*}}, %[[REDUCED_MASK]]
+// CHECK-SAME:    vector<1x8xf32>, memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
+// -----
+
 // NOTE: This is an out-of-bounds access.
 
 func.func @negative_non_unit_inner_vec_dim(%dest: memref<4x1xf32>, %vec: vector<4x8xf32>) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/188841


More information about the Mlir-commits mailing list