[Mlir-commits] [mlir] [mlir][vector] Add verification for incorrect vector.extract (PR #115824)

Diego Caballero llvmlistbot at llvm.org
Sat Nov 16 12:57:36 PST 2024


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/115824

>From 9bf7f2a09cbf14f0e726acf4b6d5cc99739624c2 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Mon, 11 Nov 2024 22:25:45 -0800
Subject: [PATCH 1/3] [mlir][vector] Add verification for incorrect
 vector.extract

This PR fixes the `vector.extract` verifier so that we have to provide
as many indices as vector dimensions to extract an scalar. I.e., the
following example is incorrect:

```
  %1 = vector.extract %arg0[0, 0] : f32 from vector<4x8x1xf32>
```

A similar check already exists for `vector.insert`
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 24 ++++++++++++++++++++++--
 mlir/test/Dialect/Vector/invalid.mlir    |  7 +++++++
 mlir/test/Dialect/Vector/ops.mlir        |  7 +++++++
 3 files changed, 36 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index db199a46e1637c..9b19b5cd0cd221 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1349,13 +1349,33 @@ LogicalResult vector::ExtractOp::verify() {
         "corresponding dynamic position) -- this can only happen due to an "
         "incorrect fold/rewrite");
   auto position = getMixedPosition();
-  if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
+  VectorType srcVecType = getSourceVectorType();
+  int64_t srcRank = srcVecType.getRank();
+  if (position.size() > static_cast<unsigned>(srcRank))
     return emitOpError(
         "expected position attribute of rank no greater than vector rank");
+
+  VectorType dstVecType = dyn_cast<VectorType>(getResult().getType());
+  if (dstVecType) {
+    int64_t srcRankMinusIndices = srcRank - getNumIndices();
+    int64_t dstRank = dstVecType.getRank();
+    if ((srcRankMinusIndices == 0 && dstRank != 1) ||
+        (srcRankMinusIndices != 0 && srcRankMinusIndices != dstRank)) {
+      return emitOpError(
+          "expected source rank minus number of indices to match "
+          "destination vector rank");
+    }
+  } else {
+    // Scalar result.
+    if (srcRank != getNumIndices())
+      return emitOpError("expected source rank to match number of indices "
+                         "for scalar result");
+  }
+
   for (auto [idx, pos] : llvm::enumerate(position)) {
     if (pos.is<Attribute>()) {
       int64_t constIdx = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
-      if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
+      if (constIdx < 0 || constIdx >= srcVecType.getDimSize(idx)) {
         return emitOpError("expected position attribute #")
                << (idx + 1)
                << " to be a non-negative integer smaller than the "
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d591c60acb64e7..e09b031dd49648 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -192,6 +192,13 @@ func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
 
 // -----
 
+func.func @extract_scalar_missing_indices(%arg0: vector<4x8x1xf32>) {
+  // expected-error at +1 {{expected source rank to match number of indices for scalar result}}
+  %1 = vector.extract %arg0[0, 0] : f32 from vector<4x8x1xf32>
+}
+
+// -----
+
 func.func @insert_element(%arg0: f32, %arg1: vector<f32>) {
   %c = arith.constant 3 : i32
   // expected-error at +1 {{expected position to be empty with 0-D vector}}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 3baacba9b61243..8595b0635cc94a 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -240,6 +240,13 @@ func.func @extract_0d(%a: vector<f32>) -> f32 {
   return %0 : f32
 }
 
+// CHECK-LABEL: @extract_single_element_vector
+func.func @extract_single_element_vector(%arg0: vector<4x8x3xf32>) -> vector<1xf32> {
+  // CHECK: vector.extract {{.*}}[0, 0, 0] : vector<1xf32> from vector<4x8x3xf32>
+  %1 = vector.extract %arg0[0, 0, 0] : vector<1xf32> from vector<4x8x3xf32>
+  return %1 : vector<1xf32>
+}
+
 // CHECK-LABEL: @insert_element_0d
 func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
   // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>

>From 6acae5282485f1cc9236aae1d0ececf77c1db72c Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Tue, 12 Nov 2024 15:16:49 -0800
Subject: [PATCH 2/3] Feedback

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 97 +++++++++++++----------
 mlir/test/Dialect/Vector/invalid.mlir    | 98 ++++++++++++++++++++----
 mlir/test/Dialect/Vector/ops.mlir        | 59 ++++++++++----
 3 files changed, 183 insertions(+), 71 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 9b19b5cd0cd221..2988d4d8194837 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1339,6 +1339,50 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
   return l == r;
 }
 
+// Common verification rules for `InsertOp` and `ExtractOp` involving indices.
+// `indexedType` is the vector type being indexed in the operation, i.e., the
+// destination type in InsertOp and the source type in ExtractOp.
+// `vecOrScalarType` is the type that is not indexed in the op and can be
+// either a scalar or a vector, i.e., the source type in InsertOp and the
+// return type in ExtractOp.
+static LogicalResult verifyInsertExtractIndices(Operation *op,
+                                                VectorType indexedType,
+                                                int64_t numIndices,
+                                                Type vecOrScalarType) {
+  int64_t indexedRank = indexedType.getRank();
+  if (numIndices > indexedRank)
+    return op->emitOpError(
+        "expected a number of indices no greater than the indexed vector rank");
+
+  if (auto nonIndexedVecType = dyn_cast<VectorType>(vecOrScalarType)) {
+    // Vector case, including:
+    //  * 0-D vector:
+    //    * vector.extract %src[2]: vector<f32> from vector<8xf32)
+    //    * vector.insert %src, %dst[3]: vector<f32> into vector<8xf32>
+    //  * One-element vector:
+    //    * vector.extract %src[4]: vector<1xf32> from vector<8xf32>
+    //    * vector.insert %src, %dst[1]: vector<1xf32> into vector<8xf32>
+    //    * vector.extract %src[7]: vector<1xf32> from vector<8x1xf32>
+    //    * vector.insert %src, %dst[5]: vector<1xf32> into vector<8x1xf32>
+    int64_t indexedRankMinusIndices = indexedRank - numIndices;
+    int64_t nonIndexedRank = nonIndexedVecType.getRank();
+    bool isOneElemVec =
+        nonIndexedRank == 1 && nonIndexedVecType.getDimSize(0) == 1;
+    if (indexedRankMinusIndices != nonIndexedRank &&
+        (!isOneElemVec || indexedRankMinusIndices != 0)) {
+      return op->emitOpError(
+          "expected indexed vector rank minus number of indices to match "
+          "the rank of the non-indexed vector rank");
+    }
+  } else if (indexedRank != numIndices) {
+    // Scalar case.
+    return op->emitOpError("expected indexed vector rank to match the number "
+                           "of indices for scalar cases");
+  }
+
+  return success();
+}
+
 LogicalResult vector::ExtractOp::verify() {
   // Note: This check must come before getMixedPosition() to prevent a crash.
   auto dynamicMarkersCount =
@@ -1348,31 +1392,12 @@ LogicalResult vector::ExtractOp::verify() {
         "mismatch between dynamic and static positions (kDynamic marker but no "
         "corresponding dynamic position) -- this can only happen due to an "
         "incorrect fold/rewrite");
-  auto position = getMixedPosition();
-  VectorType srcVecType = getSourceVectorType();
-  int64_t srcRank = srcVecType.getRank();
-  if (position.size() > static_cast<unsigned>(srcRank))
-    return emitOpError(
-        "expected position attribute of rank no greater than vector rank");
-
-  VectorType dstVecType = dyn_cast<VectorType>(getResult().getType());
-  if (dstVecType) {
-    int64_t srcRankMinusIndices = srcRank - getNumIndices();
-    int64_t dstRank = dstVecType.getRank();
-    if ((srcRankMinusIndices == 0 && dstRank != 1) ||
-        (srcRankMinusIndices != 0 && srcRankMinusIndices != dstRank)) {
-      return emitOpError(
-          "expected source rank minus number of indices to match "
-          "destination vector rank");
-    }
-  } else {
-    // Scalar result.
-    if (srcRank != getNumIndices())
-      return emitOpError("expected source rank to match number of indices "
-                         "for scalar result");
-  }
+  auto srcVecType = getSourceVectorType();
+  if (failed(verifyInsertExtractIndices(*this, srcVecType, getNumIndices(),
+                                        getResult().getType())))
+    return failure();
 
-  for (auto [idx, pos] : llvm::enumerate(position)) {
+  for (auto [idx, pos] : llvm::enumerate(getMixedPosition())) {
     if (pos.is<Attribute>()) {
       int64_t constIdx = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
       if (constIdx < 0 || constIdx >= srcVecType.getDimSize(idx)) {
@@ -2881,25 +2906,15 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
 }
 
 LogicalResult InsertOp::verify() {
-  SmallVector<OpFoldResult> position = getMixedPosition();
-  auto destVectorType = getDestVectorType();
-  if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
-    return emitOpError(
-        "expected position attribute of rank no greater than dest vector rank");
-  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
-  if (srcVectorType &&
-      (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
-       static_cast<unsigned>(destVectorType.getRank())))
-    return emitOpError("expected position attribute rank + source rank to "
-                       "match dest vector rank");
-  if (!srcVectorType &&
-      (position.size() != static_cast<unsigned>(destVectorType.getRank())))
-    return emitOpError(
-        "expected position attribute rank to match the dest vector rank");
-  for (auto [idx, pos] : llvm::enumerate(position)) {
+  auto dstVecType = getDestVectorType();
+  if (failed(verifyInsertExtractIndices(*this, dstVecType, getNumIndices(),
+                                        getSourceType())))
+    return failure();
+
+  for (auto [idx, pos] : llvm::enumerate(getMixedPosition())) {
     if (auto attr = pos.dyn_cast<Attribute>()) {
       int64_t constIdx = cast<IntegerAttr>(attr).getInt();
-      if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
+      if (constIdx < 0 || constIdx >= dstVecType.getDimSize(idx)) {
         return emitOpError("expected position attribute #")
                << (idx + 1)
                << " to be a non-negative integer smaller than the "
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index e09b031dd49648..71a36b58308c09 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -151,14 +151,14 @@ func.func @extract_vector_type(%arg0: index) {
 // -----
 
 func.func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected position attribute of rank no greater than vector rank}}
+  // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
   %1 = vector.extract %arg0[0, 0, 0, 0] : f32 from vector<4x8x16xf32>
 }
 
 // -----
 
 func.func @extract_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected position attribute of rank no greater than vector rank}}
+  // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
   %1 = "vector.extract" (%arg0) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<4x8x16xf32>) -> (vector<16xf32>)
 }
 
@@ -178,22 +178,94 @@ func.func @extract_precise_position_overflow(%arg0: vector<4x8x16xf32>) {
 
 // -----
 
-func.func @extract_0d(%arg0: vector<f32>) {
-  // expected-error at +1 {{expected position attribute of rank no greater than vector rank}}
+func.func @extract_from_0d_to_scalar_wrong_index(%arg0: vector<f32>) {
+  // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
   %1 = vector.extract %arg0[0] : f32 from vector<f32>
 }
 
 // -----
 
+func.func @extract_from_0d_to_0d_wrong_index(%arg0: vector<f32>) {
+  // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
+  %2 = vector.extract %arg0[0] : vector<f32> from vector<f32>
+}
+
+// -----
+
+func.func @extract_from_0d_to_1d_wrong_index(%arg0: vector<f32>) {
+  // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
+  %3 = vector.extract %arg0[0] : vector<1xf32> from vector<f32>
+}
+
+// -----
+
+func.func @extract_from_1d_to_scalar_wrong_index(%arg0: vector<1xf32>) {
+  // expected-error at +1 {{expected indexed vector rank to match the number of indices for scalar cases}}
+  %1 = vector.extract %arg0[] : f32 from vector<1xf32>
+}
+
+// -----
+
+func.func @extract_from_1d_to_0d_wrong_index(%arg0: vector<1xf32>) {
+  // expected-error at +1 {{'vector.extract' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
+  %2 = vector.extract %arg0[] : vector<f32> from vector<1xf32>
+}
+
+// -----
+
+func.func @extract_from_1d_to_0d(%arg0: vector<1xf32>) {
+  // expected-error at +2 {{'vector.extract' op inferred type(s) 'f32' are incompatible with return type(s) of operation 'vector<f32>'}}
+  // expected-error at +1 {{failed to infer returned types}}
+  %4 = vector.extract %arg0[0] : vector<f32> from vector<1xf32>
+}
+
+// -----
+
+func.func @extract_from_2d_to_scalar(%arg0: vector<4x1xf32>) {
+  // expected-error at +1 {{'vector.extract' op expected indexed vector rank to match the number of indices for scalar cases}}
+  %6 = vector.extract %arg0[2] : f32 from vector<4x1xf32>
+}
+
+// -----
+
+func.func @extract_from_2d_to_0d(%arg0: vector<4x1xf32>) {
+  // expected-error at +1 {{'vector.extract' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
+  %7 = vector.extract %arg0[2] : vector<f32> from vector<4x1xf32>
+}
+
+// -----
+
+func.func @extract_from_2d_to_scalar_wrong_index(%arg0: vector<4x8xf32>) {
+  // expected-error at +1 {{'vector.extract' op expected indexed vector rank to match the number of indices for scalar cases}}
+  %8 = vector.extract %arg0[3] : f32 from vector<4x8xf32>
+}
+
+// -----
+
+func.func @extract_from_2d_to_0d(%arg0: vector<4x8xf32>) {
+  // expected-error at +1 {{'vector.extract' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
+  %9 = vector.extract %arg0[3] : vector<f32> from vector<4x8xf32>
+}
+
+// -----
+
+func.func @extract_from_2d_to_1d(%arg0: vector<4x8xf32>) {
+  // expected-error at +2 {{'vector.extract' op inferred type(s) 'vector<8xf32>' are incompatible with return type(s) of operation 'vector<1xf32>'}}
+  // expected-error at +1 {{failed to infer returned types}}
+  %10 = vector.extract %arg0[3] : vector<1xf32> from vector<4x8xf32>
+}
+
+// -----
+
 func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
+  // expected-error at +1 {{'vector.extract' op expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
   %1 = vector.extract %arg0[0, 0, -1] : f32 from vector<4x8x16xf32>
 }
 
 // -----
 
 func.func @extract_scalar_missing_indices(%arg0: vector<4x8x1xf32>) {
-  // expected-error at +1 {{expected source rank to match number of indices for scalar result}}
+  // expected-error at +1 {{'vector.extract' op expected indexed vector rank to match the number of indices for scalar cases}}
   %1 = vector.extract %arg0[0, 0] : f32 from vector<4x8x1xf32>
 }
 
@@ -201,7 +273,7 @@ func.func @extract_scalar_missing_indices(%arg0: vector<4x8x1xf32>) {
 
 func.func @insert_element(%arg0: f32, %arg1: vector<f32>) {
   %c = arith.constant 3 : i32
-  // expected-error at +1 {{expected position to be empty with 0-D vector}}
+  // expected-error at +1 {{'vector.insertelement' op expected position to be empty with 0-D vector}}
   %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<f32>
 }
 
@@ -209,7 +281,7 @@ func.func @insert_element(%arg0: f32, %arg1: vector<f32>) {
 
 func.func @insert_element(%arg0: f32, %arg1: vector<4xf32>) {
   %c = arith.constant 3 : i32
-  // expected-error at +1 {{expected position for 1-D vector}}
+  // expected-error at +1 {{'vector.insertelement' op expected position for 1-D vector}}
   %0 = vector.insertelement %arg0, %arg1[] : vector<4xf32>
 }
 
@@ -232,21 +304,21 @@ func.func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) {
 // -----
 
 func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected position attribute of rank no greater than dest vector rank}}
+  // expected-error at +1 {{'vector.insert' op expected a number of indices no greater than the indexed vector rank}}
   %1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32>
 }
 
 // -----
 
 func.func @insert_vector_type(%a: vector<4xf32>, %b: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected position attribute rank + source rank to match dest vector rank}}
+  // expected-error at +1 {{'vector.insert' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
   %1 = vector.insert %a, %b[3] : vector<4xf32> into vector<4x8x16xf32>
 }
 
 // -----
 
 func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected position attribute rank to match the dest vector rank}}
+  // expected-error at +1 {{'vector.insert' op expected indexed vector rank to match the number of indices for scalar cases}}
   %1 = vector.insert %a, %b[3, 3] : f32 into vector<4x8x16xf32>
 }
 
@@ -267,14 +339,14 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
 // -----
 
 func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected position attribute rank + source rank to match dest vector rank}}
