[Mlir-commits] [mlir] [mlir][Vector] Improve `vector.mask` verifier (PR #139823)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed May 14 08:02:47 PDT 2025
================
@@ -6543,29 +6543,31 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
}
void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
- OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
- MaskOp>::ensureTerminator(region, builder, loc);
- // Keep the default yield terminator if the number of masked operations is not
- // the expected. This case will trigger a verification failure.
- Block &block = region.front();
- if (block.getOperations().size() != 2)
+ // Create default terminator if there are no ops to mask.
+ if (region.empty() || region.front().empty()) {
+ OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
+ MaskOp>::ensureTerminator(region, builder, loc);
return;
+ }
- // Replace default yield terminator with a new one that returns the results
- // from the masked operation.
- OpBuilder opBuilder(builder.getContext());
- Operation *maskedOp = &block.front();
- Operation *oldYieldOp = &block.back();
- assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
+ // If region has an explicit terminator, we don't modify it.
+ Block &block = region.front();
+ if (isa<vector::YieldOp>(block.back()))
+ return;
- // Empty vector.mask op.
- if (maskedOp == oldYieldOp)
+ // Create default terminator if the number of masked operations is not
+ // one. This case will trigger a verification failure.
+ if (block.getOperations().size() != 1) {
+ OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
+ MaskOp>::ensureTerminator(region, builder, loc);
return;
+ }
- opBuilder.setInsertionPoint(oldYieldOp);
+ // Create a terminator that yields the results from the masked operation.
+ OpBuilder opBuilder(builder.getContext());
+ Operation *maskedOp = &block.front();
+ opBuilder.setInsertionPointToEnd(&block);
----------------
banach-space wrote:
It's not immediately clear what the cases are. IIUC, it's something like this:
```cpp
// 1. For an empty `vector.mask`, create a default terminator.
if (region.empty() || region.front().empty()) {
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
MaskOp>::ensureTerminator(region, builder, loc);
return;
}
// 2. For a non-empty `vector.mask` _with_ an existing terminator, do nothing.
Block &block = region.front();
if (isa<vector::YieldOp>(block.back()))
return;
// 3. For a non-empty `vector.mask` _without_ a terminator, split into two cases.
// 3.1. If the number of masked operations is != 1, create the default terminator (this case is invalid and will be flagged by the Op verifier).
if (block.getOperations().size() != 1) {
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
MaskOp>::ensureTerminator(region, builder, loc);
return;
}
// 3.2 Otherwise, create a terminator that yields all the results from the masked operation.
OpBuilder opBuilder(builder.getContext());
Operation *maskedOp = &block.front();
opBuilder.setInsertionPointToEnd(&block);
opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
```
? Did I get it correctly?
https://github.com/llvm/llvm-project/pull/139823
More information about the Mlir-commits
mailing list