[Mlir-commits] [mlir] e9e25f0 - [mlir][vector] Restrict vector.insert/vector.extract to disallow 0-d vectors (#121458)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 26 01:47:10 PDT 2025


Author: Andrzej Warzyński
Date: 2025-06-26T09:47:06+01:00
New Revision: e9e25f02e6e10c75224aad646bdd1705f1d9d8b1

URL: https://github.com/llvm/llvm-project/commit/e9e25f02e6e10c75224aad646bdd1705f1d9d8b1
DIFF: https://github.com/llvm/llvm-project/commit/e9e25f02e6e10c75224aad646bdd1705f1d9d8b1.diff

LOG: [mlir][vector] Restrict vector.insert/vector.extract to disallow 0-d vectors (#121458)

This patch enforces a restriction in the Vector dialect: the non-indexed
operands of `vector.insert` and `vector.extract` must no longer be 0-D
vectors. In other words, rank-0 vector types like `vector<f32>` are
disallowed as the source or result.

EXAMPLES
--------
The following are now **illegal** (note the use of `vector<f32>`):

```mlir
%0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32>
%1 = vector.extract %src[0, 0] : vector<f32> from vector<2x2xf32>
```

Instead, use scalars as the source and result types:

```mlir
  %0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32>
  %1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32>
```

Note, this change serves three goals. These are summarised below.

## 1. REDUCED AMBIGUITY
By enforcing scalar-only semantics when the result (`vector.extract`)
or source (`vector.insert`) are rank-0, we eliminate ambiguity
in interpretation. Prior to this patch, both `f32` and `vector<f32>`
were accepted.

## 2. MATCH IMPLEMENTATION TO DOCUMENTATION
The current behaviour contradicts the documented intent. For example,
`vector.extract` states:

> Degenerates to an element type if n-k is zero.

This patch enforces that intent in code.

## 3. ENSURE SYMMETRY BETWEEN INSERT AND EXTRACT
With the stricter semantics in place, it’s natural and consistent to
make `vector.insert` behave symmetrically to `vector.extract`, i.e.,
degenerate the source type to a scalar when n = 0.

NOTES FOR REVIEWERS
-------------------
1. Main change is in "VectorOps.cpp", where stricter type checks are
   implemented.
2. Test updates in "invalid.mlir" and "ops.mlir" are minor cleanups to
   remove now-illegal examples.
2. Lowering changes in "VectorToSCF.cpp" are the main trade-off: we now
   require an additional `vector.extract` when a preceding
   `vector.transfer_read` generates a rank-0 vector.

RELATED RFC
-----------
*
https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index d58ee84bee63d..e6b85de5a522a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -703,8 +703,9 @@ def Vector_ExtractOp :
      InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "extract operation";
   let description = [{
-    Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
-    the proper position. Degenerates to an element type if n-k is zero.
+    Extracts an (n − k)-D result sub-vector from an n-D source vector at a
+    specified k-D position. When n = k, the result degenerates to a scalar
+    element.
 
     Static and dynamic indices must be greater or equal to zero and less than
     the size of the corresponding dimension. The result is undefined if any
@@ -716,7 +717,6 @@ def Vector_ExtractOp :
     ```mlir
     %1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32>
     %2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32>
-    %3 = vector.extract %1[]: vector<f32> from vector<f32>
     %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
     %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
     %6 = vector.extract %10[-1, %c]: f32 from vector<4x16xf32>
@@ -949,9 +949,9 @@ def Vector_InsertOp :
      AllTypesMatch<["dest", "result"]>]> {
   let summary = "insert operation";
   let description = [{
-    Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
-    and inserts the n-D source into the (n+k)-D destination at the proper
-    position. Degenerates to a scalar or a 0-d vector source type when n = 0.
+    Inserts an (n - k)-D sub-vector (value-to-store) into an n-D destination
+    vector at a specified k-D position. When n = 0, value-to-store degenerates
+    to a scalar element inserted into the n-D destination vector.
 
     Static and dynamic indices must be greater or equal to zero and less than
     the size of the corresponding dimension. The result is undefined if any
@@ -963,8 +963,7 @@ def Vector_InsertOp :
     ```mlir
     %2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
     %5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
-    %8 = vector.insert %6, %7[] : f32 into vector<f32>
-    %11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
+    %11 = vector.insert %9, %10[%a, %b, %c] : f32 into vector<4x8x16xf32>
     %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
     %13 = vector.insert %20, %1[-1, %c] : f32 into vector<4x16xf32>
     ```

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index cc5623068ab10..002dfebd2b602 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1324,6 +1324,8 @@ struct UnrollTransferReadConversion
     for (int64_t i = 0; i < dimSize; ++i) {
       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
 
+      // FIXME: Rename this lambda - it does much more than just
+      // in-bounds-check generation.
       vec = generateInBoundsCheck(
           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
           /*inBoundsCase=*/
@@ -1338,12 +1340,21 @@ struct UnrollTransferReadConversion
             insertionIndices.push_back(rewriter.getIndexAttr(i));
 
             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
+
             auto newXferOp = b.create<vector::TransferReadOp>(
                 loc, newXferVecType, xferOp.getBase(), xferIndices,
                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
                 xferOp.getPadding(), Value(), inBoundsAttr);
             maybeAssignMask(b, xferOp, newXferOp, i);
-            return b.create<vector::InsertOp>(loc, newXferOp, vec,
+
+            Value valToInser = newXferOp.getResult();
+            if (newXferVecType.getRank() == 0) {
+              // vector.insert does not accept rank-0 as the non-indexed
+              // argument. Extract the scalar before inserting.
+              valToInser = b.create<vector::ExtractOp>(loc, valToInser,
+                                                       SmallVector<int64_t>());
+            }
+            return b.create<vector::InsertOp>(loc, valToInser, vec,
                                               insertionIndices);
           },
           /*outOfBoundsCase=*/

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..a11dbe2589205 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1384,6 +1384,11 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 }
 
 LogicalResult vector::ExtractOp::verify() {
+  if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
+    if (resTy.getRank() == 0)
+      return emitError(
+          "expected a scalar instead of a 0-d vector as the result type");
+
   // Note: This check must come before getMixedPosition() to prevent a crash.
   auto dynamicMarkersCount =
       llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
@@ -3211,6 +3216,11 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
 }
 
 LogicalResult InsertOp::verify() {
+  if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
+    if (srcTy.getRank() == 0)
+      return emitError(
+          "expected a scalar instead of a 0-d vector as the source operand");
+
   SmallVector<OpFoldResult> position = getMixedPosition();
   auto destVectorType = getDestVectorType();
   if (position.size() > static_cast<unsigned>(destVectorType.getRank()))

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index b6025cee31a6d..5038646e1f026 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -178,9 +178,9 @@ func.func @extract_precise_position_overflow(%arg0: vector<4x8x16xf32>) {
 
 // -----
 
-func.func @extract_0d(%arg0: vector<f32>) {
-  // expected-error at +1 {{expected position attribute of rank no greater than vector rank}}
-  %1 = vector.extract %arg0[0] : f32 from vector<f32>
+func.func @extract_0d_result(%arg0: vector<f32>) {
+  // expected-error at +1 {{expected a scalar instead of a 0-d vector as the result type}}
+  %1 = vector.extract %arg0[] : vector<f32> from vector<f32>
 }
 
 // -----
@@ -259,16 +259,9 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
 
 // -----
 
-func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected position attribute rank + source rank to match dest vector rank}}
-  %1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
-}
-
-// -----
-
-func.func @insert_0d(%a: f32, %b: vector<f32>) {
-  // expected-error at +1 {{expected position attribute of rank no greater than dest vector rank}}
-  %1 = vector.insert %a, %b[0] : f32 into vector<f32>
+func.func @insert_0d_value_to_store(%a: vector<f32>, %b: vector<4x8x16xf32>) {
+  // expected-error at +1 {{expected a scalar instead of a 0-d vector as the source operand}}
+  %1 = vector.insert %a, %b[0, 0, 0] : vector<f32> into vector<4x8x16xf32>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 0121bcdbbba45..10bf0f1620568 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -300,12 +300,10 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
 }
 
 // CHECK-LABEL: @insert_0d
-func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
+func.func @insert_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
   // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
   %1 = vector.insert %a,  %b[] : f32 into vector<f32>
-  // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
-  %2 = vector.insert %b,  %c[0, 1] : vector<f32> into vector<2x3xf32>
-  return %1, %2 : vector<f32>, vector<2x3xf32>
+  return %1 : vector<f32>
 }
 
 // CHECK-LABEL: @insert_poison_idx


        


More information about the Mlir-commits mailing list