[Mlir-commits] [mlir] [mlir][Vector] Fold vector.constant_mask to SplatElementsAttr (PR #146724)

Kunwar Grover llvmlistbot at llvm.org
Fri Jul 4 03:26:04 PDT 2025


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/146724

>From 0a7f4c973abc5f946b5fce14aad8d6eaee517836 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 2 Jul 2025 16:18:34 +0100
Subject: [PATCH 1/2] [mlir][Vector] Fold vector.constant_mask to splat

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  1 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 20 ++++++++++
 .../SparseTensor/sparse_vector_peeled.mlir    |  2 +-
 mlir/test/Dialect/Vector/canonicalize.mlir    | 40 +++++++++++++++----
 .../Dialect/Vector/vector-mem-transforms.mlir |  4 +-
 5 files changed, 56 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index dfb2756e57bea..ec2c87ca1cf44 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2492,6 +2492,7 @@ def Vector_ConstantMaskOp :
 
   let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
   let hasVerifier = 1;
+  let hasFolder = 1;
 }
 
 def Vector_CreateMaskOp :
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..a462b3701ddbb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6594,6 +6594,26 @@ bool ConstantMaskOp::isAllOnesMask() {
   return true;
 }
 
+static Attribute createBoolSplat(ShapedType ty, bool x) {
+  return SplatElementsAttr::get(ty, BoolAttr::get(ty.getContext(), x));
+}
+
+OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
+  ArrayRef<int64_t> bounds = getMaskDimSizes();
+  ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
+  // Check the corner case of 0-D vectors first.
+  if (vectorSizes.size() == 0) {
+    assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
+    return createBoolSplat(getVectorType(), bounds[0] == 1);
+  }
+  // Fold vector.constant_mask to splat if possible.
+  if (bounds == vectorSizes)
+    return createBoolSplat(getVectorType(), true);
+  if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
+    return createBoolSplat(getVectorType(), false);
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // CreateMaskOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
index 99d6a3dc390e0..35fd7c33e4cfe 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
@@ -23,6 +23,7 @@
 // CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[c1:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[c16:.*]] = arith.constant 16 : index
+// CHECK-DAG:   %[[mask:.*]] = arith.constant dense<true> : vector<16xi1>
 // CHECK:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
 // CHECK:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
 // CHECK:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
@@ -31,7 +32,6 @@
 // CHECK:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
 // CHECK:       %[[boundary:.*]] = affine.apply #[[$map0]]()[%[[q]], %[[s]]]
 // CHECK:       scf.for %[[i:.*]] = %[[q]] to %[[boundary]] step %[[c16]] {
-// CHECK:         %[[mask:.*]] = vector.constant_mask [16] : vector<16xi1>
 // CHECK:         %[[li:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xi32>, vector<16xi32>
 // CHECK:         %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64>
 // CHECK:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..8cda8d47cb908 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -14,7 +14,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
 // CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
 func.func @create_scalable_vector_mask_to_constant_mask() -> (vector<[8]xi1>) {
   %c-1 = arith.constant -1 : index
-  // CHECK: vector.constant_mask [0] : vector<[8]xi1>
+  // CHECK: arith.constant dense<false> : vector<[8]xi1>
   %0 = vector.create_mask %c-1 : vector<[8]xi1>
   return %0 : vector<[8]xi1>
 }
@@ -36,7 +36,7 @@ func.func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>)
 func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi1>) {
   %cneg2 = arith.constant -2 : index
   %c5 = arith.constant 5 : index
-  // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+  // CHECK: arith.constant dense<false> : vector<4x3xi1>
   %0 = vector.create_mask %c5, %cneg2 : vector<4x3xi1>
   return %0 : vector<4x3xi1>
 }
@@ -47,7 +47,7 @@ func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi
 func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3xi1>) {
   %c2 = arith.constant 2 : index
   %c0 = arith.constant 0 : index
-  // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+  // CHECK: arith.constant dense<false> : vector<4x3xi1>
   %0 = vector.create_mask %c0, %c2 : vector<4x3xi1>
   return %0 : vector<4x3xi1>
 }
@@ -60,7 +60,7 @@ func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x
   %c16 = arith.constant 16 : index
   %0 = vector.vscale
   %1 = arith.muli %0, %c16 : index
-  // CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
+  // CHECK: arith.constant dense<true> : vector<8x[16]xi1>
   %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
   return %10 : vector<8x[16]xi1>
 }
@@ -272,6 +272,30 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
 
 // -----
 