+  // expected-error at +1 {{'vector.insert' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
   %1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
 }
 
 // -----
 
 func.func @insert_0d(%a: f32, %b: vector<f32>) {
-  // expected-error at +1 {{expected position attribute of rank no greater than dest vector rank}}
+  // expected-error at +1 {{'vector.insert' op expected a number of indices no greater than the indexed vector rank}}
   %1 = vector.insert %a, %b[0] : f32 into vector<f32>
 }
 
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 8595b0635cc94a..8d842afceca206 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -224,7 +224,7 @@ func.func @extract_const_idx(%arg0: vector<4x8x16xf32>)
 //  CHECK-SAME:   %[[VEC:.+]]: vector<4x8x16xf32>, %[[IDX:.+]]: index
 func.func @extract_val_idx(%arg0: vector<4x8x16xf32>, %idx: index)
                            -> (vector<8x16xf32>, vector<16xf32>, f32) {
-  // CHECK: vector.extract %[[VEC]][%[[IDX]]] : vector<8x16xf32> from vector<4x8x16xf32>
+  //      CHECK: vector.extract %[[VEC]][%[[IDX]]] : vector<8x16xf32> from vector<4x8x16xf32>
   %0 = vector.extract %arg0[%idx] : vector<8x16xf32> from vector<4x8x16xf32>
   // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], %[[IDX]]] : vector<16xf32> from vector<4x8x16xf32>
   %1 = vector.extract %arg0[%idx, %idx] : vector<16xf32> from vector<4x8x16xf32>
@@ -234,17 +234,24 @@ func.func @extract_val_idx(%arg0: vector<4x8x16xf32>, %idx: index)
 }
 
 // CHECK-LABEL: @extract_0d
