[Mlir-commits] [mlir] 5eb195f - [mlir][Vector] Fold vector.constant_mask to SplatElementsAttr (#146724)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 4 06:44:39 PDT 2025


Author: Kunwar Grover
Date: 2025-07-04T14:44:36+01:00
New Revision: 5eb195fa90c7d39c601e0bccd90a79250ca86e49

URL: https://github.com/llvm/llvm-project/commit/5eb195fa90c7d39c601e0bccd90a79250ca86e49
DIFF: https://github.com/llvm/llvm-project/commit/5eb195fa90c7d39c601e0bccd90a79250ca86e49.diff

LOG: [mlir][Vector] Fold vector.constant_mask to SplatElementsAttr (#146724)

Adds a folder to vector.constant_mask to fold to SplatElementsAttr when
possible

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/vector-mem-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index dfb2756e57bea..ec2c87ca1cf44 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2492,6 +2492,7 @@ def Vector_ConstantMaskOp :
 
   let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
   let hasVerifier = 1;
+  let hasFolder = 1;
 }
 
 def Vector_CreateMaskOp :

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..29f71b7cb9246 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6594,6 +6594,28 @@ bool ConstantMaskOp::isAllOnesMask() {
   return true;
 }
 
+OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
+  ArrayRef<int64_t> bounds = getMaskDimSizes();
+  ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
+
+  auto createBoolSplat = [&](bool x) {
+    return SplatElementsAttr::get(getVectorType(),
+                                  BoolAttr::get(getContext(), x));
+  };
+
+  // Check the corner case of 0-D vectors first.
+  if (vectorSizes.empty()) {
+    assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
+    return createBoolSplat(bounds[0] == 1);
+  }
+  // Fold vector.constant_mask to splat if possible.
+  if (bounds == vectorSizes)
+    return createBoolSplat(true);
+  if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
+    return createBoolSplat(false);
+  return OpFoldResult();
+}
+
 //===----------------------------------------------------------------------===//
 // CreateMaskOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
index 99d6a3dc390e0..35fd7c33e4cfe 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
@@ -23,6 +23,7 @@
 // CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[c1:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[c16:.*]] = arith.constant 16 : index
+// CHECK-DAG:   %[[mask:.*]] = arith.constant dense<true> : vector<16xi1>
 // CHECK:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
 // CHECK:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
 // CHECK:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
@@ -31,7 +32,6 @@
 // CHECK:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
 // CHECK:       %[[boundary:.*]] = affine.apply #[[$map0]]()[%[[q]], %[[s]]]
 // CHECK:       scf.for %[[i:.*]] = %[[q]] to %[[boundary]] step %[[c16]] {
-// CHECK:         %[[mask:.*]] = vector.constant_mask [16] : vector<16xi1>
 // CHECK:         %[[li:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xi32>, vector<16xi32>
 // CHECK:         %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64>
 // CHECK:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..8cda8d47cb908 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -14,7 +14,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
 // CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
 func.func @create_scalable_vector_mask_to_constant_mask() -> (vector<[8]xi1>) {
   %c-1 = arith.constant -1 : index
-  // CHECK: vector.constant_mask [0] : vector<[8]xi1>
+  // CHECK: arith.constant dense<false> : vector<[8]xi1>
   %0 = vector.create_mask %c-1 : vector<[8]xi1>
   return %0 : vector<[8]xi1>
 }
@@ -36,7 +36,7 @@ func.func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>)
 func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi1>) {
   %cneg2 = arith.constant -2 : index
   %c5 = arith.constant 5 : index
-  // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+  // CHECK: arith.constant dense<false> : vector<4x3xi1>
   %0 = vector.create_mask %c5, %cneg2 : vector<4x3xi1>
   return %0 : vector<4x3xi1>
 }
@@ -47,7 +47,7 @@ func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi
 func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3xi1>) {
   %c2 = arith.constant 2 : index
   %c0 = arith.constant 0 : index
-  // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+  // CHECK: arith.constant dense<false> : vector<4x3xi1>
   %0 = vector.create_mask %c0, %c2 : vector<4x3xi1>
   return %0 : vector<4x3xi1>
 }
