[Mlir-commits] [mlir] [MLIR][Vector] Make vector.extract verifier more strict for partial extractions (PR #186197)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 12 10:49:38 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
The verifier for `vector.extract` did not check that when fewer indices are provided than the source vector rank (i.e., a partial extraction that extracts a sub-vector), the result type must be a VectorType, not a scalar.
This allowed invalid IR to be created where e.g. `vector.extract %v[0] : i32 from vector<5x1xi32>` (1 index into a rank-2 source, scalar result) would pass verification. The `ExtractFromInsertTransposeChainState::fold()` function would then return a value of the correct mathematical type (`vector<1xi32>`) while the declared result type was `i32`, triggering an assertion in `checkFoldResultTypes`.
Fix by adding a verifier check that rejects scalar result types when fewer indices than the source vector rank are provided.
Fixes #<!-- -->115294
Assisted-by: Claude Code
---
Full diff: https://github.com/llvm/llvm-project/pull/186197.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+7)
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+7)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 927d35342cfdd..55deeb708b428 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1385,6 +1385,13 @@ LogicalResult vector::ExtractOp::verify() {
if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
return emitOpError(
"expected position attribute of rank no greater than vector rank");
+ // A partial extraction (fewer indices than the source vector rank) extracts a
+ // sub-vector, so the result type must be a VectorType. Extracting a scalar
+ // requires providing one index per source vector dimension.
+ if (position.size() < static_cast<size_t>(getSourceVectorType().getRank()) &&
+ !isa<VectorType>(getResult().getType()))
+ return emitOpError("expected the result type to be a vector when using "
+ "fewer indices than the source vector rank");
for (auto [idx, pos] : llvm::enumerate(position)) {
if (auto attr = dyn_cast<Attribute>(pos)) {
int64_t constIdx = cast<IntegerAttr>(attr).getInt();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 3957455ccc76e..5eb84c04c360c 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -148,6 +148,13 @@ func.func @extract_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
// -----
+func.func @extract_partial_scalar_result(%arg0: vector<4x8x16xf32>) {
+ // expected-error at +1 {{expected the result type to be a vector when using fewer indices than the source vector rank}}
+ %1 = "vector.extract" (%arg0) <{static_position = array<i64: 0>}> : (vector<4x8x16xf32>) -> f32
+}
+
+// -----
+
func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error at +1 {{expected position attribute #2 to be a non-negative integer smaller than the corresponding vector dimension}}
%1 = vector.extract %arg0[0, 43, 0] : f32 from vector<4x8x16xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/186197
More information about the Mlir-commits
mailing list