-func.func @extract_0d(%a: vector<f32>) -> f32 {
-  // CHECK-NEXT: vector.extract %{{.*}}[] : f32 from vector<f32>
-  %0 = vector.extract %a[] : f32 from vector<f32>
-  return %0 : f32
-}
-
-// CHECK-LABEL: @extract_single_element_vector
-func.func @extract_single_element_vector(%arg0: vector<4x8x3xf32>) -> vector<1xf32> {
-  // CHECK: vector.extract {{.*}}[0, 0, 0] : vector<1xf32> from vector<4x8x3xf32>
-  %1 = vector.extract %arg0[0, 0, 0] : vector<1xf32> from vector<4x8x3xf32>
-  return %1 : vector<1xf32>
+func.func @extract_0d(%arg0: vector<f32>) -> (f32, vector<1xf32>) {
+  //      CHECK: vector.extract %{{.*}}[] : f32 from vector<f32>
+  %0 = vector.extract %arg0[] : f32 from vector<f32>
+  // CHECK-NEXT: vector.extract %{{.*}}[] : vector<1xf32> from vector<f32>
+  %1 = vector.extract %arg0[] : vector<1xf32> from vector<f32>
+  return %0, %1 : f32, vector<1xf32>
+}
+
+// CHECK-LABEL: @extract_1d
+func.func @extract_1d(%arg0: vector<1xf32>, %arg1: vector<4x1xf32>)
+                      -> (f32, vector<1xf32>, vector<1xf32>) {
+  //      CHECK: vector.extract %{{.*}}[0] : f32 from vector<1xf32>
+  %0 = vector.extract %arg0[0] : f32 from vector<1xf32>
+  // CHECK-NEXT: vector.extract %{{.*}}[0] : vector<1xf32> from vector<1xf32>
+  %1 = vector.extract %arg0[0] : vector<1xf32> from vector<1xf32>
+  // CHECK-NEXT: vector.extract %{{.*}}[2] : vector<1xf32> from vector<4x1xf32>
+  %2 = vector.extract %arg1[2] : vector<1xf32> from vector<4x1xf32>
+  return %0, %1, %2 : f32, vector<1xf32>, vector<1xf32>
 }
 
 // CHECK-LABEL: @insert_element_0d
@@ -291,12 +298,30 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
 }
 
 // CHECK-LABEL: @insert_0d
