[Mlir-commits] [mlir] [mlir][vector] Restrict vector.insert/vector.extract (PR #121458)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Jan 6 01:01:47 PST 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/121458

>From 473bd0f2a5edbf753b33edf5589d74378c67f31a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 29 Dec 2024 18:52:12 +0000
Subject: [PATCH] [mlir][vector] Restrict vector.insert/vector.extract

This patch restricts the use of vector.insert and vector.extract Ops in
the Vector dialect. Specifically:
  * The non-indexed operands for `vector.insert` and `vector.extract`
    must now be non-0-D vectors.

The following are now illegal. Note that the source and result types
(i.e. non-indexed args) are rank-0 vectors:

```mlir
  %0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32>
  %1 = vector.extract %arg0[0, 0] : vector<f32> from vector<2x2xf32>
```
Instead, use scalars as the source and result types:

```mlir
  %0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32>
  %1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32>
```

Put differently, this PR removes the ambiguity when it comes to
non-indexed operands of `vector.insert` and `vector.extract`. By
requiring that only one form is used, it eliminates the flexibility of
allowing both, thereby simplifying the semantics.

For more context, see the related RFC:
  * https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect
---
 .../Conversion/VectorToSCF/VectorToSCF.cpp    | 38 ++++++++++++++++---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 10 +++++
 mlir/test/Dialect/Vector/invalid.mlir         |  4 +-
 mlir/test/Dialect/Vector/ops.mlir             |  6 +--
 4 files changed, 46 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 01bc65c841e94c..24501305ee8010 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1287,6 +1287,10 @@ struct UnrollTransferReadConversion
 
   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
   /// accesses, and broadcasts and transposes in permutation maps.
+  ///
+  /// When unpacking rank-1 vectors (i.e. when the target rank is 0), replaces
+  /// `vector.transfer_read` with either `memref.load` or `tensor.extract` (for
+  /// MemRef and Tensor source, respectively).
   LogicalResult matchAndRewrite(TransferReadOp xferOp,
                                 PatternRewriter &rewriter) const override {
     if (xferOp.getVectorType().getRank() <= options.targetRank)
@@ -1319,6 +1323,8 @@ struct UnrollTransferReadConversion
     for (int64_t i = 0; i < dimSize; ++i) {
       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
 
+      // FIXME: Rename this lambda - it does much more than just
+      // in-bounds-check generation.
       vec = generateInBoundsCheck(
           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
           /*inBoundsCase=*/
@@ -1333,12 +1339,32 @@ struct UnrollTransferReadConversion
             insertionIndices.push_back(rewriter.getIndexAttr(i));
 
             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
-            auto newXferOp = b.create<vector::TransferReadOp>(
-                loc, newXferVecType, xferOp.getSource(), xferIndices,
-                AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
-                xferOp.getPadding(), Value(), inBoundsAttr);
-            maybeAssignMask(b, xferOp, newXferOp, i);
-            return b.create<vector::InsertOp>(loc, newXferOp, vec,
+
+            // A value that's read after rank-reducing the original
+            // vector.transfer_read Op.
+            Value unpackedReadRes;
+            if (newXferVecType.getRank() != 0) {
+              // Unpacking Vector that's rank > 2
+              // (use vector.transfer_read to load a rank-reduced vector)
+              unpackedReadRes = b.create<vector::TransferReadOp>(
+                  loc, newXferVecType, xferOp.getSource(), xferIndices,
+                  AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
+                  xferOp.getPadding(), Value(), inBoundsAttr);
+              maybeAssignMask(b, xferOp,
+                              dyn_cast<vector::TransferReadOp>(
+                                  unpackedReadRes.getDefiningOp()),
+                              i);
+            } else {
+              // Unpacking Vector that's rank == 1
+              // (use memref.load/tensor.extract to load a scalar)
+              unpackedReadRes =
+                  dyn_cast<MemRefType>(xferOp.getSource().getType())
+                      ? b.create<memref::LoadOp>(loc, xferOp.getSource(),
+                                                 xferIndices).getResult()
+                      : b.create<tensor::ExtractOp>(loc, xferOp.getSource(),
+                                                 xferIndices).getResult();
+            }
+            return b.create<vector::InsertOp>(loc, unpackedReadRes, vec,
                                               insertionIndices);
           },
           /*outOfBoundsCase=*/
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ae1cf95732336a..14c688250b3913 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1340,6 +1340,11 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 }
 
 LogicalResult vector::ExtractOp::verify() {
+  if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
+    if (resTy.getRank() == 0)
+      return emitError(
+          "expected a scalar instead of a 0-d vector as the result type");
+
   // Note: This check must come before getMixedPosition() to prevent a crash.
   auto dynamicMarkersCount =
       llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
@@ -2864,6 +2869,11 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
 }
 
 LogicalResult InsertOp::verify() {
+  if (auto srcTy = dyn_cast<VectorType>(getSourceType()))
+    if (srcTy.getRank() == 0)
+      return emitError(
+          "expected a scalar instead of a 0-d vector as the source operand");
+
   SmallVector<OpFoldResult> position = getMixedPosition();
   auto destVectorType = getDestVectorType();
   if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 1a70791fae1257..7c8edb3f3eeccb 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -260,8 +260,8 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
 // -----
 
 func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected position attribute rank + source rank to match dest vector rank}}
-  %1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
+  // expected-error at +1 {{expected a scalar instead of a 0-d vector as the source operand}}
+  %1 = vector.insert %a, %b[0, 0, 0] : vector<f32> into vector<4x8x16xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 961f1b5ffeabec..7eb77ec54ffcd4 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -284,12 +284,10 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
 }
 
 // CHECK-LABEL: @insert_0d
-func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
+func.func @insert_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
   // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
   %1 = vector.insert %a,  %b[] : f32 into vector<f32>
-  // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
-  %2 = vector.insert %b,  %c[0, 1] : vector<f32> into vector<2x3xf32>
-  return %1, %2 : vector<f32>, vector<2x3xf32>
+  return %1 : vector<f32>
 }
 
 // CHECK-LABEL: @outerproduct



More information about the Mlir-commits mailing list