[Mlir-commits] [mlir] 5072c02 - [mlir][vector] Drop trailing 1-dims from constant_mask (#187383)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 1 06:05:31 PDT 2026
Author: Erick Ochoa Lopez
Date: 2026-04-01T09:05:25-04:00
New Revision: 5072c020aae1636f687af797d4082a26a761c721
URL: https://github.com/llvm/llvm-project/commit/5072c020aae1636f687af797d4082a26a761c721
DIFF: https://github.com/llvm/llvm-project/commit/5072c020aae1636f687af797d4082a26a761c721.diff
LOG: [mlir][vector] Drop trailing 1-dims from constant_mask (#187383)
Generalize TransferReadDropUnitDimsPattern to also drop unit dimensions
when `vector::ConstantMaskOp` is used.
Previously TransferReadDropUnitDimsPattern would only drop unit
dimensions when `vector::CreateMaskOp` with a statically known operand
was used.
Assisted-by: Cursor
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index f8ece8b734c42..68ef49172e662 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2608,7 +2608,7 @@ def Vector_CreateMaskOp :
Vector_Op<"create_mask", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>
]>,
- Arguments<(ins Variadic<Index>:$operands)>,
+ Arguments<(ins Variadic<Index>:$mask_dim_sizes)>,
Results<(outs VectorOfAnyRankOf<[I1]>)> {
let summary = "creates a vector mask";
let description = [{
@@ -2654,7 +2654,7 @@ def Vector_CreateMaskOp :
let hasCanonicalizer = 1;
let hasVerifier = 1;
- let assemblyFormat = "$operands attr-dict `:` type(results)";
+ let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
}
def Vector_MaskOp : Vector_Op<"mask", [
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index babd321e484bd..ac7b84abc4e06 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -22,6 +22,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/STLExtras.h"
@@ -429,30 +430,33 @@ static VectorType trimNonScalableUnitDims(VectorType oldType) {
return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
}
-// Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
-static FailureOr<Value>
-createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
- vector::CreateMaskOp op) {
+static bool isUnitDimMask(Value maskDimSize) {
+ return matchPattern(maskDimSize, m_One());
+}
+static bool isUnitDimMask(int64_t maskDimSize) { return maskDimSize == 1; }
+
+/// Rewrites a mask op to drop non-scalable unit dimensions.
+/// Supports vector.create_mask and vector.constant_mask.
+template <typename MaskOp>
+static FailureOr<Value> maskDropNonScalableUnitDims(PatternRewriter &rewriter,
+ Location loc, MaskOp op) {
auto type = op.getType();
VectorType reducedType = trimNonScalableUnitDims(type);
if (reducedType.getRank() == type.getRank())
return failure();
- SmallVector<Value> reducedOperands;
- for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
- type.getShape(), type.getScalableDims(), op.getOperands())) {
+ using ElemType = std::decay_t<decltype(*op.getMaskDimSizes().begin())>;
+ SmallVector<ElemType> reduced;
+ for (auto [dim, dimIsScalable, elem] : llvm::zip_equal(
+ type.getShape(), type.getScalableDims(), op.getMaskDimSizes())) {
if (dim == 1 && !dimIsScalable) {
- // If the mask for the unit dim is not a constant of 1, do nothing.
- auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
- if (!constant || (constant.value() != 1))
+ if (!isUnitDimMask(elem))
return failure();
continue;
}
- reducedOperands.push_back(operand);
+ reduced.push_back(elem);
}
- return vector::CreateMaskOp::create(rewriter, loc, reducedType,
- reducedOperands)
- .getResult();
+ return MaskOp::create(rewriter, loc, reducedType, reduced).getResult();
}
namespace {
@@ -522,21 +526,23 @@ class TransferReadDropUnitDimsPattern
Value maskOp = transferReadOp.getMask();
if (maskOp) {
LDBG() << " -> Processing mask operation";
- auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
- if (!createMaskOp) {
- LDBG()
- << " -> Unsupported mask op, only 'vector.create_mask' supported";
- return rewriter.notifyMatchFailure(
- transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
- "currently supported");
- }
- FailureOr<Value> rankReducedCreateMask =
- createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
- if (failed(rankReducedCreateMask)) {
+ FailureOr<Value> rankReducedMaskOp = failure();
+ if (auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>())
+ rankReducedMaskOp =
+ maskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
+ else if (auto constantMaskOp =
+ maskOp.getDefiningOp<vector::ConstantMaskOp>())
+ rankReducedMaskOp =
+ maskDropNonScalableUnitDims(rewriter, loc, constantMaskOp);
+
+ if (failed(rankReducedMaskOp)) {
LDBG() << " -> Failed to reduce mask dimensions";
- return failure();
+ return rewriter.notifyMatchFailure(
+ transferReadOp,
+ "unsupported mask op, only 'vector.create_mask' and "
+ "'vector.constant_mask' are currently supported");
}
- maskOp = *rankReducedCreateMask;
+ maskOp = *rankReducedMaskOp;
LDBG() << " -> Successfully reduced mask dimensions";
}
@@ -636,22 +642,23 @@ class TransferWriteDropUnitDimsPattern
Value maskOp = transferWriteOp.getMask();
if (maskOp) {
LDBG() << " -> Processing mask operation";
- auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
- if (!createMaskOp) {
- LDBG()
- << " -> Unsupported mask op, only 'vector.create_mask' supported";
+ FailureOr<Value> rankReducedMask = failure();
+ if (auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>())
+ rankReducedMask =
+ maskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
+ else if (auto constantMaskOp =
+ maskOp.getDefiningOp<vector::ConstantMaskOp>())
+ rankReducedMask =
+ maskDropNonScalableUnitDims(rewriter, loc, constantMaskOp);
+
+ if (failed(rankReducedMask)) {
+ LDBG() << " -> Failed to reduce mask dimensions";
return rewriter.notifyMatchFailure(
transferWriteOp,
- "unsupported mask op, only 'vector.create_mask' is "
- "currently supported");
- }
- FailureOr<Value> rankReducedCreateMask =
- createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
- if (failed(rankReducedCreateMask)) {
- LDBG() << " -> Failed to reduce mask dimensions";
- return failure();
+ "unsupported mask op, only 'vector.create_mask' and "
+ "'vector.constant_mask' are currently supported");
}
- maskOp = *rankReducedCreateMask;
+ maskOp = *rankReducedMask;
LDBG() << " -> Successfully reduced mask dimensions";
}
LDBG() << " -> Creating rank-reduced subview and new transfer_write";
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index 8234351302f6b..d30ba64c09159 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -176,7 +176,7 @@ func.func @transfer_read_dynamic_rank_reducing(
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<?xi8, {{.*}}>, vector<[16]xi8>
-func.func @masked_transfer_read_dynamic_rank_reducing_1(
+func.func @masked_transfer_read_dynamic_rank_reducing_1_create_mask(
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
%mask_dim0 : index) -> vector<[16]x1xi8> {
%c0 = arith.constant 0 : index
@@ -187,7 +187,7 @@ func.func @masked_transfer_read_dynamic_rank_reducing_1(
memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
return %v : vector<[16]x1xi8>
}
-// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_1
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_1_create_mask
// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
// CHECK-SAME: %[[MASK_DIM0:.+]]: index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -197,7 +197,22 @@ func.func @masked_transfer_read_dynamic_rank_reducing_1(
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
// CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true]} : memref<?xi8, {{.*}}>, vector<[16]xi8>
-func.func @masked_transfer_read_dynamic_rank_reducing_2(
+func.func @masked_transfer_read_dynamic_rank_reducing_1_constant_mask(
+ %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>) -> vector<[16]x1xi8> {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0 : i8
+ %mask = vector.constant_mask [16, 1] : vector<[16]x1xi1>
+ %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
+ memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
+ return %v : vector<[16]x1xi8>
+}
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_1_constant_mask
+// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
+// CHECK-NOT: vector.constant_mask
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [{{.*}}, 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
+// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : memref<?xi8, {{.*}}>, vector<[16]xi8>
+
+func.func @masked_transfer_read_dynamic_rank_reducing_2_create_mask(
%arg : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>,
%mask_dim1 : index, %mask_dim4 : index) -> vector<1x[1]x3x1x[16]x1xi8> {
%c0 = arith.constant 0 : index
@@ -209,7 +224,7 @@ func.func @masked_transfer_read_dynamic_rank_reducing_2(
memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>, vector<1x[1]x3x1x[16]x1xi8>
return %v : vector<1x[1]x3x1x[16]x1xi8>
}
-// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_2
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_2_create_mask
// CHECK-SAME: %[[ARG:.+]]: memref<1x?x3x1x?x1xi8
// CHECK-SAME: %[[MASK_DIM1:.+]]: index, %[[MASK_DIM4:.+]]: index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
@@ -223,7 +238,28 @@ func.func @masked_transfer_read_dynamic_rank_reducing_2(
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}>
// CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8>
-func.func @masked_transfer_write_and_vector_rank_reducing(
+func.func @masked_transfer_read_dynamic_rank_reducing_2_constant_mask(
+ %arg : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>) -> vector<1x[1]x3x1x[16]x1xi8> {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0 : i8
+ %mask = vector.constant_mask [1, 1, 2, 1, 16, 1] : vector<1x[1]x3x1x[16]x1xi1>
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0, %c0], %pad, %mask {in_bounds = [true, true, true, true, true, true]} :
+ memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>, vector<1x[1]x3x1x[16]x1xi8>
+ return %v : vector<1x[1]x3x1x[16]x1xi8>
+}
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_2_constant_mask
+// CHECK-SAME: %[[ARG:.+]]: memref<1x?x3x1x?x1xi8
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
+// CHECK-DAG: %[[PAD:.+]] = arith.constant 0 : i8
+// CHECK: %[[MASK:.+]] = vector.constant_mask [1, 2, 16] : vector<[1]x3x[16]xi1>
+// CHECK: %[[DIM1:.+]] = memref.dim %[[ARG]], %[[C1]] : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>
+// CHECK: %[[DIM4:.+]] = memref.dim %[[ARG]], %[[C4]] : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}>
+// CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8>
+
+func.func @masked_transfer_write_and_vector_rank_reducing_create_mask(
%arg : memref<1x1x3x1x16x1xf32>,
%vec : vector<1x3x1x16x1xf32>,
%mask_dim1 : index,
@@ -235,7 +271,7 @@ func.func @masked_transfer_write_and_vector_rank_reducing(
vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
return
}
-// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing
+// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing_create_mask
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
// CHECK-SAME: {{.*}}: vector<1x3x1x16x1xf32>,
// CHECK-SAME: %[[MASKDIM1:.+]]: index,
@@ -245,7 +281,23 @@ func.func @masked_transfer_write_and_vector_rank_reducing(
// CHECK-SAME: memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
-func.func @masked_transfer_write_dynamic_rank_reducing(
+func.func @masked_transfer_write_and_vector_rank_reducing_constant_mask(
+ %arg : memref<1x1x3x1x16x1xf32>,
+ %vec : vector<1x3x1x16x1xf32>) {
+ %c0 = arith.constant 0 : index
+ %mask = vector.constant_mask [1, 2, 1, 8, 1] : vector<1x3x1x16x1xi1>
+ vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask :
+ vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
+ return
+}
+// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing_constant_mask
+// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
+// CHECK: %[[MASK:.+]] = vector.constant_mask [2, 8] : vector<3x16xi1>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1]
+// CHECK-SAME: memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
+// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
+
+func.func @masked_transfer_write_dynamic_rank_reducing_create_mask(
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
%vec : vector<[16]x1xi8>,
%mask_dim0 : index) {
@@ -257,7 +309,7 @@ func.func @masked_transfer_write_dynamic_rank_reducing(
vector<[16]x1xi8>, memref<?x1xi8, strided<[?, ?], offset: ?>>
return
}
-// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing
+// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing_create_mask
// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
// CHECK-SAME: %{{.*}}: vector<[16]x1xi8>,
// CHECK-SAME: %[[MASK_DIM0:.+]]: index
@@ -267,7 +319,22 @@ func.func @masked_transfer_write_dynamic_rank_reducing(
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
// CHECK: vector.transfer_write {{.*}}, %[[SUBVIEW]][%[[C0]]], %[[MASK]] {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}>
-/// Only masks operands of vector.create_mask are currently supported.
+func.func @masked_transfer_write_dynamic_rank_reducing_constant_mask(
+ %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
+ %vec : vector<[16]x1xi8>) {
+ %c0 = arith.constant 0 : index
+ %mask = vector.constant_mask [16, 1] : vector<[16]x1xi1>
+ vector.transfer_write %vec, %arg[%c0, %c0], %mask {in_bounds = [true, true]} :
+ vector<[16]x1xi8>, memref<?x1xi8, strided<[?, ?], offset: ?>>
+ return
+}
+// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing_constant_mask
+// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
+// CHECK-NOT: vector.constant_mask
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [{{.*}}, 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
+// CHECK: vector.transfer_write {{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}>
+
+/// Only vector.create_mask and vector.constant_mask masks are supported.
func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1(
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
%mask : vector<[16]x1xi1>) -> vector<[16]x1xi8> {
More information about the Mlir-commits
mailing list