[Mlir-commits] [mlir] [mlir][vector] Restrict vector.insert/vector.extract (PR #121458)

Andrzej Warzyński llvmlistbot at llvm.org
Sun Jun 15 03:55:08 PDT 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/121458

>From efc29a7ab095bfbd2e2b97b76e929e313939c633 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 29 Dec 2024 18:52:12 +0000
Subject: [PATCH] [mlir][vector] Restrict use of 0-D vectors in
 vector.insert/vector.extract
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch enforces a restriction in the Vector dialect: the non-indexed
operands of `vector.insert` and `vector.extract` must no longer be 0-D
vectors. In other words, rank-0 vector types like `vector<f32>` are
disallowed as the source or result.

EXAMPLES
--------
The following are now **illegal** (note the use of `vector<f32>`):

```mlir
%0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32>
%1 = vector.extract %src[0, 0] : vector<f32> from vector<2x2xf32>
```

Instead, use scalars as the source and result types:

```mlir
  %0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32>
  %1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32>
```

This change serves three goals:

1. REDUCED AMBIGUITY
--------------------
By enforcing scalar-only semantics when n-k = 0, we eliminate ambiguity
in interpretation. Prior to this patch, both `f32` and `vector<f32>`
were accepted in practice, though only scalars were intended.

2. MATCH IMPLEMENTATION TO DOCUMENTATION
----------------------------------------
The current behavior contradicts the documented intent. For example,
vector.extract states:

> Degenerates to an element type if n-k is zero.

This patch enforces that intent in code.

3. ENSURE SYMMETRY BETWEEN INSERT AND EXTRACT
---------------------------------------------
With the stricter semantics in place, it’s natural and consistent to
make `vector.insert` behave symmetrically to `vector.extract`, i.e.,
degenerate the source type to a scalar when n = 0.

NOTES FOR REVIEWERS
-------------------
1. Main change is in "VectorOps.cpp", where stricter type checks are
   implemented.
2. Test updates in "invalid.mlir" and "ops.mlir" are minor cleanups to
   remove now-illegal examples.
2. Lowering changes in "VectorToSCF.cpp" are the main trade-off: we now
   avoid using `vector.transfer_read` for scalar loads and instead rely on
   `memref.load` / `tensor.extract`.

RELATED RFC
-----------
  * https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 16 ++++----
 .../Conversion/VectorToSCF/VectorToSCF.cpp    | 39 ++++++++++++++++---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 10 +++++
 mlir/test/Dialect/Vector/invalid.mlir         | 19 +++------
 mlir/test/Dialect/Vector/ops.mlir             |  6 +--
 5 files changed, 59 insertions(+), 31 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 8353314ed958b..cd6b3e7ad82dc 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -691,8 +691,9 @@ def Vector_ExtractOp :
      InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "extract operation";
   let description = [{
-    Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
-    the proper position. Degenerates to an element type if n-k is zero.
+    Extracts an (n − k)-D subvector (the result) from an n-D vector at a
+    specified k-D position. When n = k, the result degenerates to a scalar
+    element.
 
     Static and dynamic indices must be greater or equal to zero and less than
     the size of the corresponding dimension. The result is undefined if any
@@ -704,7 +705,6 @@ def Vector_ExtractOp :
     ```mlir
     %1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32>
     %2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32>
-    %3 = vector.extract %1[]: vector<f32> from vector<f32>
     %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
     %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
     %6 = vector.extract %10[-1, %c]: f32 from vector<4x16xf32>
@@ -886,9 +886,10 @@ def Vector_InsertOp :
      AllTypesMatch<["dest", "result"]>]> {
   let summary = "insert operation";
   let description = [{
-    Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
-    and inserts the n-D source into the (n+k)-D destination at the proper
-    position. Degenerates to a scalar or a 0-d vector source type when n = 0.
+    Inserts an n-D source vector (the value to store) into an (n + k)-D
+    destination vector at a specified k-D position. When n = 0, the source
+    degenerates to a scalar element inserted into the (0 + k)-D destination
+    vector.
 
     Static and dynamic indices must be greater or equal to zero and less than
     the size of the corresponding dimension. The result is undefined if any
@@ -900,8 +901,7 @@ def Vector_InsertOp :
     ```mlir
     %2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
     %5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
-    %8 = vector.insert %6, %7[] : f32 into vector<f32>
-    %11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
+    %11 = vector.insert %9, %10[%a, %b, %c] : f32 into vector<4x8x16xf32>
     %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
     %13 = vector.insert %20, %1[-1, %c] : f32 into vector<4x16xf32>
     ```
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index cc5623068ab10..08f398a1c8ba6 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1294,6 +1294,10 @@ struct UnrollTransferReadConversion
 
   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
   /// accesses, and broadcasts and transposes in permutation maps.
+  ///
+  /// When unpacking rank-1 vectors (i.e. when the target rank is 0), replaces
+  /// `vector.transfer_read` with either `memref.load` or `tensor.extract` (for
+  /// MemRef and Tensor source, respectively).
   LogicalResult matchAndRewrite(TransferReadOp xferOp,
                                 PatternRewriter &rewriter) const override {
     if (xferOp.getVectorType().getRank() <= options.targetRank)
@@ -1324,6 +1328,8 @@ struct UnrollTransferReadConversion
     for (int64_t i = 0; i < dimSize; ++i) {
       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
 
+      // FIXME: Rename this lambda - it does much more than just
+      // in-bounds-check generation.
       vec = generateInBoundsCheck(
           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
           /*inBoundsCase=*/
@@ -1338,12 +1344,33 @@ struct UnrollTransferReadConversion
             insertionIndices.push_back(rewriter.getIndexAttr(i));
 
             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
-            auto newXferOp = b.create<vector::TransferReadOp>(
-                loc, newXferVecType, xferOp.getBase(), xferIndices,
-                AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
-                xferOp.getPadding(), Value(), inBoundsAttr);
-            maybeAssignMask(b, xferOp, newXferOp, i);
-            return b.create<vector::InsertOp>(loc, newXferOp, vec,
+
+            // A value that's read after rank-reducing the original
+            // vector.transfer_read Op.
+            Value unpackedReadRes;
+            if (newXferVecType.getRank() != 0) {
+              // Unpacking Vector that's rank > 2
+              // (use vector.transfer_read to load a rank-reduced vector)
+              unpackedReadRes = b.create<vector::TransferReadOp>(
+                  loc, newXferVecType, xferOp.getBase(), xferIndices,
+                  AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
+                  xferOp.getPadding(), Value(), inBoundsAttr);
+              maybeAssignMask(b, xferOp,
+                              dyn_cast<vector::TransferReadOp>(
+                                  unpackedReadRes.getDefiningOp()),
+                              i);
+            } else {
+              // Unpacking Vector that's rank == 1
+              // (use memref.load/tensor.extract to load a scalar)
+              unpackedReadRes = dyn_cast<MemRefType>(xferOp.getBase().getType())
+                                    ? b.create<memref::LoadOp>(
+                                           loc, xferOp.getBase(), xferIndices)
+                                          .getResult()
+                                    : b.create<tensor::ExtractOp>(
+                                           loc, xferOp.getBase(), xferIndices)
+                                          .getResult();
+            }
+            return b.create<vector::InsertOp>(loc, unpackedReadRes, vec,
                                               insertionIndices);
           },
           /*outOfBoundsCase=*/
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2a2357319bd23..dc4bcd9b6bd84 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1383,6 +1383,11 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 }
 
 LogicalResult vector::ExtractOp::verify() {
+  if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
+    if (resTy.getRank() == 0)
+      return emitError(
+          "expected a scalar instead of a 0-d vector as the result type");
+
   // Note: This check must come before getMixedPosition() to prevent a crash.
   auto dynamicMarkersCount =
       llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
@@ -3122,6 +3127,11 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
 }
 
 LogicalResult InsertOp::verify() {
+  if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
+    if (srcTy.getRank() == 0)
+      return emitError(
+          "expected a scalar instead of a 0-d vector as the source operand");
+
   SmallVector<OpFoldResult> position = getMixedPosition();
   auto destVectorType = getDestVectorType();
   if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 04810ed52584f..a2622c06fa71c 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -178,9 +178,9 @@ func.func @extract_precise_position_overflow(%arg0: vector<4x8x16xf32>) {
 
 // -----
 
-func.func @extract_0d(%arg0: vector<f32>) {
-  // expected-error at +1 {{expected position attribute of rank no greater than vector rank}}
-  %1 = vector.extract %arg0[0] : f32 from vector<f32>
+func.func @extract_0d_result(%arg0: vector<f32>) {
+  // expected-error at +1 {{expected a scalar instead of a 0-d vector as the result type}}
+  %1 = vector.extract %arg0[] : vector<f32> from vector<f32>
 }
 
 // -----
@@ -259,16 +259,9 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
 
 // -----
 
-func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected position attribute rank + source rank to match dest vector rank}}
-  %1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
-}
-
-// -----
-
-func.func @insert_0d(%a: f32, %b: vector<f32>) {
-  // expected-error at +1 {{expected position attribute of rank no greater than dest vector rank}}
-  %1 = vector.insert %a, %b[0] : f32 into vector<f32>
+func.func @insert_0d_value_to_store(%a: vector<f32>, %b: vector<4x8x16xf32>) {
+  // expected-error at +1 {{expected a scalar instead of a 0-d vector as the source operand}}
+  %1 = vector.insert %a, %b[0, 0, 0] : vector<f32> into vector<4x8x16xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f3220aed4360c..7d43f2a84dc77 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -298,12 +298,10 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
 }
 
 // CHECK-LABEL: @insert_0d
-func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
+func.func @insert_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
   // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
   %1 = vector.insert %a,  %b[] : f32 into vector<f32>
-  // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
-  %2 = vector.insert %b,  %c[0, 1] : vector<f32> into vector<2x3xf32>
-  return %1, %2 : vector<f32>, vector<2x3xf32>
+  return %1 : vector<f32>
 }
 
 // CHECK-LABEL: @insert_poison_idx



More information about the Mlir-commits mailing list