-func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
-  // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
-  %1 = vector.insert %a,  %b[] : f32 into vector<f32>
+func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<1xf32>, %d: vector<2x3xf32>)
+                     -> (vector<f32>, vector<f32>, vector<2x3xf32>) {
+  //      CHECK: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
+  %1 = vector.insert %a, %b[] : f32 into vector<f32>
+  // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : vector<1xf32> into vector<f32>
+  %2 = vector.insert %c, %b[] : vector<1xf32> into vector<f32>
   // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
-  %2 = vector.insert %b,  %c[0, 1] : vector<f32> into vector<2x3xf32>
-  return %1, %2 : vector<f32>, vector<2x3xf32>
+  %3 = vector.insert %b, %d[0, 1] : vector<f32> into vector<2x3xf32>
+  return %1, %2, %3 : vector<f32>, vector<f32>, vector<2x3xf32>
+}
+
+// CHECK-LABEL: @insert_1d
+func.func @insert_1d(%a: f32, %b: vector<1xf32>, %c: vector<4x1xf32>,
+                     %d: vector<2x3xf32>)
+                     -> (vector<1xf32>, vector<1xf32>, vector<4x1xf32>, vector<2x3xf32>) {
+  //      CHECK: vector.insert %{{.*}}, %{{.*}}[0] : f32 into vector<1xf32>
+  %0 = vector.insert %a, %b[0] : f32 into vector<1xf32>
+  // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0] : vector<1xf32> into vector<1xf32>
+  %1 = vector.insert %0, %b[0] : vector<1xf32> into vector<1xf32>
+  // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[2] : vector<1xf32> into vector<4x1xf32>
+  %2 = vector.insert %b, %c[2] : vector<1xf32> into vector<4x1xf32>
+  // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<1xf32> into vector<2x3xf32>
+  %3 = vector.insert %b, %d[0, 1] : vector<1xf32> into vector<2x3xf32>
+  return %0, %1, %2, %3 : vector<1xf32>, vector<1xf32>, vector<4x1xf32>, vector<2x3xf32>
 }
 
 // CHECK-LABEL: @outerproduct

>From c55e252d57089fc956365372e772268a92dbc571 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Fri, 15 Nov 2024 19:25:58 -0800
Subject: [PATCH 3/3] Improvements

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp |  91 +++++++++-----
 mlir/test/Dialect/Vector/invalid.mlir    | 151 +++++++++++++++++------
 mlir/test/Dialect/Vector/ops.mlir        |  20 ++-
 3 files changed, 183 insertions(+), 79 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2988d4d8194837..eb006233201dc6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1339,23 +1339,28 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
   return l == r;
 }
 
