[Mlir-commits] [mlir] [mlir][Vector] Improve `vector.mask` verifier (PR #139823)

Diego Caballero llvmlistbot at llvm.org
Tue May 20 15:18:12 PDT 2025


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/139823

>From 2f14c0f69b6f704aa2d21d0a566ce8d93a268f77 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Tue, 13 May 2025 23:39:34 +0000
Subject: [PATCH 1/4] [mlir][Vector] Improve `vector.mask` verifier

This PR improves the verifier for the `vector.mask` operation to
make sure it's not applying masking semantics to operations defined
outside of the `vector.mask` region. Documentation is updated to
emphasize that and make it clearer, even though it already stated that.

As part of this change, the logic that ensures that a terminator is
present in the region mask has been simplified to make it less surprising
to the user when a `vector.yield` is explicitly provided.
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  9 +++-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 41 +++++++++++--------
 mlir/test/Dialect/Vector/invalid.mlir         | 28 +++++++++++++
 3 files changed, 59 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3aefcea8de994..2820759687293 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2482,8 +2482,13 @@ def Vector_MaskOp : Vector_Op<"mask", [
     masked. Values used within the region are captured from above. Only one
     *maskable* operation can be masked with a `vector.mask` operation at a time.
     An operation is *maskable* if it implements the `MaskableOpInterface`. The
-    terminator yields all results of the maskable operation to the result of
-    this operation.
+    terminator yields all results from the maskable operation to the result of
+    this operation. No other values are allowed to be yielded.
+
+    An empty `vector.mask` operation is considered ill-formed but legal to
+    facilitate optimizations across the `vector.mask` operation. It is considered
+    a no-op regardless of its returned values and will be removed by the
+    canonicalizer.
 
     The vector mask argument holds a bit for each vector lane and determines
     which vector lanes should execute the maskable operation and which ones
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1b5534d4d94ff..25c3f439e877a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6550,29 +6550,31 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
 }
 
 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
-  OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
-      MaskOp>::ensureTerminator(region, builder, loc);
-  // Keep the default yield terminator if the number of masked operations is not
-  // the expected. This case will trigger a verification failure.
-  Block &block = region.front();
-  if (block.getOperations().size() != 2)
+  // Create default terminator if there are no ops to mask.
+  if (region.empty() || region.front().empty()) {
+    OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
+        MaskOp>::ensureTerminator(region, builder, loc);
     return;
+  }
 
-  // Replace default yield terminator with a new one that returns the results
-  // from the masked operation.
-  OpBuilder opBuilder(builder.getContext());
-  Operation *maskedOp = &block.front();
-  Operation *oldYieldOp = &block.back();
-  assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
+  // If  region has an explicit terminator, we don't modify it.
+  Block &block = region.front();
+  if (isa<vector::YieldOp>(block.back()))
+    return;
 
-  // Empty vector.mask op.
-  if (maskedOp == oldYieldOp)
+  // Create default terminator if the number of masked operations is not
+  // one. This case will trigger a verification failure.
+  if (block.getOperations().size() != 1) {
+    OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
+        MaskOp>::ensureTerminator(region, builder, loc);
     return;
+  }
 
-  opBuilder.setInsertionPoint(oldYieldOp);
+  // Create a terminator that yields the results from the masked operation.
+  OpBuilder opBuilder(builder.getContext());
+  Operation *maskedOp = &block.front();
+  opBuilder.setInsertionPointToEnd(&block);
   opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
-  oldYieldOp->dropAllReferences();
-  oldYieldOp->erase();
 }
 
 LogicalResult MaskOp::verify() {
@@ -6607,6 +6609,11 @@ LogicalResult MaskOp::verify() {
     return emitOpError("expects number of results to match maskable operation "
                        "number of results");
 
+  if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
+    return emitOpError(
+        "expects all the results from the MaskableOpInterface to "
+        "be returned by the terminator");
+
   if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
     return emitOpError(
         "expects result type to match maskable operation result type");
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 740c6b7ae3174..bebb617f8aa09 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1756,6 +1756,20 @@ func.func @vector_mask_empty_passthru_no_return_type(%mask : vector<8xi1>,
 
 // -----
 
+func.func @vector_mask_non_empty_external_return(%t0: tensor<?xf32>, %idx: index,
+                                                 %m0: vector<16xi1>, %ext: vector<16xf32>) -> vector<16xf32> {
+  %ft0 = arith.constant 0.0 : f32
+  // expected-error at +1 {{'vector.mask' op expects all the results from the MaskableOpInterface to be returned by the terminator}}
+  %0 = vector.mask %m0 {
+    %1 =vector.transfer_read %t0[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
+    vector.yield %ext : vector<16xf32>
+  } : vector<16xi1> -> vector<16xf32>
+
+  return %0 : vector<16xf32>
+}
+
+// -----
+
 func.func @vector_mask_empty_passthru_empty_return_type(%mask : vector<8xi1>,
                                                         %passthru : vector<8xi32>) {
   // expected-error at +1 {{'vector.mask' expects a result if passthru operand is provided}}
@@ -1765,6 +1779,20 @@ func.func @vector_mask_empty_passthru_empty_return_type(%mask : vector<8xi1>,
 
 // -----
 
+func.func @vector_mask_non_empty_mixed_return(%t0: tensor<?xf32>, %idx: index,
+                                              %m0: vector<16xi1>, %ext: vector<16xf32>) -> (vector<16xf32>, vector<16xf32>) {
+  %ft0 = arith.constant 0.0 : f32
+  // expected-error at +1 {{'vector.mask' op expects number of results to match maskable operation number of results}}
+  %0:2 = vector.mask %m0 {
+    %1 =vector.transfer_read %t0[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
+    vector.yield %1, %ext : vector<16xf32>, vector<16xf32>
+  } : vector<16xi1> -> (vector<16xf32>, vector<16xf32>)
+
+  return %0#0, %0#1 : vector<16xf32>, vector<16xf32>
+}
+
+// -----
+
 func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) {
   // expected-error at +1 {{op failed to verify that position is a multiple of the source length.}}
   %0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32>

>From d716bbb0871d653d19d1b57d7baf6f8fcdd4882f Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Wed, 14 May 2025 17:55:30 +0000
Subject: [PATCH 2/4] Review feedback

---
 .../mlir/Dialect/Vector/IR/VectorOps.td        |  7 +++----
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp       | 11 ++++++-----
 mlir/test/Dialect/Vector/invalid.mlir          | 18 +++++++++---------
 3 files changed, 18 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2820759687293..e4c66a42ea333 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2485,10 +2485,9 @@ def Vector_MaskOp : Vector_Op<"mask", [
     terminator yields all results from the maskable operation to the result of
     this operation. No other values are allowed to be yielded.
 
-    An empty `vector.mask` operation is considered ill-formed but legal to
-    facilitate optimizations across the `vector.mask` operation. It is considered
-    a no-op regardless of its returned values and will be removed by the
-    canonicalizer.
+    An empty `vector.mask` operation is legal to facilitate optimizations across
+    the `vector.mask` operation. However, it is considered a no-op regardless of
+    its returned values and will be removed by the canonicalizer.
 
     The vector mask argument holds a bit for each vector lane and determines
     which vector lanes should execute the maskable operation and which ones
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 25c3f439e877a..bbb366b01fa6e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6550,18 +6550,20 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
 }
 
 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
-  // Create default terminator if there are no ops to mask.
+  // 1. For an empty `vector.mask`, create a default terminator.
   if (region.empty() || region.front().empty()) {
     OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
         MaskOp>::ensureTerminator(region, builder, loc);
     return;
   }
 
-  // If  region has an explicit terminator, we don't modify it.
+  // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
   Block &block = region.front();
   if (isa<vector::YieldOp>(block.back()))
     return;
 
+  // 3. For a non-empty `vector.mask` without an explicit terminator:
+
   // Create default terminator if the number of masked operations is not
   // one. This case will trigger a verification failure.
   if (block.getOperations().size() != 1) {
@@ -6610,9 +6612,8 @@ LogicalResult MaskOp::verify() {
                        "number of results");
 
   if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
-    return emitOpError(
-        "expects all the results from the MaskableOpInterface to "
-        "be returned by the terminator");
+    return emitOpError("expects all the results from the MaskableOpInterface "
+                       "to match all the values returned by the terminator");
 
   if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
     return emitOpError(
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index bebb617f8aa09..04810ed52584f 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1756,12 +1756,12 @@ func.func @vector_mask_empty_passthru_no_return_type(%mask : vector<8xi1>,
 
 // -----
 
-func.func @vector_mask_non_empty_external_return(%t0: tensor<?xf32>, %idx: index,
-                                                 %m0: vector<16xi1>, %ext: vector<16xf32>) -> vector<16xf32> {
+func.func @vector_mask_non_empty_external_return(%t: tensor<?xf32>, %idx: index,
+                                                 %m: vector<16xi1>, %ext: vector<16xf32>) -> vector<16xf32> {
   %ft0 = arith.constant 0.0 : f32
-  // expected-error at +1 {{'vector.mask' op expects all the results from the MaskableOpInterface to be returned by the terminator}}
-  %0 = vector.mask %m0 {
-    %1 =vector.transfer_read %t0[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
+  // expected-error at +1 {{'vector.mask' op expects all the results from the MaskableOpInterface to match all the values returned by the terminator}}
+  %0 = vector.mask %m {
+    %1 =vector.transfer_read %t[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
     vector.yield %ext : vector<16xf32>
   } : vector<16xi1> -> vector<16xf32>
 
@@ -1779,12 +1779,12 @@ func.func @vector_mask_empty_passthru_empty_return_type(%mask : vector<8xi1>,
 
 // -----
 
-func.func @vector_mask_non_empty_mixed_return(%t0: tensor<?xf32>, %idx: index,
-                                              %m0: vector<16xi1>, %ext: vector<16xf32>) -> (vector<16xf32>, vector<16xf32>) {
+func.func @vector_mask_non_empty_mixed_return(%t: tensor<?xf32>, %idx: index,
+                                              %m: vector<16xi1>, %ext: vector<16xf32>) -> (vector<16xf32>, vector<16xf32>) {
   %ft0 = arith.constant 0.0 : f32
   // expected-error at +1 {{'vector.mask' op expects number of results to match maskable operation number of results}}
-  %0:2 = vector.mask %m0 {
-    %1 =vector.transfer_read %t0[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
+  %0:2 = vector.mask %m {
+    %1 =vector.transfer_read %t[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
     vector.yield %1, %ext : vector<16xf32>, vector<16xf32>
   } : vector<16xi1> -> (vector<16xf32>, vector<16xf32>)
 

>From 13917011878ef930d11b07b58d3982d07887a97f Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Fri, 16 May 2025 22:49:44 +0000
Subject: [PATCH 3/4] Improve doc

---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index e4c66a42ea333..cd734c796582f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2485,9 +2485,9 @@ def Vector_MaskOp : Vector_Op<"mask", [
     terminator yields all results from the maskable operation to the result of
     this operation. No other values are allowed to be yielded.
 
-    An empty `vector.mask` operation is legal to facilitate optimizations across
-    the `vector.mask` operation. However, it is considered a no-op regardless of
-    its returned values and will be removed by the canonicalizer.
+    An empty `vector.mask` operation is currently legal to enable optimizations
+    across the `vector.mask` region. However, this might change in the future
+    once vector transformations gain better support for `vector.mask`.
 
     The vector mask argument holds a bit for each vector lane and determines
     which vector lanes should execute the maskable operation and which ones

>From 2da41ace7085ed6bcfd98dc0388aa0169403f9a0 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Tue, 20 May 2025 22:17:53 +0000
Subject: [PATCH 4/4] Add TODO

---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index cd734c796582f..5e8421ed67d66 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2488,6 +2488,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
     An empty `vector.mask` operation is currently legal to enable optimizations
     across the `vector.mask` region. However, this might change in the future
     once vector transformations gain better support for `vector.mask`.
+    TODO: Consider making empty `vector.mask` illegal.
 
     The vector mask argument holds a bit for each vector lane and determines
     which vector lanes should execute the maskable operation and which ones



More information about the Mlir-commits mailing list