[Mlir-commits] [mlir] 5eb195f - [mlir][Vector] Fold vector.constant_mask to SplatElementsAttr (#146724)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 4 06:44:39 PDT 2025
Author: Kunwar Grover
Date: 2025-07-04T14:44:36+01:00
New Revision: 5eb195fa90c7d39c601e0bccd90a79250ca86e49
URL: https://github.com/llvm/llvm-project/commit/5eb195fa90c7d39c601e0bccd90a79250ca86e49
DIFF: https://github.com/llvm/llvm-project/commit/5eb195fa90c7d39c601e0bccd90a79250ca86e49.diff
LOG: [mlir][Vector] Fold vector.constant_mask to SplatElementsAttr (#146724)
Adds a folder to vector.constant_mask to fold to SplatElementsAttr when
possible
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Dialect/Vector/vector-mem-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index dfb2756e57bea..ec2c87ca1cf44 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2492,6 +2492,7 @@ def Vector_ConstantMaskOp :
let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
let hasVerifier = 1;
+ let hasFolder = 1;
}
def Vector_CreateMaskOp :
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..29f71b7cb9246 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6594,6 +6594,28 @@ bool ConstantMaskOp::isAllOnesMask() {
return true;
}
+OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
+ ArrayRef<int64_t> bounds = getMaskDimSizes();
+ ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
+
+ auto createBoolSplat = [&](bool x) {
+ return SplatElementsAttr::get(getVectorType(),
+ BoolAttr::get(getContext(), x));
+ };
+
+ // Check the corner case of 0-D vectors first.
+ if (vectorSizes.empty()) {
+ assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
+ return createBoolSplat(bounds[0] == 1);
+ }
+ // Fold vector.constant_mask to splat if possible.
+ if (bounds == vectorSizes)
+ return createBoolSplat(true);
+ if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
+ return createBoolSplat(false);
+ return OpFoldResult();
+}
+
//===----------------------------------------------------------------------===//
// CreateMaskOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
index 99d6a3dc390e0..35fd7c33e4cfe 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
@@ -23,6 +23,7 @@
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[mask:.*]] = arith.constant dense<true> : vector<16xi1>
// CHECK: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
// CHECK: %[[a:.*]] = arith.extui %[[p]] : i32 to i64
// CHECK: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
@@ -31,7 +32,6 @@
// CHECK: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
// CHECK: %[[boundary:.*]] = affine.apply #[[$map0]]()[%[[q]], %[[s]]]
// CHECK: scf.for %[[i:.*]] = %[[q]] to %[[boundary]] step %[[c16]] {
-// CHECK: %[[mask:.*]] = vector.constant_mask [16] : vector<16xi1>
// CHECK: %[[li:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xi32>, vector<16xi32>
// CHECK: %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64>
// CHECK: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..8cda8d47cb908 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -14,7 +14,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
// CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
func.func @create_scalable_vector_mask_to_constant_mask() -> (vector<[8]xi1>) {
%c-1 = arith.constant -1 : index
- // CHECK: vector.constant_mask [0] : vector<[8]xi1>
+ // CHECK: arith.constant dense<false> : vector<[8]xi1>
%0 = vector.create_mask %c-1 : vector<[8]xi1>
return %0 : vector<[8]xi1>
}
@@ -36,7 +36,7 @@ func.func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>)
func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi1>) {
%cneg2 = arith.constant -2 : index
%c5 = arith.constant 5 : index
- // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+ // CHECK: arith.constant dense<false> : vector<4x3xi1>
%0 = vector.create_mask %c5, %cneg2 : vector<4x3xi1>
return %0 : vector<4x3xi1>
}
@@ -47,7 +47,7 @@ func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi
func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3xi1>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
- // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+ // CHECK: arith.constant dense<false> : vector<4x3xi1>
%0 = vector.create_mask %c0, %c2 : vector<4x3xi1>
return %0 : vector<4x3xi1>
}
@@ -60,7 +60,7 @@ func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x
%c16 = arith.constant 16 : index
%0 = vector.vscale
%1 = arith.muli %0, %c16 : index
- // CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
+ // CHECK: arith.constant dense<true> : vector<8x[16]xi1>
%10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
return %10 : vector<8x[16]xi1>
}
@@ -272,6 +272,30 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
// -----
+// CHECK-LABEL: constant_mask_to_true_splat
+func.func @constant_mask_to_true_splat() -> vector<2x4xi1> {
+ // CHECK: arith.constant dense<true>
+ // CHECK-NOT: vector.constant_mask
+ %0 = vector.constant_mask [2, 4] : vector<2x4xi1>
+ return %0 : vector<2x4xi1>
+}
+
+// CHECK-LABEL: constant_mask_to_false_splat
+func.func @constant_mask_to_false_splat() -> vector<2x4xi1> {
+ // CHECK: arith.constant dense<false>
+ // CHECK-NOT: vector.constant_mask
+ %0 = vector.constant_mask [0, 0] : vector<2x4xi1>
+ return %0 : vector<2x4xi1>
+}
+
+// CHECK-LABEL: constant_mask_to_true_splat_0d
+func.func @constant_mask_to_true_splat_0d() -> vector<i1> {
+ // CHECK: arith.constant dense<true>
+ // CHECK-NOT: vector.constant_mask
+ %0 = vector.constant_mask [1] : vector<i1>
+ return %0 : vector<i1>
+}
+
// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
// CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
@@ -289,7 +313,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
%1 = vector.extract_strided_slice %0
{offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}
: vector<4x3xi1> to vector<2x2xi1>
- // CHECK: vector.constant_mask [2, 2] : vector<2x2xi1>
+ // CHECK: arith.constant dense<true> : vector<2x2xi1>
return %1 : vector<2x2xi1>
}
@@ -322,7 +346,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
%1 = vector.extract_strided_slice %0
{offsets = [2, 0], sizes = [2, 2], strides = [1, 1]}
: vector<4x3xi1> to vector<2x2xi1>
- // CHECK: vector.constant_mask [0, 0] : vector<2x2xi1>
+ // CHECK: arith.constant dense<false> : vector<2x2xi1>
return %1 : vector<2x2xi1>
}
@@ -333,7 +357,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
%1 = vector.extract_strided_slice %0
{offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
: vector<4x3xi1> to vector<2x1xi1>
- // CHECK: vector.constant_mask [0, 0] : vector<2x1xi1>
+ // CHECK: arith.constant dense<false> : vector<2x1xi1>
return %1 : vector<2x1xi1>
}
@@ -344,7 +368,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
%1 = vector.extract_strided_slice %0
{offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
: vector<4x3xi1> to vector<2x1xi1>
- // CHECK: vector.constant_mask [2, 1] : vector<2x1xi1>
+ // CHECK: arith.constant dense<true> : vector<2x1xi1>
return %1 : vector<2x1xi1>
}
diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index 8cb25c7578495..e6593320f1bde 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -83,7 +83,7 @@ func.func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>) {
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT: %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-NEXT: return %[[G]] : vector<16xf32>
func.func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
@@ -112,7 +112,7 @@ func.func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru:
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT: %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
// CHECK-NEXT: vector.scatter %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
// CHECK-NEXT: return
func.func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
More information about the Mlir-commits
mailing list