[Mlir-commits] [mlir] [mlir][vector] Drop trailing 1-dims from constant_mask (PR #187383)

Erick Ochoa Lopez llvmlistbot at llvm.org
Wed Apr 1 05:52:57 PDT 2026


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/187383

>From 4300e61d598fb44d4dc1055f1d1efa5400d1a0b4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 18 Mar 2026 17:22:10 -0400
Subject: [PATCH 1/7] [mlir][vector] Drop trailing 1-dims from constant_mask

---
 .../Transforms/VectorTransferOpTransforms.cpp | 93 +++++++++++--------
 ...ctor-transfer-drop-unit-dims-patterns.mlir | 50 +++++++++-
 2 files changed, 102 insertions(+), 41 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 19db8b3b48a25..444d9d0aff251 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -429,30 +429,40 @@ static VectorType trimNonScalableUnitDims(VectorType oldType) {
   return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
 }
 
-// Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
-static FailureOr<Value>
-createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
-                                  vector::CreateMaskOp op) {
+static auto getMaskDimData(vector::CreateMaskOp op) { return op.getOperands(); }
+static auto getMaskDimData(vector::ConstantMaskOp op) {
+  return op.getMaskDimSizes();
+}
+
+static bool isUnitDimMask(Value operand) {
+  auto cst = operand.getDefiningOp<arith::ConstantIndexOp>();
+  return cst && cst.value() == 1;
+}
+static bool isUnitDimMask(int64_t maskDimSize) { return maskDimSize == 1; }
+
+/// Rewrites a mask op to drop non-scalable unit dimensions.
+/// Supports vector.create_mask and vector.constant_mask.
+template <typename MaskOp>
+static FailureOr<Value> maskDropNonScalableUnitDims(PatternRewriter &rewriter,
+                                                    Location loc, MaskOp op) {
   auto type = op.getType();
   VectorType reducedType = trimNonScalableUnitDims(type);
   if (reducedType.getRank() == type.getRank())
     return failure();
 
-  SmallVector<Value> reducedOperands;
-  for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
-           type.getShape(), type.getScalableDims(), op.getOperands())) {
+  auto maskDimData = getMaskDimData(op);
+  using ElemType = std::decay_t<decltype(*maskDimData.begin())>;
+  SmallVector<ElemType> reduced;
+  for (auto [dim, dimIsScalable, elem] :
+       llvm::zip_equal(type.getShape(), type.getScalableDims(), maskDimData)) {
     if (dim == 1 && !dimIsScalable) {
-      // If the mask for the unit dim is not a constant of 1, do nothing.
-      auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
-      if (!constant || (constant.value() != 1))
+      if (!isUnitDimMask(elem))
         return failure();
       continue;
     }
-    reducedOperands.push_back(operand);
+    reduced.push_back(elem);
   }
-  return vector::CreateMaskOp::create(rewriter, loc, reducedType,
-                                      reducedOperands)
-      .getResult();
+  return MaskOp::create(rewriter, loc, reducedType, reduced).getResult();
 }
 
 namespace {
@@ -522,21 +532,23 @@ class TransferReadDropUnitDimsPattern
     Value maskOp = transferReadOp.getMask();
     if (maskOp) {
       LDBG() << "  -> Processing mask operation";
-      auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
-      if (!createMaskOp) {
-        LDBG()
-            << "  -> Unsupported mask op, only 'vector.create_mask' supported";
-        return rewriter.notifyMatchFailure(
-            transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
-                            "currently supported");
-      }
-      FailureOr<Value> rankReducedCreateMask =
-          createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
-      if (failed(rankReducedCreateMask)) {
+      FailureOr<Value> rankReducedMask = failure();
+      if (auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>())
+        rankReducedMask =
+            maskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
+      else if (auto constantMaskOp =
+                   maskOp.getDefiningOp<vector::ConstantMaskOp>())
+        rankReducedMask =
+            maskDropNonScalableUnitDims(rewriter, loc, constantMaskOp);
+
+      if (failed(rankReducedMask)) {
         LDBG() << "  -> Failed to reduce mask dimensions";
-        return failure();
+        return rewriter.notifyMatchFailure(
+            transferReadOp,
+            "unsupported mask op, only 'vector.create_mask' and "
+            "'vector.constant_mask' are currently supported");
       }
-      maskOp = *rankReducedCreateMask;
+      maskOp = *rankReducedMask;
       LDBG() << "  -> Successfully reduced mask dimensions";
     }
 
@@ -636,22 +648,23 @@ class TransferWriteDropUnitDimsPattern
     Value maskOp = transferWriteOp.getMask();
     if (maskOp) {
       LDBG() << "  -> Processing mask operation";
-      auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
-      if (!createMaskOp) {
-        LDBG()
-            << "  -> Unsupported mask op, only 'vector.create_mask' supported";
+      FailureOr<Value> rankReducedMask = failure();
+      if (auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>())
+        rankReducedMask =
+            maskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
+      else if (auto constantMaskOp =
+                   maskOp.getDefiningOp<vector::ConstantMaskOp>())
+        rankReducedMask =
+            maskDropNonScalableUnitDims(rewriter, loc, constantMaskOp);
+
+      if (failed(rankReducedMask)) {
+        LDBG() << "  -> Failed to reduce mask dimensions";
         return rewriter.notifyMatchFailure(
             transferWriteOp,
-            "unsupported mask op, only 'vector.create_mask' is "
-            "currently supported");
-      }
-      FailureOr<Value> rankReducedCreateMask =
-          createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
-      if (failed(rankReducedCreateMask)) {
-        LDBG() << "  -> Failed to reduce mask dimensions";
-        return failure();
+            "unsupported mask op, only 'vector.create_mask' and "
+            "'vector.constant_mask' are currently supported");
       }
-      maskOp = *rankReducedCreateMask;
+      maskOp = *rankReducedMask;
       LDBG() << "  -> Successfully reduced mask dimensions";
     }
     LDBG() << "  -> Creating rank-reduced subview and new transfer_write";
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index 8234351302f6b..e09192e73b6f4 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -267,7 +267,55 @@ func.func @masked_transfer_write_dynamic_rank_reducing(
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
 //       CHECK:   vector.transfer_write {{.*}}, %[[SUBVIEW]][%[[C0]]], %[[MASK]] {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}>
 
-/// Only masks operands of vector.create_mask are currently supported.
+func.func @constant_masked_transfer_read_rank_reducing(
+      %arg : memref<3x1xi8>) -> vector<3x1xi8> {
+    %c0 = arith.constant 0 : index
+    %pad = arith.constant 0 : i8
+    %mask = vector.constant_mask [2, 1] : vector<3x1xi1>
+    %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
+      memref<3x1xi8>, vector<3x1xi8>
+    return %v : vector<3x1xi8>
+}
+// CHECK-LABEL: func @constant_masked_transfer_read_rank_reducing
+//  CHECK-SAME:     %[[ARG:.+]]: memref<3x1xi8>
+//       CHECK:   %[[MASK:.+]] = vector.constant_mask [2] : vector<3xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [3, 1] [1, 1]
+//  CHECK-SAME:     memref<3x1xi8> to memref<3xi8, {{.*}}>
+//       CHECK:   vector.transfer_read %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true]} : memref<3xi8, {{.*}}>, vector<3xi8>
+
+func.func @constant_masked_transfer_write_rank_reducing(
+      %arg : memref<1x1x3x1x16x1xf32>,
+      %vec : vector<1x3x1x16x1xf32>) {
+    %c0 = arith.constant 0 : index
+    %mask = vector.constant_mask [1, 3, 1, 16, 1] : vector<1x3x1x16x1xi1>
+    vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask :
+      vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
+    return
+}
+// CHECK-LABEL: func @constant_masked_transfer_write_rank_reducing
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
+//   CHECK-NOT:   vector.constant_mask
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
+//       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
+
+func.func @constant_masked_transfer_write_trailing_unit_dim(
+      %arg : memref<3x1xi32>,
+      %vec : vector<4x1xi32>) {
+    %c0 = arith.constant 0 : index
+    %mask = vector.constant_mask [3, 1] : vector<4x1xi1>
+    vector.transfer_write %vec, %arg[%c0, %c0], %mask {in_bounds = [true, true]} :
+      vector<4x1xi32>, memref<3x1xi32>
+    return
+}
+// CHECK-LABEL: func @constant_masked_transfer_write_trailing_unit_dim
+//  CHECK-SAME:     %[[ARG:.+]]: memref<3x1xi32>
+//       CHECK:   %[[MASK:.+]] = vector.constant_mask [3] : vector<4xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [3, 1] [1, 1]
+//  CHECK-SAME:     memref<3x1xi32> to memref<3xi32, {{.*}}>
+//       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true]} : vector<4xi32>, memref<3xi32, {{.*}}>
+
+/// Only vector.create_mask and vector.constant_mask masks are supported.
 func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1(
       %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
       %mask : vector<[16]x1xi1>) -> vector<[16]x1xi8> {

>From c6619c84583d114890dcaf2d9421c84e7d55ee3a Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 20 Mar 2026 13:19:02 -0400
Subject: [PATCH 2/7] Unify naming for CreateMaskOp with ConstantMaskOp

---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td              | 4 ++--
 .../Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp  | 2 +-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43ad435ccf1c1..d5040a0487afc 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2608,7 +2608,7 @@ def Vector_CreateMaskOp :
   Vector_Op<"create_mask", [Pure, 
    DeclareOpInterfaceMethods<VectorUnrollOpInterface>
    ]>,
-    Arguments<(ins Variadic<Index>:$operands)>,
+    Arguments<(ins Variadic<Index>:$mask_dim_sizes)>,
     Results<(outs VectorOfAnyRankOf<[I1]>)> {
   let summary = "creates a vector mask";
   let description = [{
@@ -2654,7 +2654,7 @@ def Vector_CreateMaskOp :
 
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
-  let assemblyFormat = "$operands attr-dict `:` type(results)";
+  let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
 }
 
 def Vector_MaskOp : Vector_Op<"mask", [
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 444d9d0aff251..4e51590853960 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -429,7 +429,7 @@ static VectorType trimNonScalableUnitDims(VectorType oldType) {
   return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
 }
 
-static auto getMaskDimData(vector::CreateMaskOp op) { return op.getOperands(); }
+static auto getMaskDimData(vector::CreateMaskOp op) { return op.getMaskDimSizes(); }
 static auto getMaskDimData(vector::ConstantMaskOp op) {
   return op.getMaskDimSizes();
 }

>From 4ba2b220f22ce54ad8376fe70842f3361f2ef08e Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 20 Mar 2026 13:31:54 -0400
Subject: [PATCH 3/7] Get rid of auto

---
 .../Vector/Transforms/VectorTransferOpTransforms.cpp | 12 +++---------
 1 file changed, 3 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 4e51590853960..a8dc8f692147a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -429,11 +429,6 @@ static VectorType trimNonScalableUnitDims(VectorType oldType) {
   return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
 }
 
-static auto getMaskDimData(vector::CreateMaskOp op) { return op.getMaskDimSizes(); }
-static auto getMaskDimData(vector::ConstantMaskOp op) {
-  return op.getMaskDimSizes();
-}
-
 static bool isUnitDimMask(Value operand) {
   auto cst = operand.getDefiningOp<arith::ConstantIndexOp>();
   return cst && cst.value() == 1;
@@ -450,11 +445,10 @@ static FailureOr<Value> maskDropNonScalableUnitDims(PatternRewriter &rewriter,
   if (reducedType.getRank() == type.getRank())
     return failure();
 
-  auto maskDimData = getMaskDimData(op);
-  using ElemType = std::decay_t<decltype(*maskDimData.begin())>;
+  using ElemType = std::decay_t<decltype(*op.getMaskDimSizes().begin())>;
   SmallVector<ElemType> reduced;
-  for (auto [dim, dimIsScalable, elem] :
-       llvm::zip_equal(type.getShape(), type.getScalableDims(), maskDimData)) {
+  for (auto [dim, dimIsScalable, elem] : llvm::zip_equal(
+           type.getShape(), type.getScalableDims(), op.getMaskDimSizes())) {
     if (dim == 1 && !dimIsScalable) {
       if (!isUnitDimMask(elem))
         return failure();

>From f55184bde13357ba6df01ccc357edbe0d695b420 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 20 Mar 2026 13:37:45 -0400
Subject: [PATCH 4/7] rename operand -> maskDimSize

---
 .../Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp  | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index a8dc8f692147a..3b9ce505f93c5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -429,8 +429,8 @@ static VectorType trimNonScalableUnitDims(VectorType oldType) {
   return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
 }
 
-static bool isUnitDimMask(Value operand) {
-  auto cst = operand.getDefiningOp<arith::ConstantIndexOp>();
+static bool isUnitDimMask(Value maskDimSize) {
+  auto cst = maskDimSize.getDefiningOp<arith::ConstantIndexOp>();
   return cst && cst.value() == 1;
 }
 static bool isUnitDimMask(int64_t maskDimSize) { return maskDimSize == 1; }

>From 45fb202e66921c3b141bd6e57026455cf7593c8d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 20 Mar 2026 13:40:02 -0400
Subject: [PATCH 5/7] rankReducedMask -> rankReducedMaskOp

---
 .../Vector/Transforms/VectorTransferOpTransforms.cpp   | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 3b9ce505f93c5..74962c3bb51fd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -526,23 +526,23 @@ class TransferReadDropUnitDimsPattern
     Value maskOp = transferReadOp.getMask();
     if (maskOp) {
       LDBG() << "  -> Processing mask operation";
-      FailureOr<Value> rankReducedMask = failure();
+      FailureOr<Value> rankReducedMaskOp = failure();
       if (auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>())
-        rankReducedMask =
+        rankReducedMaskOp =
             maskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
       else if (auto constantMaskOp =
                    maskOp.getDefiningOp<vector::ConstantMaskOp>())
-        rankReducedMask =
+        rankReducedMaskOp =
             maskDropNonScalableUnitDims(rewriter, loc, constantMaskOp);
 
-      if (failed(rankReducedMask)) {
+      if (failed(rankReducedMaskOp)) {
         LDBG() << "  -> Failed to reduce mask dimensions";
         return rewriter.notifyMatchFailure(
             transferReadOp,
             "unsupported mask op, only 'vector.create_mask' and "
             "'vector.constant_mask' are currently supported");
       }
-      maskOp = *rankReducedMask;
+      maskOp = *rankReducedMaskOp;
       LDBG() << "  -> Successfully reduced mask dimensions";
     }
 

>From 0a55a856dec3d2e6916badea808f810009edbf09 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 20 Mar 2026 13:49:35 -0400
Subject: [PATCH 6/7] Follow conventions in tests.

---
 ...ctor-transfer-drop-unit-dims-patterns.mlir | 121 ++++++++++--------
 1 file changed, 70 insertions(+), 51 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index e09192e73b6f4..d30ba64c09159 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -176,7 +176,7 @@ func.func @transfer_read_dynamic_rank_reducing(
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
 //       CHECK:   vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<?xi8, {{.*}}>, vector<[16]xi8>
 
-func.func @masked_transfer_read_dynamic_rank_reducing_1(
+func.func @masked_transfer_read_dynamic_rank_reducing_1_create_mask(
       %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
       %mask_dim0 : index) -> vector<[16]x1xi8> {
     %c0 = arith.constant 0 : index
@@ -187,7 +187,7 @@ func.func @masked_transfer_read_dynamic_rank_reducing_1(
       memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
     return %v : vector<[16]x1xi8>
 }
-// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_1
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_1_create_mask
 //  CHECK-SAME:     %[[ARG:.+]]: memref<?x1xi8
 //  CHECK-SAME:     %[[MASK_DIM0:.+]]: index
 //       CHECK:   %[[C0:.+]] = arith.constant 0 : index
@@ -197,7 +197,22 @@ func.func @masked_transfer_read_dynamic_rank_reducing_1(
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
 //       CHECK:   vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true]} : memref<?xi8, {{.*}}>, vector<[16]xi8>
 
-func.func @masked_transfer_read_dynamic_rank_reducing_2(
+func.func @masked_transfer_read_dynamic_rank_reducing_1_constant_mask(
+      %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>) -> vector<[16]x1xi8> {
+    %c0 = arith.constant 0 : index
+    %pad = arith.constant 0 : i8
+    %mask = vector.constant_mask [16, 1] : vector<[16]x1xi1>
+    %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
+      memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
+    return %v : vector<[16]x1xi8>
+}
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_1_constant_mask
+//  CHECK-SAME:     %[[ARG:.+]]: memref<?x1xi8
+//   CHECK-NOT:   vector.constant_mask
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [{{.*}}, 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
+//       CHECK:   vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : memref<?xi8, {{.*}}>, vector<[16]xi8>
+
+func.func @masked_transfer_read_dynamic_rank_reducing_2_create_mask(
       %arg : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>,
       %mask_dim1 : index, %mask_dim4 : index) -> vector<1x[1]x3x1x[16]x1xi8> {
     %c0 = arith.constant 0 : index
@@ -209,7 +224,7 @@ func.func @masked_transfer_read_dynamic_rank_reducing_2(
       memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>, vector<1x[1]x3x1x[16]x1xi8>
     return %v : vector<1x[1]x3x1x[16]x1xi8>
 }
-// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_2
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_2_create_mask
 //  CHECK-SAME:     %[[ARG:.+]]: memref<1x?x3x1x?x1xi8
 //  CHECK-SAME:     %[[MASK_DIM1:.+]]: index, %[[MASK_DIM4:.+]]: index
 //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
@@ -223,7 +238,28 @@ func.func @masked_transfer_read_dynamic_rank_reducing_2(
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}>
 //       CHECK:   vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8>
 
-func.func @masked_transfer_write_and_vector_rank_reducing(
+func.func @masked_transfer_read_dynamic_rank_reducing_2_constant_mask(
+      %arg : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>) -> vector<1x[1]x3x1x[16]x1xi8> {
+    %c0 = arith.constant 0 : index
+    %pad = arith.constant 0 : i8
+    %mask = vector.constant_mask [1, 1, 2, 1, 16, 1] : vector<1x[1]x3x1x[16]x1xi1>
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0, %c0], %pad, %mask {in_bounds = [true, true, true, true, true, true]} :
+      memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>, vector<1x[1]x3x1x[16]x1xi8>
+    return %v : vector<1x[1]x3x1x[16]x1xi8>
+}
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_2_constant_mask
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x?x3x1x?x1xi8
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
+//   CHECK-DAG:   %[[PAD:.+]] = arith.constant 0 : i8
+//       CHECK:   %[[MASK:.+]] = vector.constant_mask [1, 2, 16] : vector<[1]x3x[16]xi1>
+//       CHECK:   %[[DIM1:.+]] = memref.dim %[[ARG]], %[[C1]] : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>
+//       CHECK:   %[[DIM4:.+]] = memref.dim %[[ARG]], %[[C4]] : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}>
+//       CHECK:   vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8>
+
+func.func @masked_transfer_write_and_vector_rank_reducing_create_mask(
       %arg : memref<1x1x3x1x16x1xf32>,
       %vec : vector<1x3x1x16x1xf32>,
       %mask_dim1 : index,
@@ -235,7 +271,7 @@ func.func @masked_transfer_write_and_vector_rank_reducing(
       vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
     return
 }
-// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing
+// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing_create_mask
 //  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
 //  CHECK-SAME:     {{.*}}: vector<1x3x1x16x1xf32>,
 //  CHECK-SAME:     %[[MASKDIM1:.+]]: index,
@@ -245,7 +281,23 @@ func.func @masked_transfer_write_and_vector_rank_reducing(
 //  CHECK-SAME:     memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
 //       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
 
-func.func @masked_transfer_write_dynamic_rank_reducing(
+func.func @masked_transfer_write_and_vector_rank_reducing_constant_mask(
+      %arg : memref<1x1x3x1x16x1xf32>,
+      %vec : vector<1x3x1x16x1xf32>) {
+    %c0 = arith.constant 0 : index
+    %mask = vector.constant_mask [1, 2, 1, 8, 1] : vector<1x3x1x16x1xi1>
+    vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask :
+      vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
+    return
+}
+// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing_constant_mask
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
+//       CHECK:   %[[MASK:.+]] = vector.constant_mask [2, 8] : vector<3x16xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
+//       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
+
+func.func @masked_transfer_write_dynamic_rank_reducing_create_mask(
       %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
       %vec : vector<[16]x1xi8>,
       %mask_dim0 : index) {
@@ -257,7 +309,7 @@ func.func @masked_transfer_write_dynamic_rank_reducing(
       vector<[16]x1xi8>, memref<?x1xi8, strided<[?, ?], offset: ?>>
     return
 }
-// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing
+// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing_create_mask
 //  CHECK-SAME:     %[[ARG:.+]]: memref<?x1xi8
 //  CHECK-SAME:     %{{.*}}: vector<[16]x1xi8>,
 //  CHECK-SAME:     %[[MASK_DIM0:.+]]: index
@@ -267,53 +319,20 @@ func.func @masked_transfer_write_dynamic_rank_reducing(
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
 //       CHECK:   vector.transfer_write {{.*}}, %[[SUBVIEW]][%[[C0]]], %[[MASK]] {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}>
 
-func.func @constant_masked_transfer_read_rank_reducing(
-      %arg : memref<3x1xi8>) -> vector<3x1xi8> {
-    %c0 = arith.constant 0 : index
-    %pad = arith.constant 0 : i8
-    %mask = vector.constant_mask [2, 1] : vector<3x1xi1>
-    %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
-      memref<3x1xi8>, vector<3x1xi8>
-    return %v : vector<3x1xi8>
-}
-// CHECK-LABEL: func @constant_masked_transfer_read_rank_reducing
-//  CHECK-SAME:     %[[ARG:.+]]: memref<3x1xi8>
-//       CHECK:   %[[MASK:.+]] = vector.constant_mask [2] : vector<3xi1>
-//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [3, 1] [1, 1]
-//  CHECK-SAME:     memref<3x1xi8> to memref<3xi8, {{.*}}>
-//       CHECK:   vector.transfer_read %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true]} : memref<3xi8, {{.*}}>, vector<3xi8>
-
-func.func @constant_masked_transfer_write_rank_reducing(
-      %arg : memref<1x1x3x1x16x1xf32>,
-      %vec : vector<1x3x1x16x1xf32>) {
-    %c0 = arith.constant 0 : index
-    %mask = vector.constant_mask [1, 3, 1, 16, 1] : vector<1x3x1x16x1xi1>
-    vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask :
-      vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
-    return
-}
-// CHECK-LABEL: func @constant_masked_transfer_write_rank_reducing
-//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
-//   CHECK-NOT:   vector.constant_mask
-//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1]
-//  CHECK-SAME:     memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
-//       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
-
-func.func @constant_masked_transfer_write_trailing_unit_dim(
-      %arg : memref<3x1xi32>,
-      %vec : vector<4x1xi32>) {
+func.func @masked_transfer_write_dynamic_rank_reducing_constant_mask(
+      %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
+      %vec : vector<[16]x1xi8>) {
     %c0 = arith.constant 0 : index
-    %mask = vector.constant_mask [3, 1] : vector<4x1xi1>
+    %mask = vector.constant_mask [16, 1] : vector<[16]x1xi1>
     vector.transfer_write %vec, %arg[%c0, %c0], %mask {in_bounds = [true, true]} :
-      vector<4x1xi32>, memref<3x1xi32>
+      vector<[16]x1xi8>, memref<?x1xi8, strided<[?, ?], offset: ?>>
     return
 }
-// CHECK-LABEL: func @constant_masked_transfer_write_trailing_unit_dim
-//  CHECK-SAME:     %[[ARG:.+]]: memref<3x1xi32>
-//       CHECK:   %[[MASK:.+]] = vector.constant_mask [3] : vector<4xi1>
-//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [3, 1] [1, 1]
-//  CHECK-SAME:     memref<3x1xi32> to memref<3xi32, {{.*}}>
-//       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true]} : vector<4xi32>, memref<3xi32, {{.*}}>
+// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing_constant_mask
+//  CHECK-SAME:     %[[ARG:.+]]: memref<?x1xi8
+//   CHECK-NOT:   vector.constant_mask
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [{{.*}}, 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
+//       CHECK:   vector.transfer_write {{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}>
 
 /// Only vector.create_mask and vector.constant_mask masks are supported.
 func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1(

>From 504e5812da6f3d09d20b8f9948c7aed215651016 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 20 Mar 2026 14:40:19 -0400
Subject: [PATCH 7/7] Use matchPattern instead of defining op

---
 .../Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp  | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 74962c3bb51fd..f4f598676d151 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/ADT/STLExtras.h"
@@ -430,8 +431,7 @@ static VectorType trimNonScalableUnitDims(VectorType oldType) {
 }
 
 static bool isUnitDimMask(Value maskDimSize) {
-  auto cst = maskDimSize.getDefiningOp<arith::ConstantIndexOp>();
-  return cst && cst.value() == 1;
+  return matchPattern(maskDimSize, m_One());
 }
 static bool isUnitDimMask(int64_t maskDimSize) { return maskDimSize == 1; }
 



More information about the Mlir-commits mailing list