[Mlir-commits] [mlir] [mlir][vector] Restrict vector.insert/vector.extract (PR #121458)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Sat Feb 8 07:24:16 PST 2025
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/121458
>From eb62bf0b0530456319d45e4b6121cf597374629e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 29 Dec 2024 18:52:12 +0000
Subject: [PATCH] [mlir][vector] Restrict vector.insert/vector.extract
This patch restricts the use of vector.insert and vector.extract Ops in
the Vector dialect. Specifically:
* The non-indexed operands for `vector.insert` and `vector.extract`
must now be non-0-D vectors.
The following are now illegal. Note that the source and result types
(i.e. non-indexed args) are rank-0 vectors:
```mlir
%0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32>
%1 = vector.extract %arg0[0, 0] : vector<f32> from vector<2x2xf32>
```
Instead, use scalars as the source and result types:
```mlir
%0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32>
%1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32>
```
Put differently, this PR removes the ambiguity when it comes to
non-indexed operands of `vector.insert` and `vector.extract`. By
requiring that only one form is used, it eliminates the flexibility of
allowing both, thereby simplifying the semantics.
For more context, see the related RFC:
* https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect
---
.../Conversion/VectorToSCF/VectorToSCF.cpp | 40 ++++++++++++++++---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 10 +++++
mlir/test/Dialect/Vector/invalid.mlir | 4 +-
mlir/test/Dialect/Vector/ops.mlir | 6 +--
4 files changed, 48 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 22bf27d229ce5d8..5dcb4c77c92f381 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1287,6 +1287,10 @@ struct UnrollTransferReadConversion
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
/// accesses, and broadcasts and transposes in permutation maps.
+ ///
+ /// When unpacking rank-1 vectors (i.e. when the target rank is 0), replaces
+ /// `vector.transfer_read` with either `memref.load` or `tensor.extract` (for
+ /// MemRef and Tensor source, respectively).
LogicalResult matchAndRewrite(TransferReadOp xferOp,
PatternRewriter &rewriter) const override {
if (xferOp.getVectorType().getRank() <= options.targetRank)
@@ -1319,6 +1323,8 @@ struct UnrollTransferReadConversion
for (int64_t i = 0; i < dimSize; ++i) {
Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ // FIXME: Rename this lambda - it does much more than just
+ // in-bounds-check generation.
vec = generateInBoundsCheck(
rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
/*inBoundsCase=*/
@@ -1333,12 +1339,34 @@ struct UnrollTransferReadConversion
insertionIndices.push_back(rewriter.getIndexAttr(i));
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
- auto newXferOp = b.create<vector::TransferReadOp>(
- loc, newXferVecType, xferOp.getSource(), xferIndices,
- AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
- xferOp.getPadding(), Value(), inBoundsAttr);
- maybeAssignMask(b, xferOp, newXferOp, i);
- return b.create<vector::InsertOp>(loc, newXferOp, vec,
+
+ // A value that's read after rank-reducing the original
+ // vector.transfer_read Op.
+ Value unpackedReadRes;
+ if (newXferVecType.getRank() != 0) {
+ // Unpacking Vector that's rank > 2
+ // (use vector.transfer_read to load a rank-reduced vector)
+ unpackedReadRes = b.create<vector::TransferReadOp>(
+ loc, newXferVecType, xferOp.getSource(), xferIndices,
+ AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
+ xferOp.getPadding(), Value(), inBoundsAttr);
+ maybeAssignMask(b, xferOp,
+ dyn_cast<vector::TransferReadOp>(
+ unpackedReadRes.getDefiningOp()),
+ i);
+ } else {
+ // Unpacking Vector that's rank == 1
+ // (use memref.load/tensor.extract to load a scalar)
+ unpackedReadRes =
+ dyn_cast<MemRefType>(xferOp.getSource().getType())
+ ? b.create<memref::LoadOp>(loc, xferOp.getSource(),
+ xferIndices)
+ .getResult()
+ : b.create<tensor::ExtractOp>(loc, xferOp.getSource(),
+ xferIndices)
+ .getResult();
+ }
+ return b.create<vector::InsertOp>(loc, unpackedReadRes, vec,
insertionIndices);
},
/*outOfBoundsCase=*/
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b4a5461f4405dcf..dce676875247aed 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1351,6 +1351,11 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
}
LogicalResult vector::ExtractOp::verify() {
+ if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
+ if (resTy.getRank() == 0)
+ return emitError(
+ "expected a scalar instead of a 0-d vector as the result type");
+
// Note: This check must come before getMixedPosition() to prevent a crash.
auto dynamicMarkersCount =
llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
@@ -2929,6 +2934,11 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
}
LogicalResult InsertOp::verify() {
+ if (auto srcTy = dyn_cast<VectorType>(getSourceType()))
+ if (srcTy.getRank() == 0)
+ return emitError(
+ "expected a scalar instead of a 0-d vector as the source operand");
+
SmallVector<OpFoldResult> position = getMixedPosition();
auto destVectorType = getDestVectorType();
if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 57e348c7d59912c..1f2e32c500be7fd 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -260,8 +260,8 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
// -----
func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
- // expected-error at +1 {{expected position attribute rank + source rank to match dest vector rank}}
- %1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
+ // expected-error at +1 {{expected a scalar instead of a 0-d vector as the source operand}}
+ %1 = vector.insert %a, %b[0, 0, 0] : vector<f32> into vector<4x8x16xf32>
}
// -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 67484e06f456dca..5ea9acd052cd2a2 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -298,12 +298,10 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
}
// CHECK-LABEL: @insert_0d
-func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
+func.func @insert_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
%1 = vector.insert %a, %b[] : f32 into vector<f32>
- // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
- %2 = vector.insert %b, %c[0, 1] : vector<f32> into vector<2x3xf32>
- return %1, %2 : vector<f32>, vector<2x3xf32>
+ return %1 : vector<f32>
}
// CHECK-LABEL: @insert_poison_idx
More information about the Mlir-commits
mailing list