[Mlir-commits] [mlir] [mlir][vector] Drop trailing 1-dims from constant_mask (PR #187383)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Wed Apr 1 05:52:57 PDT 2026
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/187383
>From 4300e61d598fb44d4dc1055f1d1efa5400d1a0b4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 18 Mar 2026 17:22:10 -0400
Subject: [PATCH 1/7] [mlir][vector] Drop trailing 1-dims from constant_mask
---
.../Transforms/VectorTransferOpTransforms.cpp | 93 +++++++++++--------
...ctor-transfer-drop-unit-dims-patterns.mlir | 50 +++++++++-
2 files changed, 102 insertions(+), 41 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 19db8b3b48a25..444d9d0aff251 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -429,30 +429,40 @@ static VectorType trimNonScalableUnitDims(VectorType oldType) {
return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
}
-// Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
-static FailureOr<Value>
-createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
- vector::CreateMaskOp op) {
+static auto getMaskDimData(vector::CreateMaskOp op) { return op.getOperands(); }
+static auto getMaskDimData(vector::ConstantMaskOp op) {
+ return op.getMaskDimSizes();
+}
+
+static bool isUnitDimMask(Value operand) {
+ auto cst = operand.getDefiningOp<arith::ConstantIndexOp>();
+ return cst && cst.value() == 1;
+}
+static bool isUnitDimMask(int64_t maskDimSize) { return maskDimSize == 1; }
+
+/// Rewrites a mask op to drop non-scalable unit dimensions.
+/// Supports vector.create_mask and vector.constant_mask.
+template <typename MaskOp>
+static FailureOr<Value> maskDropNonScalableUnitDims(PatternRewriter &rewriter,
+ Location loc, MaskOp op) {
auto type = op.getType();
VectorType reducedType = trimNonScalableUnitDims(type);
if (reducedType.getRank() == type.getRank())
return failure();
- SmallVector<Value> reducedOperands;
- for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
- type.getShape(), type.getScalableDims(), op.getOperands())) {
+ auto maskDimData = getMaskDimData(op);
+ using ElemType = std::decay_t<decltype(*maskDimData.begin())>;
+ SmallVector<ElemType> reduced;
+ for (auto [dim, dimIsScalable, elem] :
+ llvm::zip_equal(type.getShape(), type.getScalableDims(), maskDimData)) {
if (dim == 1 && !dimIsScalable) {
- // If the mask for the unit dim is not a constant of 1, do nothing.
- auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
- if (!constant || (constant.value() != 1))
+ if (!isUnitDimMask(elem))
return failure();
continue;
}
- reducedOperands.push_back(operand);
+ reduced.push_back(elem);
}
- return vector::CreateMaskOp::create(rewriter, loc, reducedType,
- reducedOperands)
- .getResult();
+ return MaskOp::create(rewriter, loc, reducedType, reduced).getResult();
}
namespace {
@@ -522,21 +532,23 @@ class TransferReadDropUnitDimsPattern
Value maskOp = transferReadOp.getMask();
if (maskOp) {
LDBG() << " -> Processing mask operation";
- auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
- if (!createMaskOp) {
- LDBG()
- << " -> Unsupported mask op, only 'vector.create_mask' supported";
- return rewriter.notifyMatchFailure(
- transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
- "currently supported");
- }
- FailureOr<Value> rankReducedCreateMask =
- createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
- if (failed(rankReducedCreateMask)) {
+ FailureOr<Value> rankReducedMask = failure();
+ if (auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>())
+ rankReducedMask =
+ maskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
+ else if (auto constantMaskOp =
+ maskOp.getDefiningOp<vector::ConstantMaskOp>())
+ rankReducedMask =
+ maskDropNonScalableUnitDims(rewriter, loc, constantMaskOp);
+
+ if (failed(rankReducedMask)) {
LDBG() << " -> Failed to reduce mask dimensions";
- return failure();
+ return rewriter.notifyMatchFailure(
+ transferReadOp,
+ "unsupported mask op, only 'vector.create_mask' and "
+ "'vector.constant_mask' are currently supported");
}
- maskOp = *rankReducedCreateMask;
+ maskOp = *rankReducedMask;
LDBG() << " -> Successfully reduced mask dimensions";
}
@@ -636,22 +648,23 @@ class TransferWriteDropUnitDimsPattern
Value maskOp = transferWriteOp.getMask();
if (maskOp) {
LDBG() << " -> Processing mask operation";
- auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
- if (!createMaskOp) {
- LDBG()
- << " -> Unsupported mask op, only 'vector.create_mask' supported";
+ FailureOr<Value> rankReducedMask = failure();
+ if (auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>())
+ rankReducedMask =
+ maskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
+ else if (auto constantMaskOp =
+ maskOp.getDefiningOp<vector::ConstantMaskOp>())
+ rankReducedMask =
+ maskDropNonScalableUnitDims(rewriter, loc, constantMaskOp);
+
+ if (failed(rankReducedMask)) {
+ LDBG() << " -> Failed to reduce mask dimensions";
return rewriter.notifyMatchFailure(
transferWriteOp,
- "unsupported mask op, only 'vector.create_mask' is "
- "currently supported");
- }
- FailureOr<Value> rankReducedCreateMask =
- createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
- if (failed(rankReducedCreateMask)) {
- LDBG() << " -> Failed to reduce mask dimensions";
- return failure();
+ "unsupported mask op, only 'vector.create_mask' and "
+ "'vector.constant_mask' are currently supported");
}
- maskOp = *rankReducedCreateMask;
+ maskOp = *rankReducedMask;
LDBG() << " -> Successfully reduced mask dimensions";
}
LDBG() << " -> Creating rank-reduced subview and new transfer_write";
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index 8234351302f6b..e09192e73b6f4 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -267,7 +267,55 @@ func.func @masked_transfer_write_dynamic_rank_reducing(
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
// CHECK: vector.transfer_write {{.*}}, %[[SUBVIEW]][%[[C0]]], %[[MASK]] {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}>
-/// Only masks operands of vector.create_mask are currently supported.
+func.func @constant_masked_transfer_read_rank_reducing(
+ %arg : memref<3x1xi8>) -> vector<3x1xi8> {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0 : i8
+ %mask = vector.constant_mask [2, 1] : vector<3x1xi1>
+ %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
+ memref<3x1xi8>, vector<3x1xi8>
+ return %v : vector<3x1xi8>
+}
+// CHECK-LABEL: func @constant_masked_transfer_read_rank_reducing
+// CHECK-SAME: %[[ARG:.+]]: memref<3x1xi8>
+// CHECK: %[[MASK:.+]] = vector.constant_mask [2] : vector<3xi1>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [3, 1] [1, 1]
+// CHECK-SAME: memref<3x1xi8> to memref<3xi8, {{.*}}>
+// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true]} : memref<3xi8, {{.*}}>, vector<3xi8>
+
+func.func @constant_masked_transfer_write_rank_reducing(
+ %arg : memref<1x1x3x1x16x1xf32>,
+ %vec : vector<1x3x1x16x1xf32>) {
+ %c0 = arith.constant 0 : index
+ %mask = vector.constant_mask [1, 3, 1, 16, 1] : vector<1x3x1x16x1xi1>
+ vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask :
+ vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
+ return
+}
+// CHECK-LABEL: func @constant_masked_transfer_write_rank_reducing
+// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
+// CHECK-NOT: vector.constant_mask
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1]
+// CHECK-SAME: memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
+// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
+
+func.func @constant_masked_transfer_write_trailing_unit_dim(
+ %arg : memref<3x1xi32>,
+ %vec : vector<4x1xi32>) {
+ %c0 = arith.constant 0 : index
+ %mask = vector.constant_mask [3, 1] : vector<4x1xi1>
+ vector.transfer_write %vec, %arg[%c0, %c0], %mask {in_bounds = [true, true]} :
+ vector<4x1xi32>, memref<3x1xi32>
+ return
+}
+// CHECK-LABEL: func @constant_masked_transfer_write_trailing_unit_dim
+// CHECK-SAME: %[[ARG:.+]]: memref<3x1xi32>
+// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<4xi1>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [3, 1] [1, 1]
+// CHECK-SAME: memref<3x1xi32> to memref<3xi32, {{.*}}>
+// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true]} : vector<4xi32>, memref<3xi32, {{.*}}>
+
+/// Only vector.create_mask and vector.constant_mask masks are supported.
func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1(
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
%mask : vector<[16]x1xi1>) -> vector<[16]x1xi8> {
>From c6619c84583d114890dcaf2d9421c84e7d55ee3a Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 20 Mar 2026 13:19:02 -0400
Subject: [PATCH 2/7] Unify naming for CreateMaskOp with ConstantMaskOp
---
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 4 ++--
.../Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43ad435ccf1c1..d5040a0487afc 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2608,7 +2608,7 @@ def Vector_CreateMaskOp :
Vector_Op<"create_mask", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>
]>,
- Arguments<(ins Variadic<Index>:$operands)>,
+ Arguments<(ins Variadic<Index>:$mask_dim_sizes)>,
Results<(outs VectorOfAnyRankOf<[I1]>)> {
let summary = "creates a vector mask";
let description = [{
@@ -2654,7 +2654,7 @@ def Vector_CreateMaskOp :
let hasCanonicalizer = 1;
let hasVerifier = 1;
- let assemblyFormat = "$operands attr-dict `:` type(results)";
+ let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
}
def Vector_MaskOp : Vector_Op<"mask", [
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 444d9d0aff251..4e51590853960 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -429,7 +429,7 @@ static VectorType trimNonScalableUnitDims(VectorType oldType) {
return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
}
-static auto getMaskDimData(vector::CreateMaskOp op) { return op.getOperands(); }
+static auto getMaskDimData(vector::CreateMaskOp op) { return op.getMaskDimSizes(); }
static auto getMaskDimData(vector::ConstantMaskOp op) {
return op.getMaskDimSizes();
}
>From 4ba2b220f22ce54ad8376fe70842f3361f2ef08e Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 20 Mar 2026 13:31:54 -0400
Subject: [PATCH 3/7] Get rid of auto
---
.../Vector/Transforms/VectorTransferOpTransforms.cpp | 12 +++---------
1 file changed, 3 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 4e51590853960..a8dc8f692147a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -429,11 +429,6 @@ static VectorType trimNonScalableUnitDims(VectorType oldType) {
return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
}
-static auto getMaskDimData(vector::CreateMaskOp op) { return op.getMaskDimSizes(); }
-static auto getMaskDimData(vector::ConstantMaskOp op) {
- return op.getMaskDimSizes();
-}
-
static bool isUnitDimMask(Value operand) {
auto cst = operand.getDefiningOp<arith::ConstantIndexOp>();
return cst && cst.value() == 1;
@@ -450,11 +445,10 @@ static FailureOr<Value> maskDropNonScalableUnitDims(PatternRewriter &rewriter,
if (reducedType.getRank() == type.getRank())
return failure();
- auto maskDimData = getMaskDimData(op);
- using ElemType = std::decay_t<decltype(*maskDimData.begin())>;
+ using ElemType = std::decay_t<decltype(*op.getMaskDimSizes().begin())>;
SmallVector<ElemType> reduced;
- for (auto [dim, dimIsScalable, elem] :
- llvm::zip_equal(type.getShape(), type.getScalableDims(), maskDimData)) {
+ for (auto [dim, dimIsScalable, elem] : llvm::zip_equal(
+ type.getShape(), type.getScalableDims(), op.getMaskDimSizes())) {
if (dim == 1 && !dimIsScalable) {
if (!isUnitDimMask(elem))
return failure();
>From f55184bde13357ba6df01ccc357edbe0d695b420 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 20 Mar 2026 13:37:45 -0400
Subject: [PATCH 4/7] rename operand -> maskDimSize
---
.../Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index a8dc8f692147a..3b9ce505f93c5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -429,8 +429,8 @@ static VectorType trimNonScalableUnitDims(VectorType oldType) {
return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
}
-static bool isUnitDimMask(Value operand) {
- auto cst = operand.getDefiningOp<arith::ConstantIndexOp>();
+static bool isUnitDimMask(Value maskDimSize) {
+ auto cst = maskDimSize.getDefiningOp<arith::ConstantIndexOp>();
return cst && cst.value() == 1;
}
static bool isUnitDimMask(int64_t maskDimSize) { return maskDimSize == 1; }
>From 45fb202e66921c3b141bd6e57026455cf7593c8d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 20 Mar 2026 13:40:02 -0400
Subject: [PATCH 5/7] rankReducedMask -> rankReducedMaskOp
---
.../Vector/Transforms/VectorTransferOpTransforms.cpp | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 3b9ce505f93c5..74962c3bb51fd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -526,23 +526,23 @@ class TransferReadDropUnitDimsPattern
Value maskOp = transferReadOp.getMask();
if (maskOp) {
LDBG() << " -> Processing mask operation";
- FailureOr<Value> rankReducedMask = failure();
+ FailureOr<Value> rankReducedMaskOp = failure();
if (auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>())
- rankReducedMask =
+ rankReducedMaskOp =
maskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
else if (auto constantMaskOp =
maskOp.getDefiningOp<vector::ConstantMaskOp>())
- rankReducedMask =
+ rankReducedMaskOp =
maskDropNonScalableUnitDims(rewriter, loc, constantMaskOp);
- if (failed(rankReducedMask)) {
+ if (failed(rankReducedMaskOp)) {
LDBG() << " -> Failed to reduce mask dimensions";
return rewriter.notifyMatchFailure(
transferReadOp,
"unsupported mask op, only 'vector.create_mask' and "
"'vector.constant_mask' are currently supported");
}
- maskOp = *rankReducedMask;
+ maskOp = *rankReducedMaskOp;
LDBG() << " -> Successfully reduced mask dimensions";
}
>From 0a55a856dec3d2e6916badea808f810009edbf09 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 20 Mar 2026 13:49:35 -0400
Subject: [PATCH 6/7] Follow conventions in tests.
---
...ctor-transfer-drop-unit-dims-patterns.mlir | 121 ++++++++++--------
1 file changed, 70 insertions(+), 51 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index e09192e73b6f4..d30ba64c09159 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -176,7 +176,7 @@ func.func @transfer_read_dynamic_rank_reducing(
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<?xi8, {{.*}}>, vector<[16]xi8>
-func.func @masked_transfer_read_dynamic_rank_reducing_1(
+func.func @masked_transfer_read_dynamic_rank_reducing_1_create_mask(
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
%mask_dim0 : index) -> vector<[16]x1xi8> {
%c0 = arith.constant 0 : index
@@ -187,7 +187,7 @@ func.func @masked_transfer_read_dynamic_rank_reducing_1(
memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
return %v : vector<[16]x1xi8>
}
-// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_1
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_1_create_mask
// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
// CHECK-SAME: %[[MASK_DIM0:.+]]: index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -197,7 +197,22 @@ func.func @masked_transfer_read_dynamic_rank_reducing_1(
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
// CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true]} : memref<?xi8, {{.*}}>, vector<[16]xi8>
-func.func @masked_transfer_read_dynamic_rank_reducing_2(
+func.func @masked_transfer_read_dynamic_rank_reducing_1_constant_mask(
+ %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>) -> vector<[16]x1xi8> {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0 : i8
+ %mask = vector.constant_mask [16, 1] : vector<[16]x1xi1>
+ %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
+ memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
+ return %v : vector<[16]x1xi8>
+}
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_1_constant_mask
+// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
+// CHECK-NOT: vector.constant_mask
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [{{.*}}, 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
+// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : memref<?xi8, {{.*}}>, vector<[16]xi8>
+
+func.func @masked_transfer_read_dynamic_rank_reducing_2_create_mask(
%arg : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>,
%mask_dim1 : index, %mask_dim4 : index) -> vector<1x[1]x3x1x[16]x1xi8> {
%c0 = arith.constant 0 : index
@@ -209,7 +224,7 @@ func.func @masked_transfer_read_dynamic_rank_reducing_2(
memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>, vector<1x[1]x3x1x[16]x1xi8>
return %v : vector<1x[1]x3x1x[16]x1xi8>
}
-// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_2
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_2_create_mask
// CHECK-SAME: %[[ARG:.+]]: memref<1x?x3x1x?x1xi8
// CHECK-SAME: %[[MASK_DIM1:.+]]: index, %[[MASK_DIM4:.+]]: index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
@@ -223,7 +238,28 @@ func.func @masked_transfer_read_dynamic_rank_reducing_2(
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}>
// CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8>
-func.func @masked_transfer_write_and_vector_rank_reducing(
+func.func @masked_transfer_read_dynamic_rank_reducing_2_constant_mask(
+ %arg : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>) -> vector<1x[1]x3x1x[16]x1xi8> {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0 : i8
+ %mask = vector.constant_mask [1, 1, 2, 1, 16, 1] : vector<1x[1]x3x1x[16]x1xi1>
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0, %c0], %pad, %mask {in_bounds = [true, true, true, true, true, true]} :
+ memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>, vector<1x[1]x3x1x[16]x1xi8>
+ return %v : vector<1x[1]x3x1x[16]x1xi8>
+}
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_2_constant_mask
+// CHECK-SAME: %[[ARG:.+]]: memref<1x?x3x1x?x1xi8
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
+// CHECK-DAG: %[[PAD:.+]] = arith.constant 0 : i8
+// CHECK: %[[MASK:.+]] = vector.constant_mask [1, 2, 16] : vector<[1]x3x[16]xi1>
+// CHECK: %[[DIM1:.+]] = memref.dim %[[ARG]], %[[C1]] : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>
+// CHECK: %[[DIM4:.+]] = memref.dim %[[ARG]], %[[C4]] : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}>
+// CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8>
+
+func.func @masked_transfer_write_and_vector_rank_reducing_create_mask(
%arg : memref<1x1x3x1x16x1xf32>,
%vec : vector<1x3x1x16x1xf32>,
%mask_dim1 : index,
@@ -235,7 +271,7 @@ func.func @masked_transfer_write_and_vector_rank_reducing(
vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
return
}
-// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing
+// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing_create_mask
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
// CHECK-SAME: {{.*}}: vector<1x3x1x16x1xf32>,
// CHECK-SAME: %[[MASKDIM1:.+]]: index,
@@ -245,7 +281,23 @@ func.func @masked_transfer_write_and_vector_rank_reducing(
// CHECK-SAME: memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
-func.func @masked_transfer_write_dynamic_rank_reducing(
+func.func @masked_transfer_write_and_vector_rank_reducing_constant_mask(
+ %arg : memref<1x1x3x1x16x1xf32>,
+ %vec : vector<1x3x1x16x1xf32>) {
+ %c0 = arith.constant 0 : index
+ %mask = vector.constant_mask [1, 2, 1, 8, 1] : vector<1x3x1x16x1xi1>
+ vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask :
+ vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
+ return
+}
+// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing_constant_mask
+// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
+// CHECK: %[[MASK:.+]] = vector.constant_mask [2, 8] : vector<3x16xi1>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1]
+// CHECK-SAME: memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
+// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
+
+func.func @masked_transfer_write_dynamic_rank_reducing_create_mask(
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
%vec : vector<[16]x1xi8>,
%mask_dim0 : index) {
@@ -257,7 +309,7 @@ func.func @masked_transfer_write_dynamic_rank_reducing(
vector<[16]x1xi8>, memref<?x1xi8, strided<[?, ?], offset: ?>>
return
}
-// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing
+// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing_create_mask
// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
// CHECK-SAME: %{{.*}}: vector<[16]x1xi8>,
// CHECK-SAME: %[[MASK_DIM0:.+]]: index
@@ -267,53 +319,20 @@ func.func @masked_transfer_write_dynamic_rank_reducing(
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
// CHECK: vector.transfer_write {{.*}}, %[[SUBVIEW]][%[[C0]]], %[[MASK]] {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}>
-func.func @constant_masked_transfer_read_rank_reducing(
- %arg : memref<3x1xi8>) -> vector<3x1xi8> {
- %c0 = arith.constant 0 : index
- %pad = arith.constant 0 : i8
- %mask = vector.constant_mask [2, 1] : vector<3x1xi1>
- %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
- memref<3x1xi8>, vector<3x1xi8>
- return %v : vector<3x1xi8>
-}
-// CHECK-LABEL: func @constant_masked_transfer_read_rank_reducing
-// CHECK-SAME: %[[ARG:.+]]: memref<3x1xi8>
-// CHECK: %[[MASK:.+]] = vector.constant_mask [2] : vector<3xi1>
-// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [3, 1] [1, 1]
-// CHECK-SAME: memref<3x1xi8> to memref<3xi8, {{.*}}>
-// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true]} : memref<3xi8, {{.*}}>, vector<3xi8>
-
-func.func @constant_masked_transfer_write_rank_reducing(
- %arg : memref<1x1x3x1x16x1xf32>,
- %vec : vector<1x3x1x16x1xf32>) {
- %c0 = arith.constant 0 : index
- %mask = vector.constant_mask [1, 3, 1, 16, 1] : vector<1x3x1x16x1xi1>
- vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask :
- vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
- return
-}
-// CHECK-LABEL: func @constant_masked_transfer_write_rank_reducing
-// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
-// CHECK-NOT: vector.constant_mask
-// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1]
-// CHECK-SAME: memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
-// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
-
-func.func @constant_masked_transfer_write_trailing_unit_dim(
- %arg : memref<3x1xi32>,
- %vec : vector<4x1xi32>) {
+func.func @masked_transfer_write_dynamic_rank_reducing_constant_mask(
+ %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
+ %vec : vector<[16]x1xi8>) {
%c0 = arith.constant 0 : index
- %mask = vector.constant_mask [3, 1] : vector<4x1xi1>
+ %mask = vector.constant_mask [16, 1] : vector<[16]x1xi1>
vector.transfer_write %vec, %arg[%c0, %c0], %mask {in_bounds = [true, true]} :
- vector<4x1xi32>, memref<3x1xi32>
+ vector<[16]x1xi8>, memref<?x1xi8, strided<[?, ?], offset: ?>>
return
}
-// CHECK-LABEL: func @constant_masked_transfer_write_trailing_unit_dim
-// CHECK-SAME: %[[ARG:.+]]: memref<3x1xi32>
-// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<4xi1>
-// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [3, 1] [1, 1]
-// CHECK-SAME: memref<3x1xi32> to memref<3xi32, {{.*}}>
-// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true]} : vector<4xi32>, memref<3xi32, {{.*}}>
+// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing_constant_mask
+// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
+// CHECK-NOT: vector.constant_mask
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [{{.*}}, 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
+// CHECK: vector.transfer_write {{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}>
/// Only vector.create_mask and vector.constant_mask masks are supported.
func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1(
>From 504e5812da6f3d09d20b8f9948c7aed215651016 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 20 Mar 2026 14:40:19 -0400
Subject: [PATCH 7/7] Use matchPattern instead of defining op
---
.../Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 74962c3bb51fd..f4f598676d151 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -22,6 +22,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/STLExtras.h"
@@ -430,8 +431,7 @@ static VectorType trimNonScalableUnitDims(VectorType oldType) {
}
static bool isUnitDimMask(Value maskDimSize) {
- auto cst = maskDimSize.getDefiningOp<arith::ConstantIndexOp>();
- return cst && cst.value() == 1;
+ return matchPattern(maskDimSize, m_One());
}
static bool isUnitDimMask(int64_t maskDimSize) { return maskDimSize == 1; }
More information about the Mlir-commits
mailing list