[Mlir-commits] [mlir] [mlir][vector] Restrict vector.insert/vector.extract (PR #121458)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Jan 2 00:57:41 PST 2025


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/121458

This patch restricts the use of vector.insert and vector.extract Ops in
the Vector dialect. Specifically:
  * The non-indexed operands for `vector.insert` and `vector.extract`
    must now be non-0-D vectors.

The following are now illegal. Note that the source and result types
(i.e. non-indexed args) are rank-0 vectors:

```mlir
  %0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32>
  %1 = vector.extract %arg0[0, 0] : vector<f32> from vector<2x2xf32>
```
Instead, use scalars as the source and result types:

```mlir
  %0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32>
  %1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32>
```

Put differently, this PR removes the ambiguity when it comes to
non-indexed operands of `vector.insert` and `vector.extract`. By
requiring that only one form is used, it eliminates the flexibility of
allowing both, thereby simplifying the semantics.

For more context, see the related RFC:
  * https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect


>From 0962edb6ae0d93d253195cacf042961d3fafbcbe Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 29 Dec 2024 18:52:12 +0000
Subject: [PATCH] [mlir][vector] Restrict vector.insert/vector.extract

This patch restricts the use of vector.insert and vector.extract Ops in
the Vector dialect. Specifically:
  * The non-indexed operands for `vector.insert` and `vector.extract`
    must now be non-0-D vectors.

The following are now illegal. Note that the source and result types
(i.e. non-indexed args) are rank-0 vectors:

```mlir
  %0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32>
  %1 = vector.extract %arg0[0, 0] : vector<f32> from vector<2x2xf32>
```
Instead, use scalars as the source and result types:

```mlir
  %0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32>
  %1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32>
```

Put differently, this PR removes the ambiguity when it comes to
non-indexed operands of `vector.insert` and `vector.extract`. By
requiring that only one form is used, it eliminates the flexibility of
allowing both, thereby simplifying the semantics.

For more context, see the related RFC:
  * https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect
---
 .../Conversion/VectorToSCF/VectorToSCF.cpp    | 22 ++++++++++++++-----
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 10 +++++++++
 mlir/test/Dialect/Vector/invalid.mlir         |  4 ++--
 mlir/test/Dialect/Vector/ops.mlir             |  6 ++---
 4 files changed, 30 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 01bc65c841e94c..fb2b2ba5e8e047 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1333,12 +1333,22 @@ struct UnrollTransferReadConversion
             insertionIndices.push_back(rewriter.getIndexAttr(i));
 
             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
-            auto newXferOp = b.create<vector::TransferReadOp>(
-                loc, newXferVecType, xferOp.getSource(), xferIndices,
-                AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
-                xferOp.getPadding(), Value(), inBoundsAttr);
-            maybeAssignMask(b, xferOp, newXferOp, i);
-            return b.create<vector::InsertOp>(loc, newXferOp, vec,
+            Value newXferOrLoadOp;
+            if (newXferVecType.getRank() != 0) {
+              newXferOrLoadOp = b.create<vector::TransferReadOp>(
+                  loc, newXferVecType, xferOp.getSource(), xferIndices,
+                  AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
+                  xferOp.getPadding(), Value(), inBoundsAttr);
+              maybeAssignMask(b, xferOp,
+                              dyn_cast<vector::TransferReadOp>(
+                                  newXferOrLoadOp.getDefiningOp()),
+                              i);
+            } else {
+              // TODO: Generalize so that this also works for Tensors.
+              newXferOrLoadOp = b.create<memref::LoadOp>(
+                  loc, xferOp.getSource(), xferIndices);
+            }
+            return b.create<vector::InsertOp>(loc, newXferOrLoadOp, vec,
                                               insertionIndices);
           },
           /*outOfBoundsCase=*/
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ae1cf95732336a..14c688250b3913 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1340,6 +1340,11 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 }
 
 LogicalResult vector::ExtractOp::verify() {
+  if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
+    if (resTy.getRank() == 0)
+      return emitError(
+          "expected a scalar instead of a 0-d vector as the result type");
+
   // Note: This check must come before getMixedPosition() to prevent a crash.
   auto dynamicMarkersCount =
       llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
@@ -2864,6 +2869,11 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
 }
 
 LogicalResult InsertOp::verify() {
+  if (auto srcTy = dyn_cast<VectorType>(getSourceType()))
+    if (srcTy.getRank() == 0)
+      return emitError(
+          "expected a scalar instead of a 0-d vector as the source operand");
+
   SmallVector<OpFoldResult> position = getMixedPosition();
   auto destVectorType = getDestVectorType();
   if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 1a70791fae1257..7c8edb3f3eeccb 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -260,8 +260,8 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
 // -----
 
 func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected position attribute rank + source rank to match dest vector rank}}
-  %1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
+  // expected-error at +1 {{expected a scalar instead of a 0-d vector as the source operand}}
+  %1 = vector.insert %a, %b[0, 0, 0] : vector<f32> into vector<4x8x16xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 961f1b5ffeabec..7eb77ec54ffcd4 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -284,12 +284,10 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
 }
 
 // CHECK-LABEL: @insert_0d
-func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
+func.func @insert_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
   // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
   %1 = vector.insert %a,  %b[] : f32 into vector<f32>
-  // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
-  %2 = vector.insert %b,  %c[0, 1] : vector<f32> into vector<2x3xf32>
-  return %1, %2 : vector<f32>, vector<2x3xf32>
+  return %1 : vector<f32>
 }
 
 // CHECK-LABEL: @outerproduct



More information about the Mlir-commits mailing list