-// Common verification rules for `InsertOp` and `ExtractOp` involving indices.
-// `indexedType` is the vector type being indexed in the operation, i.e., the
-// destination type in InsertOp and the source type in ExtractOp.
-// `vecOrScalarType` is the type that is not indexed in the op and can be
-// either a scalar or a vector, i.e., the source type in InsertOp and the
-// return type in ExtractOp.
-static LogicalResult verifyInsertExtractIndices(Operation *op,
-                                                VectorType indexedType,
-                                                int64_t numIndices,
-                                                Type vecOrScalarType) {
+// Common verification rules for `InsertOp` and `ExtractOp` involving indices
+// and shapes. `indexedType` is the vector type being indexed by the operation,
+// i.e., the destination type in `InsertOp` and the source type in `ExtractOp`.
+// `nonIndexedType` is the inserted or extracted type by an `InsertOp` or and
+// `ExtractOp`, respectively.
+static LogicalResult verifyInsertExtractIndicesAndShapes(Operation *op,
+                                                         VectorType indexedType,
+                                                         int64_t numIndices,
+                                                         Type nonIndexedType) {
+  assert((isa<InsertOp>(op) || isa<ExtractOp>(op)) &&
+         "Expected InsertOp or ExtractOp");
+
+  std::string nonIndexedStr = isa<InsertOp>(op) ? "inserted" : "extracted";
+  std::string indexedStr = isa<InsertOp>(op) ? "destination" : "source";
   int64_t indexedRank = indexedType.getRank();
   if (numIndices > indexedRank)
-    return op->emitOpError(
-        "expected a number of indices no greater than the indexed vector rank");
+    return op->emitOpError()
+           << "expected a number of indices no greater than the " << indexedStr
+           << " vector rank";
 
-  if (auto nonIndexedVecType = dyn_cast<VectorType>(vecOrScalarType)) {
-    // Vector case, including:
+  if (auto nonIndexedVecType = dyn_cast<VectorType>(nonIndexedType)) {
+    // Vector case, including meaningful cases such as:
     //  * 0-D vector:
     //    * vector.extract %src[2]: vector<f32> from vector<8xf32)
     //    * vector.insert %src, %dst[3]: vector<f32> into vector<8xf32>
@@ -1364,20 +1369,46 @@ static LogicalResult verifyInsertExtractIndices(Operation *op,
     //    * vector.insert %src, %dst[1]: vector<1xf32> into vector<8xf32>
     //    * vector.extract %src[7]: vector<1xf32> from vector<8x1xf32>
     //    * vector.insert %src, %dst[5]: vector<1xf32> into vector<8x1xf32>
-    int64_t indexedRankMinusIndices = indexedRank - numIndices;
     int64_t nonIndexedRank = nonIndexedVecType.getRank();
-    bool isOneElemVec =
-        nonIndexedRank == 1 && nonIndexedVecType.getDimSize(0) == 1;
+    bool isSingleElem1DNonIndexedVec =
+        (nonIndexedRank == 1 && nonIndexedVecType.getDimSize(0) == 1);
+    bool isSingleElem1DIndexedVec =
+        (indexedRank == 1 && indexedType.getDimSize(0) == 1);
+    // Verify 0-D -> single-element 1-D supported cases.
+    if ((indexedRank == 0 && isSingleElem1DNonIndexedVec) ||
+        (nonIndexedRank == 0 && isSingleElem1DIndexedVec)) {
+      return op->emitOpError("expected source and destination vectors with "
+                             "different number of elements");
+    }
+
+    // Verify indices for all the cases.
+    int64_t indexedRankMinusIndices = indexedRank - numIndices;
     if (indexedRankMinusIndices != nonIndexedRank &&
-        (!isOneElemVec || indexedRankMinusIndices != 0)) {
-      return op->emitOpError(
-          "expected indexed vector rank minus number of indices to match "
-          "the rank of the non-indexed vector rank");
+        (!isSingleElem1DNonIndexedVec || indexedRankMinusIndices != 0)) {
+      return op->emitOpError()
+             << "expected " << indexedStr
+             << " vector rank minus number of indices to match the rank of the "
+             << nonIndexedStr << " vector";
     }
-  } else if (indexedRank != numIndices) {
-    // Scalar case.
-    return op->emitOpError("expected indexed vector rank to match the number "
-                           "of indices for scalar cases");
+    // Check that if we are inserting or extracting a sub-vector, the
+    // corresponding source and destination shapes match.
+    if (indexedRankMinusIndices > 0) {
+      auto indexedShape = indexedType.getShape();
+      if (indexedShape.drop_front(numIndices) != nonIndexedVecType.getShape())
+        return op->emitOpError() << "expected " << nonIndexedStr
+                                 << " vector shape to match the sub-vector "
+                                    "shape of the "
+                                 << indexedStr << " vector";
+    }
+
+    return success();
+  }
+
+  // Scalar case.
+  if (indexedRank != numIndices) {
+    return op->emitOpError()
+           << "expected " << indexedStr
+           << " vector rank to match the number of indices for scalar cases";
   }
 
   return success();
@@ -1393,9 +1424,10 @@ LogicalResult vector::ExtractOp::verify() {
         "corresponding dynamic position) -- this can only happen due to an "
         "incorrect fold/rewrite");
   auto srcVecType = getSourceVectorType();
-  if (failed(verifyInsertExtractIndices(*this, srcVecType, getNumIndices(),
-                                        getResult().getType())))
+  if (failed(verifyInsertExtractIndicesAndShapes(
+          *this, srcVecType, getNumIndices(), getResult().getType()))) {
     return failure();
+  }
 
   for (auto [idx, pos] : llvm::enumerate(getMixedPosition())) {
     if (pos.is<Attribute>()) {
@@ -2907,9 +2939,10 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
 
 LogicalResult InsertOp::verify() {
   auto dstVecType = getDestVectorType();
-  if (failed(verifyInsertExtractIndices(*this, dstVecType, getNumIndices(),
-                                        getSourceType())))
+  if (failed(verifyInsertExtractIndicesAndShapes(
+          *this, dstVecType, getNumIndices(), getSourceType()))) {
     return failure();
+  }
 
   for (auto [idx, pos] : llvm::enumerate(getMixedPosition())) {
     if (auto attr = pos.dyn_cast<Attribute>()) {
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 71a36b58308c09..b8c932a93c8a25 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -151,107 +151,112 @@ func.func @extract_vector_type(%arg0: index) {
 // -----
 
 func.func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
+  // expected-error at +1 {{'vector.extract' op expected a number of indices no greater than the source vector rank}}
   %1 = vector.extract %arg0[0, 0, 0, 0] : f32 from vector<4x8x16xf32>
 }
 
 // -----
 
 func.func @extract_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
+  // expected-error at +1 {{'vector.extract' op expected a number of indices no greater than the source vector rank}}
   %1 = "vector.extract" (%arg0) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<4x8x16xf32>) -> (vector<16xf32>)
 }
 
 // -----
 
 func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected position attribute #2 to be a non-negative integer smaller than the corresponding vector dimension}}
+  // expected-error at +1 {{'vector.extract' op expected position attribute #2 to be a non-negative integer smaller than the corresponding vector dimension}}
   %1 = vector.extract %arg0[0, 43, 0] : f32 from vector<4x8x16xf32>
 }
 
 // -----
 
 func.func @extract_precise_position_overflow(%arg0: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
+  // expected-error at +1 {{'vector.extract' op expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
   %1 = vector.extract %arg0[3, 7, 16] : f32 from vector<4x8x16xf32>
 }
 
 // -----
 
-func.func @extract_from_0d_to_scalar_wrong_index(%arg0: vector<f32>) {
-  // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
+func.func @extract_scalar_from_0d_wrong_index(%arg0: vector<f32>) {
+  // expected-error at +1 {{'vector.extract' op expected a number of indices no greater than the source vector rank}}
   %1 = vector.extract %arg0[0] : f32 from vector<f32>
 }
 
 // -----
 
-func.func @extract_from_0d_to_0d_wrong_index(%arg0: vector<f32>) {
-  // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
+func.func @extract_0d_from_0d_wrong_index(%arg0: vector<f32>) {
+  // expected-error at +1 {{'vector.extract' op expected a number of indices no greater than the source vector rank}}
   %2 = vector.extract %arg0[0] : vector<f32> from vector<f32>
 }
 
 // -----
 
-func.func @extract_from_0d_to_1d_wrong_index(%arg0: vector<f32>) {
-  // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
+func.func @extract_1d_from_0d_wrong_index(%arg0: vector<f32>) {
+  // expected-error at +1 {{'vector.extract' op expected a number of indices no greater than the source vector rank}}
   %3 = vector.extract %arg0[0] : vector<1xf32> from vector<f32>
 }
 
 // -----
 
-func.func @extract_from_1d_to_scalar_wrong_index(%arg0: vector<1xf32>) {
-  // expected-error at +1 {{expected indexed vector rank to match the number of indices for scalar cases}}
+func.func @extract_scalar_from_1d_wrong_index(%arg0: vector<1xf32>) {
+  // expected-error at +1 {{'vector.extract' op expected source vector rank to match the number of indices for scalar cases}}
   %1 = vector.extract %arg0[] : f32 from vector<1xf32>
 }
 
 // -----
 
-func.func @extract_from_1d_to_0d_wrong_index(%arg0: vector<1xf32>) {
-  // expected-error at +1 {{'vector.extract' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
+func.func @extract_0d_from_1d_wrong_index(%arg0: vector<1xf32>) {
+  // expected-error at +1 {{'vector.extract' op expected source and destination vectors with different number of elements}}
   %2 = vector.extract %arg0[] : vector<f32> from vector<1xf32>
 }
 
 // -----
 
-func.func @extract_from_1d_to_0d(%arg0: vector<1xf32>) {
-  // expected-error at +2 {{'vector.extract' op inferred type(s) 'f32' are incompatible with return type(s) of operation 'vector<f32>'}}
-  // expected-error at +1 {{failed to infer returned types}}
+func.func @extract_0d_from_1d(%arg0: vector<1xf32>) {
+  // expected-error at +1 {{'vector.extract' op expected source and destination vectors with different number of elements}}
   %4 = vector.extract %arg0[0] : vector<f32> from vector<1xf32>
 }
 
 // -----
 
-func.func @extract_from_2d_to_scalar(%arg0: vector<4x1xf32>) {
-  // expected-error at +1 {{'vector.extract' op expected indexed vector rank to match the number of indices for scalar cases}}
+func.func @extract_1d_from_0d(%arg0: vector<f32>) {
+  // expected-error at +1 {{'vector.extract' op expected source and destination vectors with different number of elements}}
+  %4 = vector.extract %arg0[] : vector<1xf32> from vector<f32>
+}
+
+// -----
+
+func.func @extract_scalar_from_2d(%arg0: vector<4x1xf32>) {
+  // expected-error at +1 {{'vector.extract' op expected source vector rank to match the number of indices for scalar cases}}
   %6 = vector.extract %arg0[2] : f32 from vector<4x1xf32>
 }
 
 // -----
 
-func.func @extract_from_2d_to_0d(%arg0: vector<4x1xf32>) {
-  // expected-error at +1 {{'vector.extract' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
+func.func @extract_0d_from_2d(%arg0: vector<4x1xf32>) {
+  // expected-error at +1 {{'vector.extract' op expected source vector rank minus number of indices to match the rank of the extracted vector}}
   %7 = vector.extract %arg0[2] : vector<f32> from vector<4x1xf32>
 }
 
 // -----
 
-func.func @extract_from_2d_to_scalar_wrong_index(%arg0: vector<4x8xf32>) {
-  // expected-error at +1 {{'vector.extract' op expected indexed vector rank to match the number of indices for scalar cases}}
+func.func @extract_scalar_from_2d_wrong_index(%arg0: vector<4x8xf32>) {
+  // expected-error at +1 {{'vector.extract' op expected source vector rank to match the number of indices for scalar cases}}
   %8 = vector.extract %arg0[3] : f32 from vector<4x8xf32>
 }
 
 // -----
 
-func.func @extract_from_2d_to_0d(%arg0: vector<4x8xf32>) {
-  // expected-error at +1 {{'vector.extract' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
+func.func @extract_0d_from_2d(%arg0: vector<4x8xf32>) {
+  // expected-error at +1 {{'vector.extract' op expected source vector rank minus number of indices to match the rank of the extracted vector}}
   %9 = vector.extract %arg0[3] : vector<f32> from vector<4x8xf32>
 }
 
 // -----
 
-func.func @extract_from_2d_to_1d(%arg0: vector<4x8xf32>) {
-  // expected-error at +2 {{'vector.extract' op inferred type(s) 'vector<8xf32>' are incompatible with return type(s) of operation 'vector<1xf32>'}}
-  // expected-error at +1 {{failed to infer returned types}}
+func.func @extract_1d_from_2d(%arg0: vector<4x8xf32>) {
+  // expected-error at +1 {{'vector.extract' op expected extracted vector shape to match the sub-vector shape of the source vector}}
   %10 = vector.extract %arg0[3] : vector<1xf32> from vector<4x8xf32>
 }
 
@@ -265,7 +270,7 @@ func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
 // -----
 
 func.func @extract_scalar_missing_indices(%arg0: vector<4x8x1xf32>) {
-  // expected-error at +1 {{'vector.extract' op expected indexed vector rank to match the number of indices for scalar cases}}
+  // expected-error at +1 {{'vector.extract' op expected source vector rank to match the number of indices for scalar cases}}
   %1 = vector.extract %arg0[0, 0] : f32 from vector<4x8x1xf32>
 }
 
@@ -304,21 +309,21 @@ func.func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) {
 // -----
 
 func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
-  // expected-error at +1 {{'vector.insert' op expected a number of indices no greater than the indexed vector rank}}
+  // expected-error at +1 {{'vector.insert' op expected a number of indices no greater than the destination vector rank}}
   %1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32>
 }
 
 // -----
 
 func.func @insert_vector_type(%a: vector<4xf32>, %b: vector<4x8x16xf32>) {
-  // expected-error at +1 {{'vector.insert' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
+  // expected-error at +1 {{'vector.insert' op expected destination vector rank minus number of indices to match the rank of the inserted vector}}
   %1 = vector.insert %a, %b[3] : vector<4xf32> into vector<4x8x16xf32>
 }
 
 // -----
 
 func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
-  // expected-error at +1 {{'vector.insert' op expected indexed vector rank to match the number of indices for scalar cases}}
+  // expected-error at +1 {{'vector.insert' op expected destination vector rank to match the number of indices for scalar cases}}
   %1 = vector.insert %a, %b[3, 3] : f32 into vector<4x8x16xf32>
 }
 
@@ -338,16 +343,86 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
 
 // -----
 
-func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
-  // expected-error at +1 {{'vector.insert' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
-  %1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
+func.func @insert_scalar_into_0d_wrong_index(%arg0: f32, %arg1: vector<f32>) {
+  // expected-error at +1 {{'vector.insert' op expected a number of indices no greater than the destination vector rank}}
+  %1 = vector.insert %arg0, %arg1[0] : f32 into vector<f32>
+}
+
+// -----
+
+func.func @insert_0d_into_0d_wrong_index(%arg0: vector<f32>, %arg1: vector<f32>) {
+  // expected-error at +1 {{'vector.insert' op expected a number of indices no greater than the destination vector rank}}
+  %2 = vector.insert %arg0, %arg1[0] : vector<f32> into vector<f32>
+}
+
+// -----
+
+func.func @insert_1d_into_0d_wrong_index(%arg0: vector<1xf32>, %arg1: vector<f32>) {
+  // expected-error at +1 {{'vector.insert' op expected a number of indices no greater than the destination vector rank}}
+  %3 = vector.insert %arg0, %arg1[0] : vector<1xf32> into vector<f32>
+}
+
+// -----
+
+func.func @insert_scalar_into_1d_wrong_index(%arg0: f32, %arg1: vector<1xf32>) {
+  // expected-error at +1 {{'vector.insert' op expected destination vector rank to match the number of indices for scalar cases}}
+  %1 = vector.insert %arg0, %arg1[] : f32 into vector<1xf32>
+}
+
+// -----
+
+func.func @insert_0d_into_1d_wrong_index(%arg0: vector<f32>, %arg1: vector<1xf32>) {
+  // expected-error at +1 {{'vector.insert' op expected source and destination vectors with different number of elements}}
+  %2 = vector.insert %arg0, %arg1[] : vector<f32> into vector<1xf32>
+}
+
+// -----
+
+func.func @insert_0d_into_1d(%arg0: vector<f32>, %arg1: vector<1xf32>) {
+  // expected-error at +1 {{'vector.insert' op expected source and destination vectors with different number of elements}}
+  %4 = vector.insert %arg0, %arg1[0] : vector<f32> into vector<1xf32>
+}
+
+// -----
+
+func.func @insert_1d_into_0d(%arg0: vector<1xf32>, %arg1: vector<f32>) {
+  // expected-error at +1 {{'vector.insert' op expected source and destination vectors with different number of elements}}
+  %4 = vector.insert %arg0, %arg1[] : vector<1xf32> into vector<f32>
+}
+
+// -----
+
+func.func @insert_scalar_into_2d(%arg0: f32, %arg1: vector<4x1xf32>) {
+  // expected-error at +1 {{'vector.insert' op expected destination vector rank to match the number of indices for scalar cases}}
+  %6 = vector.insert %arg0, %arg1[2] : f32 into vector<4x1xf32>
+}
+
+// -----
+
+func.func @insert_0d_into_2d(%arg0: vector<f32>, %arg1: vector<4x1xf32>) {
+  // expected-error at +1 {{'vector.insert' op expected destination vector rank minus number of indices to match the rank of the inserted vector}}
+  %7 = vector.insert %arg0, %arg1[2] : vector<f32> into vector<4x1xf32>
+}
+
+// -----
+
+func.func @insert_scalar_into_2d_wrong_index(%arg0: f32, %arg1: vector<4x8xf32>) {
+  // expected-error at +1 {{'vector.insert' op expected destination vector rank to match the number of indices for scalar cases}}
+  %8 = vector.insert %arg0, %arg1[3] : f32 into vector<4x8xf32>
+}
+
+// -----
+
+func.func @insert_0d_into_2d(%arg0: vector<f32>, %arg1: vector<4x8xf32>) {
+  // expected-error at +1 {{'vector.insert' op expected destination vector rank minus number of indices to match the rank of the inserted vector}}
+  %9 = vector.insert %arg0, %arg1[3] : vector<f32> into vector<4x8xf32>
 }
 
 // -----
 
-func.func @insert_0d(%a: f32, %b: vector<f32>) {
-  // expected-error at +1 {{'vector.insert' op expected a number of indices no greater than the indexed vector rank}}
-  %1 = vector.insert %a, %b[0] : f32 into vector<f32>
+func.func @insert_1d_into_2d(%arg0: vector<1xf32>, %arg1: vector<4x8xf32>) {
+  // expected-error at +1 {{'vector.insert' op expected inserted vector shape to match the sub-vector shape of the destination vector}}
+  %10 = vector.insert %arg0, %arg1[3] : vector<1xf32> into vector<4x8xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 8d842afceca206..00a433424d6c46 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -234,24 +234,20 @@ func.func @extract_val_idx(%arg0: vector<4x8x16xf32>, %idx: index)
 }
 
 // CHECK-LABEL: @extract_0d
-func.func @extract_0d(%arg0: vector<f32>) -> (f32, vector<1xf32>) {
+func.func @extract_0d(%arg0: vector<f32>) -> f32 {
   //      CHECK: vector.extract %{{.*}}[] : f32 from vector<f32>
   %0 = vector.extract %arg0[] : f32 from vector<f32>
-  // CHECK-NEXT: vector.extract %{{.*}}[] : vector<1xf32> from vector<f32>
-  %1 = vector.extract %arg0[] : vector<1xf32> from vector<f32>
-  return %0, %1 : f32, vector<1xf32>
+  return %0 : f32
 }
 
 // CHECK-LABEL: @extract_1d
 func.func @extract_1d(%arg0: vector<1xf32>, %arg1: vector<4x1xf32>)
-                      -> (f32, vector<1xf32>, vector<1xf32>) {
+                      -> (f32, vector<1xf32>) {
   //      CHECK: vector.extract %{{.*}}[0] : f32 from vector<1xf32>
   %0 = vector.extract %arg0[0] : f32 from vector<1xf32>
-  // CHECK-NEXT: vector.extract %{{.*}}[0] : vector<1xf32> from vector<1xf32>
-  %1 = vector.extract %arg0[0] : vector<1xf32> from vector<1xf32>
   // CHECK-NEXT: vector.extract %{{.*}}[2] : vector<1xf32> from vector<4x1xf32>
-  %2 = vector.extract %arg1[2] : vector<1xf32> from vector<4x1xf32>
-  return %0, %1, %2 : f32, vector<1xf32>, vector<1xf32>
+  %1 = vector.extract %arg1[2] : vector<1xf32> from vector<4x1xf32>
+  return %0, %1 : f32, vector<1xf32>
 }
 
 // CHECK-LABEL: @insert_element_0d
@@ -302,8 +298,8 @@ func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<1xf32>, %d: vector<2x3
                      -> (vector<f32>, vector<f32>, vector<2x3xf32>) {
   //      CHECK: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
   %1 = vector.insert %a, %b[] : f32 into vector<f32>
-  // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : vector<1xf32> into vector<f32>
-  %2 = vector.insert %c, %b[] : vector<1xf32> into vector<f32>
+  // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : vector<f32> into vector<f32>
+  %2 = vector.insert %b, %b[] : vector<f32> into vector<f32>
   // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
   %3 = vector.insert %b, %d[0, 1] : vector<f32> into vector<2x3xf32>
   return %1, %2, %3 : vector<f32>, vector<f32>, vector<2x3xf32>
@@ -316,7 +312,7 @@ func.func @insert_1d(%a: f32, %b: vector<1xf32>, %c: vector<4x1xf32>,
   //      CHECK: vector.insert %{{.*}}, %{{.*}}[0] : f32 into vector<1xf32>
   %0 = vector.insert %a, %b[0] : f32 into vector<1xf32>
   // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0] : vector<1xf32> into vector<1xf32>
-  %1 = vector.insert %0, %b[0] : vector<1xf32> into vector<1xf32>
+  %1 = vector.insert %b, %b[0] : vector<1xf32> into vector<1xf32>
   // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[2] : vector<1xf32> into vector<4x1xf32>
   %2 = vector.insert %b, %c[2] : vector<1xf32> into vector<4x1xf32>
   // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<1xf32> into vector<2x3xf32>



More information about the Mlir-commits mailing list