[Mlir-commits] [mlir] [mlir][Vector] Fold vector.constant_mask to SplatElementsAttr (PR #146724)
Kunwar Grover
llvmlistbot at llvm.org
Fri Jul 4 03:26:04 PDT 2025
https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/146724
>From 0a7f4c973abc5f946b5fce14aad8d6eaee517836 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 2 Jul 2025 16:18:34 +0100
Subject: [PATCH 1/2] [mlir][Vector] Fold vector.constant_mask to splat
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 1 +
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 20 ++++++++++
.../SparseTensor/sparse_vector_peeled.mlir | 2 +-
mlir/test/Dialect/Vector/canonicalize.mlir | 40 +++++++++++++++----
.../Dialect/Vector/vector-mem-transforms.mlir | 4 +-
5 files changed, 56 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index dfb2756e57bea..ec2c87ca1cf44 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2492,6 +2492,7 @@ def Vector_ConstantMaskOp :
let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
let hasVerifier = 1;
+ let hasFolder = 1;
}
def Vector_CreateMaskOp :
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..a462b3701ddbb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6594,6 +6594,26 @@ bool ConstantMaskOp::isAllOnesMask() {
return true;
}
+static Attribute createBoolSplat(ShapedType ty, bool x) {
+ return SplatElementsAttr::get(ty, BoolAttr::get(ty.getContext(), x));
+}
+
+OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
+ ArrayRef<int64_t> bounds = getMaskDimSizes();
+ ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
+ // Check the corner case of 0-D vectors first.
+ if (vectorSizes.size() == 0) {
+ assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
+ return createBoolSplat(getVectorType(), bounds[0] == 1);
+ }
+ // Fold vector.constant_mask to splat if possible.
+ if (bounds == vectorSizes)
+ return createBoolSplat(getVectorType(), true);
+ if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
+ return createBoolSplat(getVectorType(), false);
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// CreateMaskOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
index 99d6a3dc390e0..35fd7c33e4cfe 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
@@ -23,6 +23,7 @@
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[mask:.*]] = arith.constant dense<true> : vector<16xi1>
// CHECK: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
// CHECK: %[[a:.*]] = arith.extui %[[p]] : i32 to i64
// CHECK: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
@@ -31,7 +32,6 @@
// CHECK: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
// CHECK: %[[boundary:.*]] = affine.apply #[[$map0]]()[%[[q]], %[[s]]]
// CHECK: scf.for %[[i:.*]] = %[[q]] to %[[boundary]] step %[[c16]] {
-// CHECK: %[[mask:.*]] = vector.constant_mask [16] : vector<16xi1>
// CHECK: %[[li:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xi32>, vector<16xi32>
// CHECK: %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64>
// CHECK: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..8cda8d47cb908 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -14,7 +14,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
// CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
func.func @create_scalable_vector_mask_to_constant_mask() -> (vector<[8]xi1>) {
%c-1 = arith.constant -1 : index
- // CHECK: vector.constant_mask [0] : vector<[8]xi1>
+ // CHECK: arith.constant dense<false> : vector<[8]xi1>
%0 = vector.create_mask %c-1 : vector<[8]xi1>
return %0 : vector<[8]xi1>
}
@@ -36,7 +36,7 @@ func.func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>)
func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi1>) {
%cneg2 = arith.constant -2 : index
%c5 = arith.constant 5 : index
- // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+ // CHECK: arith.constant dense<false> : vector<4x3xi1>
%0 = vector.create_mask %c5, %cneg2 : vector<4x3xi1>
return %0 : vector<4x3xi1>
}
@@ -47,7 +47,7 @@ func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi
func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3xi1>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
- // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+ // CHECK: arith.constant dense<false> : vector<4x3xi1>
%0 = vector.create_mask %c0, %c2 : vector<4x3xi1>
return %0 : vector<4x3xi1>
}
@@ -60,7 +60,7 @@ func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x
%c16 = arith.constant 16 : index
%0 = vector.vscale
%1 = arith.muli %0, %c16 : index
- // CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
+ // CHECK: arith.constant dense<true> : vector<8x[16]xi1>
%10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
return %10 : vector<8x[16]xi1>
}
@@ -272,6 +272,30 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
// -----
+// CHECK-LABEL: constant_mask_to_true_splat
+func.func @constant_mask_to_true_splat() -> vector<2x4xi1> {
+ // CHECK: arith.constant dense<true>
+ // CHECK-NOT: vector.constant_mask
+ %0 = vector.constant_mask [2, 4] : vector<2x4xi1>
+ return %0 : vector<2x4xi1>
+}
+
+// CHECK-LABEL: constant_mask_to_false_splat
+func.func @constant_mask_to_false_splat() -> vector<2x4xi1> {
+ // CHECK: arith.constant dense<false>
+ // CHECK-NOT: vector.constant_mask
+ %0 = vector.constant_mask [0, 0] : vector<2x4xi1>
+ return %0 : vector<2x4xi1>
+}
+
+// CHECK-LABEL: constant_mask_to_true_splat_0d
+func.func @constant_mask_to_true_splat_0d() -> vector<i1> {
+ // CHECK: arith.constant dense<true>
+ // CHECK-NOT: vector.constant_mask
+ %0 = vector.constant_mask [1] : vector<i1>
+ return %0 : vector<i1>
+}
+
// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
// CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
@@ -289,7 +313,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
%1 = vector.extract_strided_slice %0
{offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}
: vector<4x3xi1> to vector<2x2xi1>
- // CHECK: vector.constant_mask [2, 2] : vector<2x2xi1>
+ // CHECK: arith.constant dense<true> : vector<2x2xi1>
return %1 : vector<2x2xi1>
}
@@ -322,7 +346,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
%1 = vector.extract_strided_slice %0
{offsets = [2, 0], sizes = [2, 2], strides = [1, 1]}
: vector<4x3xi1> to vector<2x2xi1>
- // CHECK: vector.constant_mask [0, 0] : vector<2x2xi1>
+ // CHECK: arith.constant dense<false> : vector<2x2xi1>
return %1 : vector<2x2xi1>
}
@@ -333,7 +357,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
%1 = vector.extract_strided_slice %0
{offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
: vector<4x3xi1> to vector<2x1xi1>
- // CHECK: vector.constant_mask [0, 0] : vector<2x1xi1>
+ // CHECK: arith.constant dense<false> : vector<2x1xi1>
return %1 : vector<2x1xi1>
}
@@ -344,7 +368,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
%1 = vector.extract_strided_slice %0
{offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
: vector<4x3xi1> to vector<2x1xi1>
- // CHECK: vector.constant_mask [2, 1] : vector<2x1xi1>
+ // CHECK: arith.constant dense<true> : vector<2x1xi1>
return %1 : vector<2x1xi1>
}
diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index 8cb25c7578495..e6593320f1bde 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -83,7 +83,7 @@ func.func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>) {
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT: %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-NEXT: return %[[G]] : vector<16xf32>
func.func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
@@ -112,7 +112,7 @@ func.func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru:
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT: %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
// CHECK-NEXT: vector.scatter %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
// CHECK-NEXT: return
func.func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
>From 5946db58b58c6414d53a613d06ccd76d845da248 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 4 Jul 2025 11:25:31 +0100
Subject: [PATCH 2/2] Address comments
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 20 +++++++++++---------
1 file changed, 11 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a462b3701ddbb..29f71b7cb9246 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6594,24 +6594,26 @@ bool ConstantMaskOp::isAllOnesMask() {
return true;
}
-static Attribute createBoolSplat(ShapedType ty, bool x) {
- return SplatElementsAttr::get(ty, BoolAttr::get(ty.getContext(), x));
-}
-
OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
ArrayRef<int64_t> bounds = getMaskDimSizes();
ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
+
+ auto createBoolSplat = [&](bool x) {
+ return SplatElementsAttr::get(getVectorType(),
+ BoolAttr::get(getContext(), x));
+ };
+
// Check the corner case of 0-D vectors first.
- if (vectorSizes.size() == 0) {
+ if (vectorSizes.empty()) {
assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
- return createBoolSplat(getVectorType(), bounds[0] == 1);
+ return createBoolSplat(bounds[0] == 1);
}
// Fold vector.constant_mask to splat if possible.
if (bounds == vectorSizes)
- return createBoolSplat(getVectorType(), true);
+ return createBoolSplat(true);
if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
- return createBoolSplat(getVectorType(), false);
- return {};
+ return createBoolSplat(false);
+ return OpFoldResult();
}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list