+// CHECK-LABEL: constant_mask_to_true_splat
+func.func @constant_mask_to_true_splat() -> vector<2x4xi1> {
+  // CHECK: arith.constant dense<true>
+  // CHECK-NOT: vector.constant_mask
+  %0 = vector.constant_mask [2, 4] : vector<2x4xi1>
+  return %0 : vector<2x4xi1>
+}
+
+// CHECK-LABEL: constant_mask_to_false_splat
+func.func @constant_mask_to_false_splat() -> vector<2x4xi1> {
+  // CHECK: arith.constant dense<false>
+  // CHECK-NOT: vector.constant_mask
+  %0 = vector.constant_mask [0, 0] : vector<2x4xi1>
+  return %0 : vector<2x4xi1>
+}
+
+// CHECK-LABEL: constant_mask_to_true_splat_0d
+func.func @constant_mask_to_true_splat_0d() -> vector<i1> {
+  // CHECK: arith.constant dense<true>
+  // CHECK-NOT: vector.constant_mask
+  %0 = vector.constant_mask [1] : vector<i1>
+  return %0 : vector<i1>
+}
+
 // CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
 func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
   //     CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
@@ -289,7 +313,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
   %1 = vector.extract_strided_slice %0
     {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}
       : vector<4x3xi1> to vector<2x2xi1>
-  // CHECK: vector.constant_mask [2, 2] : vector<2x2xi1>
+  // CHECK: arith.constant dense<true> : vector<2x2xi1>
   return %1 : vector<2x2xi1>
 }
 
@@ -322,7 +346,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
   %1 = vector.extract_strided_slice %0
     {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]}
       : vector<4x3xi1> to vector<2x2xi1>
-  // CHECK: vector.constant_mask [0, 0] : vector<2x2xi1>
+  // CHECK: arith.constant dense<false> : vector<2x2xi1>
   return %1 : vector<2x2xi1>
 }
 
@@ -333,7 +357,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
   %1 = vector.extract_strided_slice %0
     {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
       : vector<4x3xi1> to vector<2x1xi1>
-  // CHECK: vector.constant_mask [0, 0] : vector<2x1xi1>
+  // CHECK: arith.constant dense<false> : vector<2x1xi1>
   return %1 : vector<2x1xi1>
 }
 
@@ -344,7 +368,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
   %1 = vector.extract_strided_slice %0
     {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
       : vector<4x3xi1> to vector<2x1xi1>
-  // CHECK: vector.constant_mask [2, 1] : vector<2x1xi1>
+  // CHECK: arith.constant dense<true> : vector<2x1xi1>
   return %1 : vector<2x1xi1>
 }
 
diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index 8cb25c7578495..e6593320f1bde 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -83,7 +83,7 @@ func.func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>)  {
 // CHECK-SAME:                  %[[A1:.*]]: vector<16xi32>,
 // CHECK-SAME:                  %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
 // CHECK-NEXT:      %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT:      %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT:      %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
 // CHECK-NEXT:      %[[G:.*]] = vector.gather %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 // CHECK-NEXT:      return %[[G]] : vector<16xf32>
 func.func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
@@ -112,7 +112,7 @@ func.func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru:
 // CHECK-SAME:                   %[[A1:.*]]: vector<16xi32>,
 // CHECK-SAME:                   %[[A2:.*]]: vector<16xf32>) {
 // CHECK-NEXT:      %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT:      %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT:      %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
 // CHECK-NEXT:      vector.scatter %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 // CHECK-NEXT:      return
 func.func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {

>From 5946db58b58c6414d53a613d06ccd76d845da248 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 4 Jul 2025 11:25:31 +0100
Subject: [PATCH 2/2] Address comments

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 20 +++++++++++---------
 1 file changed, 11 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a462b3701ddbb..29f71b7cb9246 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6594,24 +6594,26 @@ bool ConstantMaskOp::isAllOnesMask() {
   return true;
 }
 
-static Attribute createBoolSplat(ShapedType ty, bool x) {
-  return SplatElementsAttr::get(ty, BoolAttr::get(ty.getContext(), x));
-}
-
 OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
   ArrayRef<int64_t> bounds = getMaskDimSizes();
   ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
+
+  auto createBoolSplat = [&](bool x) {
+    return SplatElementsAttr::get(getVectorType(),
+                                  BoolAttr::get(getContext(), x));
+  };
+
   // Check the corner case of 0-D vectors first.
-  if (vectorSizes.size() == 0) {
+  if (vectorSizes.empty()) {
     assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
-    return createBoolSplat(getVectorType(), bounds[0] == 1);
+    return createBoolSplat(bounds[0] == 1);
   }
   // Fold vector.constant_mask to splat if possible.
   if (bounds == vectorSizes)
-    return createBoolSplat(getVectorType(), true);
+    return createBoolSplat(true);
   if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
-    return createBoolSplat(getVectorType(), false);
-  return {};
+    return createBoolSplat(false);
+  return OpFoldResult();
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list