[Mlir-commits] [mlir] 5072c02 - [mlir][vector] Drop trailing 1-dims from constant_mask (#187383)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 1 06:05:31 PDT 2026


Author: Erick Ochoa Lopez
Date: 2026-04-01T09:05:25-04:00
New Revision: 5072c020aae1636f687af797d4082a26a761c721

URL: https://github.com/llvm/llvm-project/commit/5072c020aae1636f687af797d4082a26a761c721
DIFF: https://github.com/llvm/llvm-project/commit/5072c020aae1636f687af797d4082a26a761c721.diff

LOG: [mlir][vector] Drop trailing 1-dims from constant_mask (#187383)

Generalize TransferReadDropUnitDimsPattern to also drop unit dimensions
when `vector::ConstantMaskOp` is used.

Previously TransferReadDropUnitDimsPattern would only drop unit
dimensions when `vector::CreateMaskOp` with a statically known operand
was used.

Assisted-by: Cursor

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index f8ece8b734c42..68ef49172e662 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2608,7 +2608,7 @@ def Vector_CreateMaskOp :
   Vector_Op<"create_mask", [Pure, 
    DeclareOpInterfaceMethods<VectorUnrollOpInterface>
    ]>,
-    Arguments<(ins Variadic<Index>:$operands)>,
+    Arguments<(ins Variadic<Index>:$mask_dim_sizes)>,
     Results<(outs VectorOfAnyRankOf<[I1]>)> {
   let summary = "creates a vector mask";
   let description = [{
@@ -2654,7 +2654,7 @@ def Vector_CreateMaskOp :
 
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
-  let assemblyFormat = "$operands attr-dict `:` type(results)";
+  let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
 }
 
 def Vector_MaskOp : Vector_Op<"mask", [

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index babd321e484bd..ac7b84abc4e06 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/ADT/STLExtras.h"
@@ -429,30 +430,33 @@ static VectorType trimNonScalableUnitDims(VectorType oldType) {
   return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
 }
 
-// Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
-static FailureOr<Value>
-createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
-                                  vector::CreateMaskOp op) {
+static bool isUnitDimMask(Value maskDimSize) {
+  return matchPattern(maskDimSize, m_One());
+}
+static bool isUnitDimMask(int64_t maskDimSize) { return maskDimSize == 1; }
+
+/// Rewrites a mask op to drop non-scalable unit dimensions.
+/// Supports vector.create_mask and vector.constant_mask.
+template <typename MaskOp>
+static FailureOr<Value> maskDropNonScalableUnitDims(PatternRewriter &rewriter,
+                                                    Location loc, MaskOp op) {
   auto type = op.getType();
   VectorType reducedType = trimNonScalableUnitDims(type);
   if (reducedType.getRank() == type.getRank())
     return failure();
 
-  SmallVector<Value> reducedOperands;
-  for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
-           type.getShape(), type.getScalableDims(), op.getOperands())) {
+  using ElemType = std::decay_t<decltype(*op.getMaskDimSizes().begin())>;
+  SmallVector<ElemType> reduced;
+  for (auto [dim, dimIsScalable, elem] : llvm::zip_equal(
+           type.getShape(), type.getScalableDims(), op.getMaskDimSizes())) {
     if (dim == 1 && !dimIsScalable) {
-      // If the mask for the unit dim is not a constant of 1, do nothing.
-      auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
-      if (!constant || (constant.value() != 1))
+      if (!isUnitDimMask(elem))
         return failure();
       continue;
     }
-    reducedOperands.push_back(operand);
+    reduced.push_back(elem);
   }
-  return vector::CreateMaskOp::create(rewriter, loc, reducedType,
-                                      reducedOperands)
-      .getResult();
+  return MaskOp::create(rewriter, loc, reducedType, reduced).getResult();
 }
 
 namespace {
@@ -522,21 +526,23 @@ class TransferReadDropUnitDimsPattern
     Value maskOp = transferReadOp.getMask();
     if (maskOp) {
       LDBG() << "  -> Processing mask operation";
-      auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
-      if (!createMaskOp) {
-        LDBG()
-            << "  -> Unsupported mask op, only 'vector.create_mask' supported";
-        return rewriter.notifyMatchFailure(
-            transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
-                            "currently supported");
-      }
-      FailureOr<Value> rankReducedCreateMask =
-          createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
-      if (failed(rankReducedCreateMask)) {
+      FailureOr<Value> rankReducedMaskOp = failure();
+      if (auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>())
+        rankReducedMaskOp =
+            maskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
+      else if (auto constantMaskOp =
+                   maskOp.getDefiningOp<vector::ConstantMaskOp>())
+        rankReducedMaskOp =
+            maskDropNonScalableUnitDims(rewriter, loc, constantMaskOp);
+
+      if (failed(rankReducedMaskOp)) {
         LDBG() << "  -> Failed to reduce mask dimensions";
-        return failure();
+        return rewriter.notifyMatchFailure(
+            transferReadOp,
+            "unsupported mask op, only 'vector.create_mask' and "
+            "'vector.constant_mask' are currently supported");
       }
-      maskOp = *rankReducedCreateMask;
+      maskOp = *rankReducedMaskOp;
       LDBG() << "  -> Successfully reduced mask dimensions";
     }
 
@@ -636,22 +642,23 @@ class TransferWriteDropUnitDimsPattern
     Value maskOp = transferWriteOp.getMask();
     if (maskOp) {
       LDBG() << "  -> Processing mask operation";
-      auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
-      if (!createMaskOp) {
-        LDBG()
-            << "  -> Unsupported mask op, only 'vector.create_mask' supported";
+      FailureOr<Value> rankReducedMask = failure();
+      if (auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>())
+        rankReducedMask =
+            maskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
+      else if (auto constantMaskOp =
+                   maskOp.getDefiningOp<vector::ConstantMaskOp>())
+        rankReducedMask =
+            maskDropNonScalableUnitDims(rewriter, loc, constantMaskOp);
+
+      if (failed(rankReducedMask)) {
+        LDBG() << "  -> Failed to reduce mask dimensions";
         return rewriter.notifyMatchFailure(
             transferWriteOp,
-            "unsupported mask op, only 'vector.create_mask' is "
-            "currently supported");
-      }
-      FailureOr<Value> rankReducedCreateMask =
-          createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
-      if (failed(rankReducedCreateMask)) {
-        LDBG() << "  -> Failed to reduce mask dimensions";
-        return failure();
+            "unsupported mask op, only 'vector.create_mask' and "
+            "'vector.constant_mask' are currently supported");
       }
-      maskOp = *rankReducedCreateMask;
+      maskOp = *rankReducedMask;
       LDBG() << "  -> Successfully reduced mask dimensions";
     }
     LDBG() << "  -> Creating rank-reduced subview and new transfer_write";

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index 8234351302f6b..d30ba64c09159 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -176,7 +176,7 @@ func.func @transfer_read_dynamic_rank_reducing(
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
 //       CHECK:   vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<?xi8, {{.*}}>, vector<[16]xi8>
 
-func.func @masked_transfer_read_dynamic_rank_reducing_1(
+func.func @masked_transfer_read_dynamic_rank_reducing_1_create_mask(
       %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
       %mask_dim0 : index) -> vector<[16]x1xi8> {
     %c0 = arith.constant 0 : index
@@ -187,7 +187,7 @@ func.func @masked_transfer_read_dynamic_rank_reducing_1(
       memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
     return %v : vector<[16]x1xi8>
 }
-// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_1
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_1_create_mask
 //  CHECK-SAME:     %[[ARG:.+]]: memref<?x1xi8
 //  CHECK-SAME:     %[[MASK_DIM0:.+]]: index
 //       CHECK:   %[[C0:.+]] = arith.constant 0 : index
@@ -197,7 +197,22 @@ func.func @masked_transfer_read_dynamic_rank_reducing_1(
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
 //       CHECK:   vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true]} : memref<?xi8, {{.*}}>, vector<[16]xi8>
 
-func.func @masked_transfer_read_dynamic_rank_reducing_2(
+func.func @masked_transfer_read_dynamic_rank_reducing_1_constant_mask(
+      %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>) -> vector<[16]x1xi8> {
+    %c0 = arith.constant 0 : index
+    %pad = arith.constant 0 : i8
+    %mask = vector.constant_mask [16, 1] : vector<[16]x1xi1>
+    %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
+      memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
+    return %v : vector<[16]x1xi8>
+}
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_1_constant_mask
+//  CHECK-SAME:     %[[ARG:.+]]: memref<?x1xi8
+//   CHECK-NOT:   vector.constant_mask
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [{{.*}}, 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
+//       CHECK:   vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : memref<?xi8, {{.*}}>, vector<[16]xi8>
+
+func.func @masked_transfer_read_dynamic_rank_reducing_2_create_mask(
       %arg : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>,
       %mask_dim1 : index, %mask_dim4 : index) -> vector<1x[1]x3x1x[16]x1xi8> {
     %c0 = arith.constant 0 : index
@@ -209,7 +224,7 @@ func.func @masked_transfer_read_dynamic_rank_reducing_2(
       memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>, vector<1x[1]x3x1x[16]x1xi8>
     return %v : vector<1x[1]x3x1x[16]x1xi8>
 }
-// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_2
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_2_create_mask
 //  CHECK-SAME:     %[[ARG:.+]]: memref<1x?x3x1x?x1xi8
 //  CHECK-SAME:     %[[MASK_DIM1:.+]]: index, %[[MASK_DIM4:.+]]: index
 //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
@@ -223,7 +238,28 @@ func.func @masked_transfer_read_dynamic_rank_reducing_2(
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}>
 //       CHECK:   vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8>
 
-func.func @masked_transfer_write_and_vector_rank_reducing(
+func.func @masked_transfer_read_dynamic_rank_reducing_2_constant_mask(
+      %arg : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>) -> vector<1x[1]x3x1x[16]x1xi8> {
+    %c0 = arith.constant 0 : index
+    %pad = arith.constant 0 : i8
+    %mask = vector.constant_mask [1, 1, 2, 1, 16, 1] : vector<1x[1]x3x1x[16]x1xi1>
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0, %c0], %pad, %mask {in_bounds = [true, true, true, true, true, true]} :
+      memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>, vector<1x[1]x3x1x[16]x1xi8>
+    return %v : vector<1x[1]x3x1x[16]x1xi8>
+}
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_2_constant_mask
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x?x3x1x?x1xi8
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
+//   CHECK-DAG:   %[[PAD:.+]] = arith.constant 0 : i8
+//       CHECK:   %[[MASK:.+]] = vector.constant_mask [1, 2, 16] : vector<[1]x3x[16]xi1>
+//       CHECK:   %[[DIM1:.+]] = memref.dim %[[ARG]], %[[C1]] : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>
+//       CHECK:   %[[DIM4:.+]] = memref.dim %[[ARG]], %[[C4]] : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}>
+//       CHECK:   vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8>
+
+func.func @masked_transfer_write_and_vector_rank_reducing_create_mask(
       %arg : memref<1x1x3x1x16x1xf32>,
       %vec : vector<1x3x1x16x1xf32>,
       %mask_dim1 : index,
@@ -235,7 +271,7 @@ func.func @masked_transfer_write_and_vector_rank_reducing(
       vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
     return
 }
-// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing
+// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing_create_mask
 //  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
 //  CHECK-SAME:     {{.*}}: vector<1x3x1x16x1xf32>,
 //  CHECK-SAME:     %[[MASKDIM1:.+]]: index,
@@ -245,7 +281,23 @@ func.func @masked_transfer_write_and_vector_rank_reducing(
 //  CHECK-SAME:     memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
 //       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
 
-func.func @masked_transfer_write_dynamic_rank_reducing(
+func.func @masked_transfer_write_and_vector_rank_reducing_constant_mask(
+      %arg : memref<1x1x3x1x16x1xf32>,
+      %vec : vector<1x3x1x16x1xf32>) {
+    %c0 = arith.constant 0 : index
+    %mask = vector.constant_mask [1, 2, 1, 8, 1] : vector<1x3x1x16x1xi1>
+    vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask :
+      vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
+    return
+}
+// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing_constant_mask
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
+//       CHECK:   %[[MASK:.+]] = vector.constant_mask [2, 8] : vector<3x16xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
+//       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
+
+func.func @masked_transfer_write_dynamic_rank_reducing_create_mask(
       %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
       %vec : vector<[16]x1xi8>,
       %mask_dim0 : index) {
@@ -257,7 +309,7 @@ func.func @masked_transfer_write_dynamic_rank_reducing(
       vector<[16]x1xi8>, memref<?x1xi8, strided<[?, ?], offset: ?>>
     return
 }
-// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing
+// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing_create_mask
 //  CHECK-SAME:     %[[ARG:.+]]: memref<?x1xi8
 //  CHECK-SAME:     %{{.*}}: vector<[16]x1xi8>,
 //  CHECK-SAME:     %[[MASK_DIM0:.+]]: index
@@ -267,7 +319,22 @@ func.func @masked_transfer_write_dynamic_rank_reducing(
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
 //       CHECK:   vector.transfer_write {{.*}}, %[[SUBVIEW]][%[[C0]]], %[[MASK]] {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}>
 
-/// Only masks operands of vector.create_mask are currently supported.
+func.func @masked_transfer_write_dynamic_rank_reducing_constant_mask(
+      %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
+      %vec : vector<[16]x1xi8>) {
+    %c0 = arith.constant 0 : index
+    %mask = vector.constant_mask [16, 1] : vector<[16]x1xi1>
+    vector.transfer_write %vec, %arg[%c0, %c0], %mask {in_bounds = [true, true]} :
+      vector<[16]x1xi8>, memref<?x1xi8, strided<[?, ?], offset: ?>>
+    return
+}
+// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing_constant_mask
+//  CHECK-SAME:     %[[ARG:.+]]: memref<?x1xi8
+//   CHECK-NOT:   vector.constant_mask
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [{{.*}}, 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
+//       CHECK:   vector.transfer_write {{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}>
+
+/// Only vector.create_mask and vector.constant_mask masks are supported.
 func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1(
       %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
       %mask : vector<[16]x1xi1>) -> vector<[16]x1xi8> {


        


More information about the Mlir-commits mailing list