@@ -60,7 +60,7 @@ func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x
   %c16 = arith.constant 16 : index
   %0 = vector.vscale
   %1 = arith.muli %0, %c16 : index
-  // CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
+  // CHECK: arith.constant dense<true> : vector<8x[16]xi1>
   %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
   return %10 : vector<8x[16]xi1>
 }
@@ -272,6 +272,30 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
 
 // -----
 
+// CHECK-LABEL: constant_mask_to_true_splat
+func.func @constant_mask_to_true_splat() -> vector<2x4xi1> {
+  // CHECK: arith.constant dense<true>
+  // CHECK-NOT: vector.constant_mask
+  %0 = vector.constant_mask [2, 4] : vector<2x4xi1>
+  return %0 : vector<2x4xi1>
+}
+
+// CHECK-LABEL: constant_mask_to_false_splat
+func.func @constant_mask_to_false_splat() -> vector<2x4xi1> {
+  // CHECK: arith.constant dense<false>
+  // CHECK-NOT: vector.constant_mask
+  %0 = vector.constant_mask [0, 0] : vector<2x4xi1>
+  return %0 : vector<2x4xi1>
+}
+
+// CHECK-LABEL: constant_mask_to_true_splat_0d
+func.func @constant_mask_to_true_splat_0d() -> vector<i1> {
+  // CHECK: arith.constant dense<true>
+  // CHECK-NOT: vector.constant_mask
+  %0 = vector.constant_mask [1] : vector<i1>
+  return %0 : vector<i1>
+}
+
 // CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
 func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
   //     CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
@@ -289,7 +313,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
   %1 = vector.extract_strided_slice %0
     {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}
       : vector<4x3xi1> to vector<2x2xi1>
-  // CHECK: vector.constant_mask [2, 2] : vector<2x2xi1>
+  // CHECK: arith.constant dense<true> : vector<2x2xi1>
   return %1 : vector<2x2xi1>
 }
 
@@ -322,7 +346,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
   %1 = vector.extract_strided_slice %0
     {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]}
       : vector<4x3xi1> to vector<2x2xi1>
-  // CHECK: vector.constant_mask [0, 0] : vector<2x2xi1>
+  // CHECK: arith.constant dense<false> : vector<2x2xi1>
   return %1 : vector<2x2xi1>
 }
 
@@ -333,7 +357,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
   %1 = vector.extract_strided_slice %0
     {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
       : vector<4x3xi1> to vector<2x1xi1>
-  // CHECK: vector.constant_mask [0, 0] : vector<2x1xi1>
+  // CHECK: arith.constant dense<false> : vector<2x1xi1>
   return %1 : vector<2x1xi1>
 }
 
@@ -344,7 +368,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
   %1 = vector.extract_strided_slice %0
     {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
       : vector<4x3xi1> to vector<2x1xi1>
-  // CHECK: vector.constant_mask [2, 1] : vector<2x1xi1>
+  // CHECK: arith.constant dense<true> : vector<2x1xi1>
   return %1 : vector<2x1xi1>
 }
 

diff  --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index 8cb25c7578495..e6593320f1bde 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -83,7 +83,7 @@ func.func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>)  {
 // CHECK-SAME:                  %[[A1:.*]]: vector<16xi32>,
 // CHECK-SAME:                  %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
 // CHECK-NEXT:      %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT:      %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT:      %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
 // CHECK-NEXT:      %[[G:.*]] = vector.gather %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 // CHECK-NEXT:      return %[[G]] : vector<16xf32>
 func.func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
@@ -112,7 +112,7 @@ func.func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru:
 // CHECK-SAME:                   %[[A1:.*]]: vector<16xi32>,
 // CHECK-SAME:                   %[[A2:.*]]: vector<16xf32>) {
 // CHECK-NEXT:      %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT:      %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT:      %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
 // CHECK-NEXT:      vector.scatter %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 // CHECK-NEXT:      return
 func.func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {


        


More information about the Mlir-commits mailing list