[Mlir-commits] [mlir] 6cac792 - [mlir][Vector] Improve `vector.mask` verifier (#139823)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 20 15:31:48 PDT 2025
Author: Diego Caballero
Date: 2025-05-20T15:31:45-07:00
New Revision: 6cac792bf9eacb1ed0c80fc7c767fc99c50e2524
URL: https://github.com/llvm/llvm-project/commit/6cac792bf9eacb1ed0c80fc7c767fc99c50e2524
DIFF: https://github.com/llvm/llvm-project/commit/6cac792bf9eacb1ed0c80fc7c767fc99c50e2524.diff
LOG: [mlir][Vector] Improve `vector.mask` verifier (#139823)
This PR improves the `vector.mask` verifier to make sure it's not
applying masking semantics to operations defined outside of the
`vector.mask` region. Documentation is updated to emphasize that and
make it clearer, even though it already stated that.
As part of this change, the logic that ensures that a terminator is
present in the region mask has been simplified to make it less
surprising to the user when a `vector.yield` is explicitly provided in
the IR.
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3aefcea8de994..5e8421ed67d66 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2482,8 +2482,13 @@ def Vector_MaskOp : Vector_Op<"mask", [
masked. Values used within the region are captured from above. Only one
*maskable* operation can be masked with a `vector.mask` operation at a time.
An operation is *maskable* if it implements the `MaskableOpInterface`. The
- terminator yields all results of the maskable operation to the result of
- this operation.
+ terminator yields all results from the maskable operation to the result of
+ this operation. No other values are allowed to be yielded.
+
+ An empty `vector.mask` operation is currently legal to enable optimizations
+ across the `vector.mask` region. However, this might change in the future
+ once vector transformations gain better support for `vector.mask`.
+ TODO: Consider making empty `vector.mask` illegal.
The vector mask argument holds a bit for each vector lane and determines
which vector lanes should execute the maskable operation and which ones
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1b5534d4d94ff..bbb366b01fa6e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6550,29 +6550,33 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
}
void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
- OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
- MaskOp>::ensureTerminator(region, builder, loc);
- // Keep the default yield terminator if the number of masked operations is not
- // the expected. This case will trigger a verification failure.
+ // 1. For an empty `vector.mask`, create a default terminator.
+ if (region.empty() || region.front().empty()) {
+ OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
+ MaskOp>::ensureTerminator(region, builder, loc);
+ return;
+ }
+
+ // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
Block &block = region.front();
- if (block.getOperations().size() != 2)
+ if (isa<vector::YieldOp>(block.back()))
return;
- // Replace default yield terminator with a new one that returns the results
- // from the masked operation.
- OpBuilder opBuilder(builder.getContext());
- Operation *maskedOp = &block.front();
- Operation *oldYieldOp = &block.back();
- assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
+ // 3. For a non-empty `vector.mask` without an explicit terminator:
- // Empty vector.mask op.
- if (maskedOp == oldYieldOp)
+ // Create default terminator if the number of masked operations is not
+ // one. This case will trigger a verification failure.
+ if (block.getOperations().size() != 1) {
+ OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
+ MaskOp>::ensureTerminator(region, builder, loc);
return;
+ }
- opBuilder.setInsertionPoint(oldYieldOp);
+ // Create a terminator that yields the results from the masked operation.
+ OpBuilder opBuilder(builder.getContext());
+ Operation *maskedOp = &block.front();
+ opBuilder.setInsertionPointToEnd(&block);
opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
- oldYieldOp->dropAllReferences();
- oldYieldOp->erase();
}
LogicalResult MaskOp::verify() {
@@ -6607,6 +6611,10 @@ LogicalResult MaskOp::verify() {
return emitOpError("expects number of results to match maskable operation "
"number of results");
+ if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
+ return emitOpError("expects all the results from the MaskableOpInterface "
+ "to match all the values returned by the terminator");
+
if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
return emitOpError(
"expects result type to match maskable operation result type");
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 740c6b7ae3174..04810ed52584f 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1756,6 +1756,20 @@ func.func @vector_mask_empty_passthru_no_return_type(%mask : vector<8xi1>,
// -----
+func.func @vector_mask_non_empty_external_return(%t: tensor<?xf32>, %idx: index,
+ %m: vector<16xi1>, %ext: vector<16xf32>) -> vector<16xf32> {
+ %ft0 = arith.constant 0.0 : f32
+ // expected-error at +1 {{'vector.mask' op expects all the results from the MaskableOpInterface to match all the values returned by the terminator}}
+ %0 = vector.mask %m {
+ %1 =vector.transfer_read %t[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
+ vector.yield %ext : vector<16xf32>
+ } : vector<16xi1> -> vector<16xf32>
+
+ return %0 : vector<16xf32>
+}
+
+// -----
+
func.func @vector_mask_empty_passthru_empty_return_type(%mask : vector<8xi1>,
%passthru : vector<8xi32>) {
// expected-error at +1 {{'vector.mask' expects a result if passthru operand is provided}}
@@ -1765,6 +1779,20 @@ func.func @vector_mask_empty_passthru_empty_return_type(%mask : vector<8xi1>,
// -----
+func.func @vector_mask_non_empty_mixed_return(%t: tensor<?xf32>, %idx: index,
+ %m: vector<16xi1>, %ext: vector<16xf32>) -> (vector<16xf32>, vector<16xf32>) {
+ %ft0 = arith.constant 0.0 : f32
+ // expected-error at +1 {{'vector.mask' op expects number of results to match maskable operation number of results}}
+ %0:2 = vector.mask %m {
+ %1 =vector.transfer_read %t[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
+ vector.yield %1, %ext : vector<16xf32>, vector<16xf32>
+ } : vector<16xi1> -> (vector<16xf32>, vector<16xf32>)
+
+ return %0#0, %0#1 : vector<16xf32>, vector<16xf32>
+}
+
+// -----
+
func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) {
// expected-error at +1 {{op failed to verify that position is a multiple of the source length.}}
%0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32>
More information about the Mlir-commits
mailing list