[Mlir-commits] [mlir] e9e25f0 - [mlir][vector] Restrict vector.insert/vector.extract to disallow 0-d vectors (#121458)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 26 01:47:10 PDT 2025
Author: Andrzej Warzyński
Date: 2025-06-26T09:47:06+01:00
New Revision: e9e25f02e6e10c75224aad646bdd1705f1d9d8b1
URL: https://github.com/llvm/llvm-project/commit/e9e25f02e6e10c75224aad646bdd1705f1d9d8b1
DIFF: https://github.com/llvm/llvm-project/commit/e9e25f02e6e10c75224aad646bdd1705f1d9d8b1.diff
LOG: [mlir][vector] Restrict vector.insert/vector.extract to disallow 0-d vectors (#121458)
This patch enforces a restriction in the Vector dialect: the non-indexed
operands of `vector.insert` and `vector.extract` must no longer be 0-D
vectors. In other words, rank-0 vector types like `vector<f32>` are
disallowed as the source or result.
EXAMPLES
--------
The following are now **illegal** (note the use of `vector<f32>`):
```mlir
%0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32>
%1 = vector.extract %src[0, 0] : vector<f32> from vector<2x2xf32>
```
Instead, use scalars as the source and result types:
```mlir
%0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32>
%1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32>
```
Note, this change serves three goals. These are summarised below.
## 1. REDUCED AMBIGUITY
By enforcing scalar-only semantics when the result (`vector.extract`)
or source (`vector.insert`) are rank-0, we eliminate ambiguity
in interpretation. Prior to this patch, both `f32` and `vector<f32>`
were accepted.
## 2. MATCH IMPLEMENTATION TO DOCUMENTATION
The current behaviour contradicts the documented intent. For example,
`vector.extract` states:
> Degenerates to an element type if n-k is zero.
This patch enforces that intent in code.
## 3. ENSURE SYMMETRY BETWEEN INSERT AND EXTRACT
With the stricter semantics in place, it’s natural and consistent to
make `vector.insert` behave symmetrically to `vector.extract`, i.e.,
degenerate the source type to a scalar when n = 0.
NOTES FOR REVIEWERS
-------------------
1. Main change is in "VectorOps.cpp", where stricter type checks are
implemented.
2. Test updates in "invalid.mlir" and "ops.mlir" are minor cleanups to
remove now-illegal examples.
2. Lowering changes in "VectorToSCF.cpp" are the main trade-off: we now
require an additional `vector.extract` when a preceding
`vector.transfer_read` generates a rank-0 vector.
RELATED RFC
-----------
*
https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index d58ee84bee63d..e6b85de5a522a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -703,8 +703,9 @@ def Vector_ExtractOp :
InferTypeOpAdaptorWithIsCompatible]> {
let summary = "extract operation";
let description = [{
- Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
- the proper position. Degenerates to an element type if n-k is zero.
+ Extracts an (n − k)-D result sub-vector from an n-D source vector at a
+ specified k-D position. When n = k, the result degenerates to a scalar
+ element.
Static and dynamic indices must be greater or equal to zero and less than
the size of the corresponding dimension. The result is undefined if any
@@ -716,7 +717,6 @@ def Vector_ExtractOp :
```mlir
%1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32>
%2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32>
- %3 = vector.extract %1[]: vector<f32> from vector<f32>
%4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
%5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
%6 = vector.extract %10[-1, %c]: f32 from vector<4x16xf32>
@@ -949,9 +949,9 @@ def Vector_InsertOp :
AllTypesMatch<["dest", "result"]>]> {
let summary = "insert operation";
let description = [{
- Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
- and inserts the n-D source into the (n+k)-D destination at the proper
- position. Degenerates to a scalar or a 0-d vector source type when n = 0.
+ Inserts an (n - k)-D sub-vector (value-to-store) into an n-D destination
+ vector at a specified k-D position. When n = 0, value-to-store degenerates
+ to a scalar element inserted into the n-D destination vector.
Static and dynamic indices must be greater or equal to zero and less than
the size of the corresponding dimension. The result is undefined if any
@@ -963,8 +963,7 @@ def Vector_InsertOp :
```mlir
%2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
%5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
- %8 = vector.insert %6, %7[] : f32 into vector<f32>
- %11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
+ %11 = vector.insert %9, %10[%a, %b, %c] : f32 into vector<4x8x16xf32>
%12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
%13 = vector.insert %20, %1[-1, %c] : f32 into vector<4x16xf32>
```
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index cc5623068ab10..002dfebd2b602 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1324,6 +1324,8 @@ struct UnrollTransferReadConversion
for (int64_t i = 0; i < dimSize; ++i) {
Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ // FIXME: Rename this lambda - it does much more than just
+ // in-bounds-check generation.
vec = generateInBoundsCheck(
rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
/*inBoundsCase=*/
@@ -1338,12 +1340,21 @@ struct UnrollTransferReadConversion
insertionIndices.push_back(rewriter.getIndexAttr(i));
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
+
auto newXferOp = b.create<vector::TransferReadOp>(
loc, newXferVecType, xferOp.getBase(), xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
xferOp.getPadding(), Value(), inBoundsAttr);
maybeAssignMask(b, xferOp, newXferOp, i);
- return b.create<vector::InsertOp>(loc, newXferOp, vec,
+
+ Value valToInser = newXferOp.getResult();
+ if (newXferVecType.getRank() == 0) {
+ // vector.insert does not accept rank-0 as the non-indexed
+ // argument. Extract the scalar before inserting.
+ valToInser = b.create<vector::ExtractOp>(loc, valToInser,
+ SmallVector<int64_t>());
+ }
+ return b.create<vector::InsertOp>(loc, valToInser, vec,
insertionIndices);
},
/*outOfBoundsCase=*/
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..a11dbe2589205 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1384,6 +1384,11 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
}
LogicalResult vector::ExtractOp::verify() {
+ if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
+ if (resTy.getRank() == 0)
+ return emitError(
+ "expected a scalar instead of a 0-d vector as the result type");
+
// Note: This check must come before getMixedPosition() to prevent a crash.
auto dynamicMarkersCount =
llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
@@ -3211,6 +3216,11 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
}
LogicalResult InsertOp::verify() {
+ if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
+ if (srcTy.getRank() == 0)
+ return emitError(
+ "expected a scalar instead of a 0-d vector as the source operand");
+
SmallVector<OpFoldResult> position = getMixedPosition();
auto destVectorType = getDestVectorType();
if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index b6025cee31a6d..5038646e1f026 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -178,9 +178,9 @@ func.func @extract_precise_position_overflow(%arg0: vector<4x8x16xf32>) {
// -----
-func.func @extract_0d(%arg0: vector<f32>) {
- // expected-error at +1 {{expected position attribute of rank no greater than vector rank}}
- %1 = vector.extract %arg0[0] : f32 from vector<f32>
+func.func @extract_0d_result(%arg0: vector<f32>) {
+ // expected-error at +1 {{expected a scalar instead of a 0-d vector as the result type}}
+ %1 = vector.extract %arg0[] : vector<f32> from vector<f32>
}
// -----
@@ -259,16 +259,9 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
// -----
-func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
- // expected-error at +1 {{expected position attribute rank + source rank to match dest vector rank}}
- %1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
-}
-
-// -----
-
-func.func @insert_0d(%a: f32, %b: vector<f32>) {
- // expected-error at +1 {{expected position attribute of rank no greater than dest vector rank}}
- %1 = vector.insert %a, %b[0] : f32 into vector<f32>
+func.func @insert_0d_value_to_store(%a: vector<f32>, %b: vector<4x8x16xf32>) {
+ // expected-error at +1 {{expected a scalar instead of a 0-d vector as the source operand}}
+ %1 = vector.insert %a, %b[0, 0, 0] : vector<f32> into vector<4x8x16xf32>
}
// -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 0121bcdbbba45..10bf0f1620568 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -300,12 +300,10 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
}
// CHECK-LABEL: @insert_0d
-func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
+func.func @insert_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
%1 = vector.insert %a, %b[] : f32 into vector<f32>
- // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
- %2 = vector.insert %b, %c[0, 1] : vector<f32> into vector<2x3xf32>
- return %1, %2 : vector<f32>, vector<2x3xf32>
+ return %1 : vector<f32>
}
// CHECK-LABEL: @insert_poison_idx
More information about the Mlir-commits
mailing list