[Mlir-commits] [mlir] [mlir][vector] Restrict vector.insert/vector.extract (PR #121458)
Andrzej Warzyński
llvmlistbot at llvm.org
Sun Jun 15 03:55:08 PDT 2025
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/121458
>From efc29a7ab095bfbd2e2b97b76e929e313939c633 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 29 Dec 2024 18:52:12 +0000
Subject: [PATCH] [mlir][vector] Restrict use of 0-D vectors in
vector.insert/vector.extract
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This patch enforces a restriction in the Vector dialect: the non-indexed
operands of `vector.insert` and `vector.extract` must no longer be 0-D
vectors. In other words, rank-0 vector types like `vector<f32>` are
disallowed as the source or result.
EXAMPLES
--------
The following are now **illegal** (note the use of `vector<f32>`):
```mlir
%0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32>
%1 = vector.extract %src[0, 0] : vector<f32> from vector<2x2xf32>
```
Instead, use scalars as the source and result types:
```mlir
%0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32>
%1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32>
```
This change serves three goals:
1. REDUCED AMBIGUITY
--------------------
By enforcing scalar-only semantics when n-k = 0, we eliminate ambiguity
in interpretation. Prior to this patch, both `f32` and `vector<f32>`
were accepted in practice, though only scalars were intended.
2. MATCH IMPLEMENTATION TO DOCUMENTATION
----------------------------------------
The current behavior contradicts the documented intent. For example,
vector.extract states:
> Degenerates to an element type if n-k is zero.
This patch enforces that intent in code.
3. ENSURE SYMMETRY BETWEEN INSERT AND EXTRACT
---------------------------------------------
With the stricter semantics in place, it’s natural and consistent to
make `vector.insert` behave symmetrically to `vector.extract`, i.e.,
degenerate the source type to a scalar when n = 0.
NOTES FOR REVIEWERS
-------------------
1. Main change is in "VectorOps.cpp", where stricter type checks are
implemented.
2. Test updates in "invalid.mlir" and "ops.mlir" are minor cleanups to
remove now-illegal examples.
2. Lowering changes in "VectorToSCF.cpp" are the main trade-off: we now
avoid using `vector.transfer_read` for scalar loads and instead rely on
`memref.load` / `tensor.extract`.
RELATED RFC
-----------
* https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 16 ++++----
.../Conversion/VectorToSCF/VectorToSCF.cpp | 39 ++++++++++++++++---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 10 +++++
mlir/test/Dialect/Vector/invalid.mlir | 19 +++------
mlir/test/Dialect/Vector/ops.mlir | 6 +--
5 files changed, 59 insertions(+), 31 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 8353314ed958b..cd6b3e7ad82dc 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -691,8 +691,9 @@ def Vector_ExtractOp :
InferTypeOpAdaptorWithIsCompatible]> {
let summary = "extract operation";
let description = [{
- Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
- the proper position. Degenerates to an element type if n-k is zero.
+ Extracts an (n − k)-D subvector (the result) from an n-D vector at a
+ specified k-D position. When n = k, the result degenerates to a scalar
+ element.
Static and dynamic indices must be greater or equal to zero and less than
the size of the corresponding dimension. The result is undefined if any
@@ -704,7 +705,6 @@ def Vector_ExtractOp :
```mlir
%1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32>
%2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32>
- %3 = vector.extract %1[]: vector<f32> from vector<f32>
%4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
%5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
%6 = vector.extract %10[-1, %c]: f32 from vector<4x16xf32>
@@ -886,9 +886,10 @@ def Vector_InsertOp :
AllTypesMatch<["dest", "result"]>]> {
let summary = "insert operation";
let description = [{
- Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
- and inserts the n-D source into the (n+k)-D destination at the proper
- position. Degenerates to a scalar or a 0-d vector source type when n = 0.
+ Inserts an n-D source vector (the value to store) into an (n + k)-D
+ destination vector at a specified k-D position. When n = 0, the source
+ degenerates to a scalar element inserted into the (0 + k)-D destination
+ vector.
Static and dynamic indices must be greater or equal to zero and less than
the size of the corresponding dimension. The result is undefined if any
@@ -900,8 +901,7 @@ def Vector_InsertOp :
```mlir
%2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
%5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
- %8 = vector.insert %6, %7[] : f32 into vector<f32>
- %11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
+ %11 = vector.insert %9, %10[%a, %b, %c] : f32 into vector<4x8x16xf32>
%12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
%13 = vector.insert %20, %1[-1, %c] : f32 into vector<4x16xf32>
```
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index cc5623068ab10..08f398a1c8ba6 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1294,6 +1294,10 @@ struct UnrollTransferReadConversion
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
/// accesses, and broadcasts and transposes in permutation maps.
+ ///
+ /// When unpacking rank-1 vectors (i.e. when the target rank is 0), replaces
+ /// `vector.transfer_read` with either `memref.load` or `tensor.extract` (for
+ /// MemRef and Tensor source, respectively).
LogicalResult matchAndRewrite(TransferReadOp xferOp,
PatternRewriter &rewriter) const override {
if (xferOp.getVectorType().getRank() <= options.targetRank)
@@ -1324,6 +1328,8 @@ struct UnrollTransferReadConversion
for (int64_t i = 0; i < dimSize; ++i) {
Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ // FIXME: Rename this lambda - it does much more than just
+ // in-bounds-check generation.
vec = generateInBoundsCheck(
rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
/*inBoundsCase=*/
@@ -1338,12 +1344,33 @@ struct UnrollTransferReadConversion
insertionIndices.push_back(rewriter.getIndexAttr(i));
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
- auto newXferOp = b.create<vector::TransferReadOp>(
- loc, newXferVecType, xferOp.getBase(), xferIndices,
- AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
- xferOp.getPadding(), Value(), inBoundsAttr);
- maybeAssignMask(b, xferOp, newXferOp, i);
- return b.create<vector::InsertOp>(loc, newXferOp, vec,
+
+ // A value that's read after rank-reducing the original
+ // vector.transfer_read Op.
+ Value unpackedReadRes;
+ if (newXferVecType.getRank() != 0) {
+ // Unpacking Vector that's rank > 2
+ // (use vector.transfer_read to load a rank-reduced vector)
+ unpackedReadRes = b.create<vector::TransferReadOp>(
+ loc, newXferVecType, xferOp.getBase(), xferIndices,
+ AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
+ xferOp.getPadding(), Value(), inBoundsAttr);
+ maybeAssignMask(b, xferOp,
+ dyn_cast<vector::TransferReadOp>(
+ unpackedReadRes.getDefiningOp()),
+ i);
+ } else {
+ // Unpacking Vector that's rank == 1
+ // (use memref.load/tensor.extract to load a scalar)
+ unpackedReadRes = dyn_cast<MemRefType>(xferOp.getBase().getType())
+ ? b.create<memref::LoadOp>(
+ loc, xferOp.getBase(), xferIndices)
+ .getResult()
+ : b.create<tensor::ExtractOp>(
+ loc, xferOp.getBase(), xferIndices)
+ .getResult();
+ }
+ return b.create<vector::InsertOp>(loc, unpackedReadRes, vec,
insertionIndices);
},
/*outOfBoundsCase=*/
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2a2357319bd23..dc4bcd9b6bd84 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1383,6 +1383,11 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
}
LogicalResult vector::ExtractOp::verify() {
+ if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
+ if (resTy.getRank() == 0)
+ return emitError(
+ "expected a scalar instead of a 0-d vector as the result type");
+
// Note: This check must come before getMixedPosition() to prevent a crash.
auto dynamicMarkersCount =
llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
@@ -3122,6 +3127,11 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
}
LogicalResult InsertOp::verify() {
+ if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
+ if (srcTy.getRank() == 0)
+ return emitError(
+ "expected a scalar instead of a 0-d vector as the source operand");
+
SmallVector<OpFoldResult> position = getMixedPosition();
auto destVectorType = getDestVectorType();
if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 04810ed52584f..a2622c06fa71c 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -178,9 +178,9 @@ func.func @extract_precise_position_overflow(%arg0: vector<4x8x16xf32>) {
// -----
-func.func @extract_0d(%arg0: vector<f32>) {
- // expected-error at +1 {{expected position attribute of rank no greater than vector rank}}
- %1 = vector.extract %arg0[0] : f32 from vector<f32>
+func.func @extract_0d_result(%arg0: vector<f32>) {
+ // expected-error at +1 {{expected a scalar instead of a 0-d vector as the result type}}
+ %1 = vector.extract %arg0[] : vector<f32> from vector<f32>
}
// -----
@@ -259,16 +259,9 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
// -----
-func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
- // expected-error at +1 {{expected position attribute rank + source rank to match dest vector rank}}
- %1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
-}
-
-// -----
-
-func.func @insert_0d(%a: f32, %b: vector<f32>) {
- // expected-error at +1 {{expected position attribute of rank no greater than dest vector rank}}
- %1 = vector.insert %a, %b[0] : f32 into vector<f32>
+func.func @insert_0d_value_to_store(%a: vector<f32>, %b: vector<4x8x16xf32>) {
+ // expected-error at +1 {{expected a scalar instead of a 0-d vector as the source operand}}
+ %1 = vector.insert %a, %b[0, 0, 0] : vector<f32> into vector<4x8x16xf32>
}
// -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f3220aed4360c..7d43f2a84dc77 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -298,12 +298,10 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
}
// CHECK-LABEL: @insert_0d
-func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
+func.func @insert_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
%1 = vector.insert %a, %b[] : f32 into vector<f32>
- // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
- %2 = vector.insert %b, %c[0, 1] : vector<f32> into vector<2x3xf32>
- return %1, %2 : vector<f32>, vector<2x3xf32>
+ return %1 : vector<f32>
}
// CHECK-LABEL: @insert_poison_idx
More information about the Mlir-commits
mailing list