[Mlir-commits] [mlir] [mlir][vector] Add verification for incorrect vector.extract (PR #115824)

Diego Caballero llvmlistbot at llvm.org
Mon Nov 11 22:34:25 PST 2024


https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/115824

This PR fixes the `vector.extract` verifier so that we have to provide as many indices as vector dimensions to extract a scalar. I.e., the following example is incorrect:

```
  %1 = vector.extract %arg0[0, 0] : f32 from vector<4x8x1xf32>
```

A similar check already exists for `vector.insert`.

>From 9bf7f2a09cbf14f0e726acf4b6d5cc99739624c2 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Mon, 11 Nov 2024 22:25:45 -0800
Subject: [PATCH] [mlir][vector] Add verification for incorrect vector.extract

This PR fixes the `vector.extract` verifier so that we have to provide
as many indices as vector dimensions to extract an scalar. I.e., the
following example is incorrect:

```
  %1 = vector.extract %arg0[0, 0] : f32 from vector<4x8x1xf32>
```

A similar check already exists for `vector.insert`
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 24 ++++++++++++++++++++++--
 mlir/test/Dialect/Vector/invalid.mlir    |  7 +++++++
 mlir/test/Dialect/Vector/ops.mlir        |  7 +++++++
 3 files changed, 36 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index db199a46e1637c..9b19b5cd0cd221 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1349,13 +1349,33 @@ LogicalResult vector::ExtractOp::verify() {
         "corresponding dynamic position) -- this can only happen due to an "
         "incorrect fold/rewrite");
   auto position = getMixedPosition();
-  if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
+  VectorType srcVecType = getSourceVectorType();
+  int64_t srcRank = srcVecType.getRank();
+  if (position.size() > static_cast<unsigned>(srcRank))
     return emitOpError(
         "expected position attribute of rank no greater than vector rank");
+
+  VectorType dstVecType = dyn_cast<VectorType>(getResult().getType());
+  if (dstVecType) {
+    int64_t srcRankMinusIndices = srcRank - getNumIndices();
+    int64_t dstRank = dstVecType.getRank();
+    if ((srcRankMinusIndices == 0 && dstRank != 1) ||
+        (srcRankMinusIndices != 0 && srcRankMinusIndices != dstRank)) {
+      return emitOpError(
+          "expected source rank minus number of indices to match "
+          "destination vector rank");
+    }
+  } else {
+    // Scalar result.
+    if (srcRank != getNumIndices())
+      return emitOpError("expected source rank to match number of indices "
+                         "for scalar result");
+  }
+
   for (auto [idx, pos] : llvm::enumerate(position)) {
     if (pos.is<Attribute>()) {
       int64_t constIdx = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
-      if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
+      if (constIdx < 0 || constIdx >= srcVecType.getDimSize(idx)) {
         return emitOpError("expected position attribute #")
                << (idx + 1)
                << " to be a non-negative integer smaller than the "
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d591c60acb64e7..e09b031dd49648 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -192,6 +192,13 @@ func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
 
 // -----
 
+func.func @extract_scalar_missing_indices(%arg0: vector<4x8x1xf32>) {
+  // expected-error at +1 {{expected source rank to match number of indices for scalar result}}
+  %1 = vector.extract %arg0[0, 0] : f32 from vector<4x8x1xf32>
+}
+
+// -----
+
 func.func @insert_element(%arg0: f32, %arg1: vector<f32>) {
   %c = arith.constant 3 : i32
   // expected-error at +1 {{expected position to be empty with 0-D vector}}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 3baacba9b61243..8595b0635cc94a 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -240,6 +240,13 @@ func.func @extract_0d(%a: vector<f32>) -> f32 {
   return %0 : f32
 }
 
+// CHECK-LABEL: @extract_single_element_vector
+func.func @extract_single_element_vector(%arg0: vector<4x8x3xf32>) -> vector<1xf32> {
+  // CHECK: vector.extract {{.*}}[0, 0, 0] : vector<1xf32> from vector<4x8x3xf32>
+  %1 = vector.extract %arg0[0, 0, 0] : vector<1xf32> from vector<4x8x3xf32>
+  return %1 : vector<1xf32>
+}
+
 // CHECK-LABEL: @insert_element_0d
 func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
   // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>



More information about the Mlir-commits mailing list