[Mlir-commits] [mlir] [mlir][vector] Add verification for incorrect vector.extract (PR #115824)
Diego Caballero
llvmlistbot at llvm.org
Sat Nov 16 12:57:36 PST 2024
https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/115824
>From 9bf7f2a09cbf14f0e726acf4b6d5cc99739624c2 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Mon, 11 Nov 2024 22:25:45 -0800
Subject: [PATCH 1/3] [mlir][vector] Add verification for incorrect
vector.extract
This PR fixes the `vector.extract` verifier so that we have to provide
as many indices as vector dimensions to extract an scalar. I.e., the
following example is incorrect:
```
%1 = vector.extract %arg0[0, 0] : f32 from vector<4x8x1xf32>
```
A similar check already exists for `vector.insert`
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 24 ++++++++++++++++++++++--
mlir/test/Dialect/Vector/invalid.mlir | 7 +++++++
mlir/test/Dialect/Vector/ops.mlir | 7 +++++++
3 files changed, 36 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index db199a46e1637c..9b19b5cd0cd221 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1349,13 +1349,33 @@ LogicalResult vector::ExtractOp::verify() {
"corresponding dynamic position) -- this can only happen due to an "
"incorrect fold/rewrite");
auto position = getMixedPosition();
- if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
+ VectorType srcVecType = getSourceVectorType();
+ int64_t srcRank = srcVecType.getRank();
+ if (position.size() > static_cast<unsigned>(srcRank))
return emitOpError(
"expected position attribute of rank no greater than vector rank");
+
+ VectorType dstVecType = dyn_cast<VectorType>(getResult().getType());
+ if (dstVecType) {
+ int64_t srcRankMinusIndices = srcRank - getNumIndices();
+ int64_t dstRank = dstVecType.getRank();
+ if ((srcRankMinusIndices == 0 && dstRank != 1) ||
+ (srcRankMinusIndices != 0 && srcRankMinusIndices != dstRank)) {
+ return emitOpError(
+ "expected source rank minus number of indices to match "
+ "destination vector rank");
+ }
+ } else {
+ // Scalar result.
+ if (srcRank != getNumIndices())
+ return emitOpError("expected source rank to match number of indices "
+ "for scalar result");
+ }
+
for (auto [idx, pos] : llvm::enumerate(position)) {
if (pos.is<Attribute>()) {
int64_t constIdx = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
- if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
+ if (constIdx < 0 || constIdx >= srcVecType.getDimSize(idx)) {
return emitOpError("expected position attribute #")
<< (idx + 1)
<< " to be a non-negative integer smaller than the "
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d591c60acb64e7..e09b031dd49648 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -192,6 +192,13 @@ func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
// -----
+func.func @extract_scalar_missing_indices(%arg0: vector<4x8x1xf32>) {
+ // expected-error at +1 {{expected source rank to match number of indices for scalar result}}
+ %1 = vector.extract %arg0[0, 0] : f32 from vector<4x8x1xf32>
+}
+
+// -----
+
func.func @insert_element(%arg0: f32, %arg1: vector<f32>) {
%c = arith.constant 3 : i32
// expected-error at +1 {{expected position to be empty with 0-D vector}}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 3baacba9b61243..8595b0635cc94a 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -240,6 +240,13 @@ func.func @extract_0d(%a: vector<f32>) -> f32 {
return %0 : f32
}
+// CHECK-LABEL: @extract_single_element_vector
+func.func @extract_single_element_vector(%arg0: vector<4x8x3xf32>) -> vector<1xf32> {
+ // CHECK: vector.extract {{.*}}[0, 0, 0] : vector<1xf32> from vector<4x8x3xf32>
+ %1 = vector.extract %arg0[0, 0, 0] : vector<1xf32> from vector<4x8x3xf32>
+ return %1 : vector<1xf32>
+}
+
// CHECK-LABEL: @insert_element_0d
func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
// CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>
>From 6acae5282485f1cc9236aae1d0ececf77c1db72c Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Tue, 12 Nov 2024 15:16:49 -0800
Subject: [PATCH 2/3] Feedback
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 97 +++++++++++++----------
mlir/test/Dialect/Vector/invalid.mlir | 98 ++++++++++++++++++++----
mlir/test/Dialect/Vector/ops.mlir | 59 ++++++++++----
3 files changed, 183 insertions(+), 71 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 9b19b5cd0cd221..2988d4d8194837 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1339,6 +1339,50 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return l == r;
}
+// Common verification rules for `InsertOp` and `ExtractOp` involving indices.
+// `indexedType` is the vector type being indexed in the operation, i.e., the
+// destination type in InsertOp and the source type in ExtractOp.
+// `vecOrScalarType` is the type that is not indexed in the op and can be
+// either a scalar or a vector, i.e., the source type in InsertOp and the
+// return type in ExtractOp.
+static LogicalResult verifyInsertExtractIndices(Operation *op,
+ VectorType indexedType,
+ int64_t numIndices,
+ Type vecOrScalarType) {
+ int64_t indexedRank = indexedType.getRank();
+ if (numIndices > indexedRank)
+ return op->emitOpError(
+ "expected a number of indices no greater than the indexed vector rank");
+
+ if (auto nonIndexedVecType = dyn_cast<VectorType>(vecOrScalarType)) {
+ // Vector case, including:
+ // * 0-D vector:
+ // * vector.extract %src[2]: vector<f32> from vector<8xf32)
+ // * vector.insert %src, %dst[3]: vector<f32> into vector<8xf32>
+ // * One-element vector:
+ // * vector.extract %src[4]: vector<1xf32> from vector<8xf32>
+ // * vector.insert %src, %dst[1]: vector<1xf32> into vector<8xf32>
+ // * vector.extract %src[7]: vector<1xf32> from vector<8x1xf32>
+ // * vector.insert %src, %dst[5]: vector<1xf32> into vector<8x1xf32>
+ int64_t indexedRankMinusIndices = indexedRank - numIndices;
+ int64_t nonIndexedRank = nonIndexedVecType.getRank();
+ bool isOneElemVec =
+ nonIndexedRank == 1 && nonIndexedVecType.getDimSize(0) == 1;
+ if (indexedRankMinusIndices != nonIndexedRank &&
+ (!isOneElemVec || indexedRankMinusIndices != 0)) {
+ return op->emitOpError(
+ "expected indexed vector rank minus number of indices to match "
+ "the rank of the non-indexed vector rank");
+ }
+ } else if (indexedRank != numIndices) {
+ // Scalar case.
+ return op->emitOpError("expected indexed vector rank to match the number "
+ "of indices for scalar cases");
+ }
+
+ return success();
+}
+
LogicalResult vector::ExtractOp::verify() {
// Note: This check must come before getMixedPosition() to prevent a crash.
auto dynamicMarkersCount =
@@ -1348,31 +1392,12 @@ LogicalResult vector::ExtractOp::verify() {
"mismatch between dynamic and static positions (kDynamic marker but no "
"corresponding dynamic position) -- this can only happen due to an "
"incorrect fold/rewrite");
- auto position = getMixedPosition();
- VectorType srcVecType = getSourceVectorType();
- int64_t srcRank = srcVecType.getRank();
- if (position.size() > static_cast<unsigned>(srcRank))
- return emitOpError(
- "expected position attribute of rank no greater than vector rank");
-
- VectorType dstVecType = dyn_cast<VectorType>(getResult().getType());
- if (dstVecType) {
- int64_t srcRankMinusIndices = srcRank - getNumIndices();
- int64_t dstRank = dstVecType.getRank();
- if ((srcRankMinusIndices == 0 && dstRank != 1) ||
- (srcRankMinusIndices != 0 && srcRankMinusIndices != dstRank)) {
- return emitOpError(
- "expected source rank minus number of indices to match "
- "destination vector rank");
- }
- } else {
- // Scalar result.
- if (srcRank != getNumIndices())
- return emitOpError("expected source rank to match number of indices "
- "for scalar result");
- }
+ auto srcVecType = getSourceVectorType();
+ if (failed(verifyInsertExtractIndices(*this, srcVecType, getNumIndices(),
+ getResult().getType())))
+ return failure();
- for (auto [idx, pos] : llvm::enumerate(position)) {
+ for (auto [idx, pos] : llvm::enumerate(getMixedPosition())) {
if (pos.is<Attribute>()) {
int64_t constIdx = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
if (constIdx < 0 || constIdx >= srcVecType.getDimSize(idx)) {
@@ -2881,25 +2906,15 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
}
LogicalResult InsertOp::verify() {
- SmallVector<OpFoldResult> position = getMixedPosition();
- auto destVectorType = getDestVectorType();
- if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
- return emitOpError(
- "expected position attribute of rank no greater than dest vector rank");
- auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
- if (srcVectorType &&
- (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
- static_cast<unsigned>(destVectorType.getRank())))
- return emitOpError("expected position attribute rank + source rank to "
- "match dest vector rank");
- if (!srcVectorType &&
- (position.size() != static_cast<unsigned>(destVectorType.getRank())))
- return emitOpError(
- "expected position attribute rank to match the dest vector rank");
- for (auto [idx, pos] : llvm::enumerate(position)) {
+ auto dstVecType = getDestVectorType();
+ if (failed(verifyInsertExtractIndices(*this, dstVecType, getNumIndices(),
+ getSourceType())))
+ return failure();
+
+ for (auto [idx, pos] : llvm::enumerate(getMixedPosition())) {
if (auto attr = pos.dyn_cast<Attribute>()) {
int64_t constIdx = cast<IntegerAttr>(attr).getInt();
- if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
+ if (constIdx < 0 || constIdx >= dstVecType.getDimSize(idx)) {
return emitOpError("expected position attribute #")
<< (idx + 1)
<< " to be a non-negative integer smaller than the "
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index e09b031dd49648..71a36b58308c09 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -151,14 +151,14 @@ func.func @extract_vector_type(%arg0: index) {
// -----
func.func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
- // expected-error at +1 {{expected position attribute of rank no greater than vector rank}}
+ // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
%1 = vector.extract %arg0[0, 0, 0, 0] : f32 from vector<4x8x16xf32>
}
// -----
func.func @extract_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
- // expected-error at +1 {{expected position attribute of rank no greater than vector rank}}
+ // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
%1 = "vector.extract" (%arg0) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<4x8x16xf32>) -> (vector<16xf32>)
}
@@ -178,22 +178,94 @@ func.func @extract_precise_position_overflow(%arg0: vector<4x8x16xf32>) {
// -----
-func.func @extract_0d(%arg0: vector<f32>) {
- // expected-error at +1 {{expected position attribute of rank no greater than vector rank}}
+func.func @extract_from_0d_to_scalar_wrong_index(%arg0: vector<f32>) {
+ // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
%1 = vector.extract %arg0[0] : f32 from vector<f32>
}
// -----
+func.func @extract_from_0d_to_0d_wrong_index(%arg0: vector<f32>) {
+ // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
+ %2 = vector.extract %arg0[0] : vector<f32> from vector<f32>
+}
+
+// -----
+
+func.func @extract_from_0d_to_1d_wrong_index(%arg0: vector<f32>) {
+ // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
+ %3 = vector.extract %arg0[0] : vector<1xf32> from vector<f32>
+}
+
+// -----
+
+func.func @extract_from_1d_to_scalar_wrong_index(%arg0: vector<1xf32>) {
+ // expected-error at +1 {{expected indexed vector rank to match the number of indices for scalar cases}}
+ %1 = vector.extract %arg0[] : f32 from vector<1xf32>
+}
+
+// -----
+
+func.func @extract_from_1d_to_0d_wrong_index(%arg0: vector<1xf32>) {
+ // expected-error at +1 {{'vector.extract' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
+ %2 = vector.extract %arg0[] : vector<f32> from vector<1xf32>
+}
+
+// -----
+
+func.func @extract_from_1d_to_0d(%arg0: vector<1xf32>) {
+ // expected-error at +2 {{'vector.extract' op inferred type(s) 'f32' are incompatible with return type(s) of operation 'vector<f32>'}}
+ // expected-error at +1 {{failed to infer returned types}}
+ %4 = vector.extract %arg0[0] : vector<f32> from vector<1xf32>
+}
+
+// -----
+
+func.func @extract_from_2d_to_scalar(%arg0: vector<4x1xf32>) {
+ // expected-error at +1 {{'vector.extract' op expected indexed vector rank to match the number of indices for scalar cases}}
+ %6 = vector.extract %arg0[2] : f32 from vector<4x1xf32>
+}
+
+// -----
+
+func.func @extract_from_2d_to_0d(%arg0: vector<4x1xf32>) {
+ // expected-error at +1 {{'vector.extract' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
+ %7 = vector.extract %arg0[2] : vector<f32> from vector<4x1xf32>
+}
+
+// -----
+
+func.func @extract_from_2d_to_scalar_wrong_index(%arg0: vector<4x8xf32>) {
+ // expected-error at +1 {{'vector.extract' op expected indexed vector rank to match the number of indices for scalar cases}}
+ %8 = vector.extract %arg0[3] : f32 from vector<4x8xf32>
+}
+
+// -----
+
+func.func @extract_from_2d_to_0d(%arg0: vector<4x8xf32>) {
+ // expected-error at +1 {{'vector.extract' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
+ %9 = vector.extract %arg0[3] : vector<f32> from vector<4x8xf32>
+}
+
+// -----
+
+func.func @extract_from_2d_to_1d(%arg0: vector<4x8xf32>) {
+ // expected-error at +2 {{'vector.extract' op inferred type(s) 'vector<8xf32>' are incompatible with return type(s) of operation 'vector<1xf32>'}}
+ // expected-error at +1 {{failed to infer returned types}}
+ %10 = vector.extract %arg0[3] : vector<1xf32> from vector<4x8xf32>
+}
+
+// -----
+
func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
- // expected-error at +1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
+ // expected-error at +1 {{'vector.extract' op expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
%1 = vector.extract %arg0[0, 0, -1] : f32 from vector<4x8x16xf32>
}
// -----
func.func @extract_scalar_missing_indices(%arg0: vector<4x8x1xf32>) {
- // expected-error at +1 {{expected source rank to match number of indices for scalar result}}
+ // expected-error at +1 {{'vector.extract' op expected indexed vector rank to match the number of indices for scalar cases}}
%1 = vector.extract %arg0[0, 0] : f32 from vector<4x8x1xf32>
}
@@ -201,7 +273,7 @@ func.func @extract_scalar_missing_indices(%arg0: vector<4x8x1xf32>) {
func.func @insert_element(%arg0: f32, %arg1: vector<f32>) {
%c = arith.constant 3 : i32
- // expected-error at +1 {{expected position to be empty with 0-D vector}}
+ // expected-error at +1 {{'vector.insertelement' op expected position to be empty with 0-D vector}}
%0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<f32>
}
@@ -209,7 +281,7 @@ func.func @insert_element(%arg0: f32, %arg1: vector<f32>) {
func.func @insert_element(%arg0: f32, %arg1: vector<4xf32>) {
%c = arith.constant 3 : i32
- // expected-error at +1 {{expected position for 1-D vector}}
+ // expected-error at +1 {{'vector.insertelement' op expected position for 1-D vector}}
%0 = vector.insertelement %arg0, %arg1[] : vector<4xf32>
}
@@ -232,21 +304,21 @@ func.func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) {
// -----
func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
- // expected-error at +1 {{expected position attribute of rank no greater than dest vector rank}}
+ // expected-error at +1 {{'vector.insert' op expected a number of indices no greater than the indexed vector rank}}
%1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32>
}
// -----
func.func @insert_vector_type(%a: vector<4xf32>, %b: vector<4x8x16xf32>) {
- // expected-error at +1 {{expected position attribute rank + source rank to match dest vector rank}}
+ // expected-error at +1 {{'vector.insert' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
%1 = vector.insert %a, %b[3] : vector<4xf32> into vector<4x8x16xf32>
}
// -----
func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
- // expected-error at +1 {{expected position attribute rank to match the dest vector rank}}
+ // expected-error at +1 {{'vector.insert' op expected indexed vector rank to match the number of indices for scalar cases}}
%1 = vector.insert %a, %b[3, 3] : f32 into vector<4x8x16xf32>
}
@@ -267,14 +339,14 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
// -----
func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
- // expected-error at +1 {{expected position attribute rank + source rank to match dest vector rank}}
+ // expected-error at +1 {{'vector.insert' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
%1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
}
// -----
func.func @insert_0d(%a: f32, %b: vector<f32>) {
- // expected-error at +1 {{expected position attribute of rank no greater than dest vector rank}}
+ // expected-error at +1 {{'vector.insert' op expected a number of indices no greater than the indexed vector rank}}
%1 = vector.insert %a, %b[0] : f32 into vector<f32>
}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 8595b0635cc94a..8d842afceca206 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -224,7 +224,7 @@ func.func @extract_const_idx(%arg0: vector<4x8x16xf32>)
// CHECK-SAME: %[[VEC:.+]]: vector<4x8x16xf32>, %[[IDX:.+]]: index
func.func @extract_val_idx(%arg0: vector<4x8x16xf32>, %idx: index)
-> (vector<8x16xf32>, vector<16xf32>, f32) {
- // CHECK: vector.extract %[[VEC]][%[[IDX]]] : vector<8x16xf32> from vector<4x8x16xf32>
+ // CHECK: vector.extract %[[VEC]][%[[IDX]]] : vector<8x16xf32> from vector<4x8x16xf32>
%0 = vector.extract %arg0[%idx] : vector<8x16xf32> from vector<4x8x16xf32>
// CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], %[[IDX]]] : vector<16xf32> from vector<4x8x16xf32>
%1 = vector.extract %arg0[%idx, %idx] : vector<16xf32> from vector<4x8x16xf32>
@@ -234,17 +234,24 @@ func.func @extract_val_idx(%arg0: vector<4x8x16xf32>, %idx: index)
}
// CHECK-LABEL: @extract_0d
-func.func @extract_0d(%a: vector<f32>) -> f32 {
- // CHECK-NEXT: vector.extract %{{.*}}[] : f32 from vector<f32>
- %0 = vector.extract %a[] : f32 from vector<f32>
- return %0 : f32
-}
-
-// CHECK-LABEL: @extract_single_element_vector
-func.func @extract_single_element_vector(%arg0: vector<4x8x3xf32>) -> vector<1xf32> {
- // CHECK: vector.extract {{.*}}[0, 0, 0] : vector<1xf32> from vector<4x8x3xf32>
- %1 = vector.extract %arg0[0, 0, 0] : vector<1xf32> from vector<4x8x3xf32>
- return %1 : vector<1xf32>
+func.func @extract_0d(%arg0: vector<f32>) -> (f32, vector<1xf32>) {
+ // CHECK: vector.extract %{{.*}}[] : f32 from vector<f32>
+ %0 = vector.extract %arg0[] : f32 from vector<f32>
+ // CHECK-NEXT: vector.extract %{{.*}}[] : vector<1xf32> from vector<f32>
+ %1 = vector.extract %arg0[] : vector<1xf32> from vector<f32>
+ return %0, %1 : f32, vector<1xf32>
+}
+
+// CHECK-LABEL: @extract_1d
+func.func @extract_1d(%arg0: vector<1xf32>, %arg1: vector<4x1xf32>)
+ -> (f32, vector<1xf32>, vector<1xf32>) {
+ // CHECK: vector.extract %{{.*}}[0] : f32 from vector<1xf32>
+ %0 = vector.extract %arg0[0] : f32 from vector<1xf32>
+ // CHECK-NEXT: vector.extract %{{.*}}[0] : vector<1xf32> from vector<1xf32>
+ %1 = vector.extract %arg0[0] : vector<1xf32> from vector<1xf32>
+ // CHECK-NEXT: vector.extract %{{.*}}[2] : vector<1xf32> from vector<4x1xf32>
+ %2 = vector.extract %arg1[2] : vector<1xf32> from vector<4x1xf32>
+ return %0, %1, %2 : f32, vector<1xf32>, vector<1xf32>
}
// CHECK-LABEL: @insert_element_0d
@@ -291,12 +298,30 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
}
// CHECK-LABEL: @insert_0d
-func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
- // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
- %1 = vector.insert %a, %b[] : f32 into vector<f32>
+func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<1xf32>, %d: vector<2x3xf32>)
+ -> (vector<f32>, vector<f32>, vector<2x3xf32>) {
+ // CHECK: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
+ %1 = vector.insert %a, %b[] : f32 into vector<f32>
+ // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : vector<1xf32> into vector<f32>
+ %2 = vector.insert %c, %b[] : vector<1xf32> into vector<f32>
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
- %2 = vector.insert %b, %c[0, 1] : vector<f32> into vector<2x3xf32>
- return %1, %2 : vector<f32>, vector<2x3xf32>
+ %3 = vector.insert %b, %d[0, 1] : vector<f32> into vector<2x3xf32>
+ return %1, %2, %3 : vector<f32>, vector<f32>, vector<2x3xf32>
+}
+
+// CHECK-LABEL: @insert_1d
+func.func @insert_1d(%a: f32, %b: vector<1xf32>, %c: vector<4x1xf32>,
+ %d: vector<2x3xf32>)
+ -> (vector<1xf32>, vector<1xf32>, vector<4x1xf32>, vector<2x3xf32>) {
+ // CHECK: vector.insert %{{.*}}, %{{.*}}[0] : f32 into vector<1xf32>
+ %0 = vector.insert %a, %b[0] : f32 into vector<1xf32>
+ // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0] : vector<1xf32> into vector<1xf32>
+ %1 = vector.insert %0, %b[0] : vector<1xf32> into vector<1xf32>
+ // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[2] : vector<1xf32> into vector<4x1xf32>
+ %2 = vector.insert %b, %c[2] : vector<1xf32> into vector<4x1xf32>
+ // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<1xf32> into vector<2x3xf32>
+ %3 = vector.insert %b, %d[0, 1] : vector<1xf32> into vector<2x3xf32>
+ return %0, %1, %2, %3 : vector<1xf32>, vector<1xf32>, vector<4x1xf32>, vector<2x3xf32>
}
// CHECK-LABEL: @outerproduct
>From c55e252d57089fc956365372e772268a92dbc571 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Fri, 15 Nov 2024 19:25:58 -0800
Subject: [PATCH 3/3] Improvements
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 91 +++++++++-----
mlir/test/Dialect/Vector/invalid.mlir | 151 +++++++++++++++++------
mlir/test/Dialect/Vector/ops.mlir | 20 ++-
3 files changed, 183 insertions(+), 79 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2988d4d8194837..eb006233201dc6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1339,23 +1339,28 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return l == r;
}
-// Common verification rules for `InsertOp` and `ExtractOp` involving indices.
-// `indexedType` is the vector type being indexed in the operation, i.e., the
-// destination type in InsertOp and the source type in ExtractOp.
-// `vecOrScalarType` is the type that is not indexed in the op and can be
-// either a scalar or a vector, i.e., the source type in InsertOp and the
-// return type in ExtractOp.
-static LogicalResult verifyInsertExtractIndices(Operation *op,
- VectorType indexedType,
- int64_t numIndices,
- Type vecOrScalarType) {
+// Common verification rules for `InsertOp` and `ExtractOp` involving indices
+// and shapes. `indexedType` is the vector type being indexed by the operation,
+// i.e., the destination type in `InsertOp` and the source type in `ExtractOp`.
+// `nonIndexedType` is the inserted or extracted type by an `InsertOp` or and
+// `ExtractOp`, respectively.
+static LogicalResult verifyInsertExtractIndicesAndShapes(Operation *op,
+ VectorType indexedType,
+ int64_t numIndices,
+ Type nonIndexedType) {
+ assert((isa<InsertOp>(op) || isa<ExtractOp>(op)) &&
+ "Expected InsertOp or ExtractOp");
+
+ std::string nonIndexedStr = isa<InsertOp>(op) ? "inserted" : "extracted";
+ std::string indexedStr = isa<InsertOp>(op) ? "destination" : "source";
int64_t indexedRank = indexedType.getRank();
if (numIndices > indexedRank)
- return op->emitOpError(
- "expected a number of indices no greater than the indexed vector rank");
+ return op->emitOpError()
+ << "expected a number of indices no greater than the " << indexedStr
+ << " vector rank";
- if (auto nonIndexedVecType = dyn_cast<VectorType>(vecOrScalarType)) {
- // Vector case, including:
+ if (auto nonIndexedVecType = dyn_cast<VectorType>(nonIndexedType)) {
+ // Vector case, including meaningful cases such as:
// * 0-D vector:
// * vector.extract %src[2]: vector<f32> from vector<8xf32)
// * vector.insert %src, %dst[3]: vector<f32> into vector<8xf32>
@@ -1364,20 +1369,46 @@ static LogicalResult verifyInsertExtractIndices(Operation *op,
// * vector.insert %src, %dst[1]: vector<1xf32> into vector<8xf32>
// * vector.extract %src[7]: vector<1xf32> from vector<8x1xf32>
// * vector.insert %src, %dst[5]: vector<1xf32> into vector<8x1xf32>
- int64_t indexedRankMinusIndices = indexedRank - numIndices;
int64_t nonIndexedRank = nonIndexedVecType.getRank();
- bool isOneElemVec =
- nonIndexedRank == 1 && nonIndexedVecType.getDimSize(0) == 1;
+ bool isSingleElem1DNonIndexedVec =
+ (nonIndexedRank == 1 && nonIndexedVecType.getDimSize(0) == 1);
+ bool isSingleElem1DIndexedVec =
+ (indexedRank == 1 && indexedType.getDimSize(0) == 1);
+ // Verify 0-D -> single-element 1-D supported cases.
+ if ((indexedRank == 0 && isSingleElem1DNonIndexedVec) ||
+ (nonIndexedRank == 0 && isSingleElem1DIndexedVec)) {
+ return op->emitOpError("expected source and destination vectors with "
+ "different number of elements");
+ }
+
+ // Verify indices for all the cases.
+ int64_t indexedRankMinusIndices = indexedRank - numIndices;
if (indexedRankMinusIndices != nonIndexedRank &&
- (!isOneElemVec || indexedRankMinusIndices != 0)) {
- return op->emitOpError(
- "expected indexed vector rank minus number of indices to match "
- "the rank of the non-indexed vector rank");
+ (!isSingleElem1DNonIndexedVec || indexedRankMinusIndices != 0)) {
+ return op->emitOpError()
+ << "expected " << indexedStr
+ << " vector rank minus number of indices to match the rank of the "
+ << nonIndexedStr << " vector";
}
- } else if (indexedRank != numIndices) {
- // Scalar case.
- return op->emitOpError("expected indexed vector rank to match the number "
- "of indices for scalar cases");
+ // Check that if we are inserting or extracting a sub-vector, the
+ // corresponding source and destination shapes match.
+ if (indexedRankMinusIndices > 0) {
+ auto indexedShape = indexedType.getShape();
+ if (indexedShape.drop_front(numIndices) != nonIndexedVecType.getShape())
+ return op->emitOpError() << "expected " << nonIndexedStr
+ << " vector shape to match the sub-vector "
+ "shape of the "
+ << indexedStr << " vector";
+ }
+
+ return success();
+ }
+
+ // Scalar case.
+ if (indexedRank != numIndices) {
+ return op->emitOpError()
+ << "expected " << indexedStr
+ << " vector rank to match the number of indices for scalar cases";
}
return success();
@@ -1393,9 +1424,10 @@ LogicalResult vector::ExtractOp::verify() {
"corresponding dynamic position) -- this can only happen due to an "
"incorrect fold/rewrite");
auto srcVecType = getSourceVectorType();
- if (failed(verifyInsertExtractIndices(*this, srcVecType, getNumIndices(),
- getResult().getType())))
+ if (failed(verifyInsertExtractIndicesAndShapes(
+ *this, srcVecType, getNumIndices(), getResult().getType()))) {
return failure();
+ }
for (auto [idx, pos] : llvm::enumerate(getMixedPosition())) {
if (pos.is<Attribute>()) {
@@ -2907,9 +2939,10 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
LogicalResult InsertOp::verify() {
auto dstVecType = getDestVectorType();
- if (failed(verifyInsertExtractIndices(*this, dstVecType, getNumIndices(),
- getSourceType())))
+ if (failed(verifyInsertExtractIndicesAndShapes(
+ *this, dstVecType, getNumIndices(), getSourceType()))) {
return failure();
+ }
for (auto [idx, pos] : llvm::enumerate(getMixedPosition())) {
if (auto attr = pos.dyn_cast<Attribute>()) {
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 71a36b58308c09..b8c932a93c8a25 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -151,107 +151,112 @@ func.func @extract_vector_type(%arg0: index) {
// -----
func.func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
- // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
+ // expected-error at +1 {{'vector.extract' op expected a number of indices no greater than the source vector rank}}
%1 = vector.extract %arg0[0, 0, 0, 0] : f32 from vector<4x8x16xf32>
}
// -----
func.func @extract_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
- // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
+ // expected-error at +1 {{'vector.extract' op expected a number of indices no greater than the source vector rank}}
%1 = "vector.extract" (%arg0) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<4x8x16xf32>) -> (vector<16xf32>)
}
// -----
func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
- // expected-error at +1 {{expected position attribute #2 to be a non-negative integer smaller than the corresponding vector dimension}}
+ // expected-error at +1 {{'vector.extract' op expected position attribute #2 to be a non-negative integer smaller than the corresponding vector dimension}}
%1 = vector.extract %arg0[0, 43, 0] : f32 from vector<4x8x16xf32>
}
// -----
func.func @extract_precise_position_overflow(%arg0: vector<4x8x16xf32>) {
- // expected-error at +1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
+ // expected-error at +1 {{'vector.extract' op expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
%1 = vector.extract %arg0[3, 7, 16] : f32 from vector<4x8x16xf32>
}
// -----
-func.func @extract_from_0d_to_scalar_wrong_index(%arg0: vector<f32>) {
- // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
+func.func @extract_scalar_from_0d_wrong_index(%arg0: vector<f32>) {
+ // expected-error at +1 {{'vector.extract' op expected a number of indices no greater than the source vector rank}}
%1 = vector.extract %arg0[0] : f32 from vector<f32>
}
// -----
-func.func @extract_from_0d_to_0d_wrong_index(%arg0: vector<f32>) {
- // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
+func.func @extract_0d_from_0d_wrong_index(%arg0: vector<f32>) {
+ // expected-error at +1 {{'vector.extract' op expected a number of indices no greater than the source vector rank}}
%2 = vector.extract %arg0[0] : vector<f32> from vector<f32>
}
// -----
-func.func @extract_from_0d_to_1d_wrong_index(%arg0: vector<f32>) {
- // expected-error at +1 {{expected a number of indices no greater than the indexed vector rank}}
+func.func @extract_1d_from_0d_wrong_index(%arg0: vector<f32>) {
+ // expected-error at +1 {{'vector.extract' op expected a number of indices no greater than the source vector rank}}
%3 = vector.extract %arg0[0] : vector<1xf32> from vector<f32>
}
// -----
-func.func @extract_from_1d_to_scalar_wrong_index(%arg0: vector<1xf32>) {
- // expected-error at +1 {{expected indexed vector rank to match the number of indices for scalar cases}}
+func.func @extract_scalar_from_1d_wrong_index(%arg0: vector<1xf32>) {
+ // expected-error at +1 {{'vector.extract' op expected source vector rank to match the number of indices for scalar cases}}
%1 = vector.extract %arg0[] : f32 from vector<1xf32>
}
// -----
-func.func @extract_from_1d_to_0d_wrong_index(%arg0: vector<1xf32>) {
- // expected-error at +1 {{'vector.extract' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
+func.func @extract_0d_from_1d_wrong_index(%arg0: vector<1xf32>) {
+ // expected-error at +1 {{'vector.extract' op expected source and destination vectors with different number of elements}}
%2 = vector.extract %arg0[] : vector<f32> from vector<1xf32>
}
// -----
-func.func @extract_from_1d_to_0d(%arg0: vector<1xf32>) {
- // expected-error at +2 {{'vector.extract' op inferred type(s) 'f32' are incompatible with return type(s) of operation 'vector<f32>'}}
- // expected-error at +1 {{failed to infer returned types}}
+func.func @extract_0d_from_1d(%arg0: vector<1xf32>) {
+ // expected-error at +1 {{'vector.extract' op expected source and destination vectors with different number of elements}}
%4 = vector.extract %arg0[0] : vector<f32> from vector<1xf32>
}
// -----
-func.func @extract_from_2d_to_scalar(%arg0: vector<4x1xf32>) {
- // expected-error at +1 {{'vector.extract' op expected indexed vector rank to match the number of indices for scalar cases}}
+func.func @extract_1d_from_0d(%arg0: vector<f32>) {
+ // expected-error at +1 {{'vector.extract' op expected source and destination vectors with different number of elements}}
+ %4 = vector.extract %arg0[] : vector<1xf32> from vector<f32>
+}
+
+// -----
+
+func.func @extract_scalar_from_2d(%arg0: vector<4x1xf32>) {
+ // expected-error at +1 {{'vector.extract' op expected source vector rank to match the number of indices for scalar cases}}
%6 = vector.extract %arg0[2] : f32 from vector<4x1xf32>
}
// -----
-func.func @extract_from_2d_to_0d(%arg0: vector<4x1xf32>) {
- // expected-error at +1 {{'vector.extract' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
+func.func @extract_0d_from_2d(%arg0: vector<4x1xf32>) {
+ // expected-error at +1 {{'vector.extract' op expected source vector rank minus number of indices to match the rank of the extracted vector}}
%7 = vector.extract %arg0[2] : vector<f32> from vector<4x1xf32>
}
// -----
-func.func @extract_from_2d_to_scalar_wrong_index(%arg0: vector<4x8xf32>) {
- // expected-error at +1 {{'vector.extract' op expected indexed vector rank to match the number of indices for scalar cases}}
+func.func @extract_scalar_from_2d_wrong_index(%arg0: vector<4x8xf32>) {
+ // expected-error at +1 {{'vector.extract' op expected source vector rank to match the number of indices for scalar cases}}
%8 = vector.extract %arg0[3] : f32 from vector<4x8xf32>
}
// -----
-func.func @extract_from_2d_to_0d(%arg0: vector<4x8xf32>) {
- // expected-error at +1 {{'vector.extract' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
+func.func @extract_0d_from_2d(%arg0: vector<4x8xf32>) {
+ // expected-error at +1 {{'vector.extract' op expected source vector rank minus number of indices to match the rank of the extracted vector}}
%9 = vector.extract %arg0[3] : vector<f32> from vector<4x8xf32>
}
// -----
-func.func @extract_from_2d_to_1d(%arg0: vector<4x8xf32>) {
- // expected-error at +2 {{'vector.extract' op inferred type(s) 'vector<8xf32>' are incompatible with return type(s) of operation 'vector<1xf32>'}}
- // expected-error at +1 {{failed to infer returned types}}
+func.func @extract_1d_from_2d(%arg0: vector<4x8xf32>) {
+ // expected-error at +1 {{'vector.extract' op expected extracted vector shape to match the sub-vector shape of the source vector}}
%10 = vector.extract %arg0[3] : vector<1xf32> from vector<4x8xf32>
}
@@ -265,7 +270,7 @@ func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
// -----
func.func @extract_scalar_missing_indices(%arg0: vector<4x8x1xf32>) {
- // expected-error at +1 {{'vector.extract' op expected indexed vector rank to match the number of indices for scalar cases}}
+ // expected-error at +1 {{'vector.extract' op expected source vector rank to match the number of indices for scalar cases}}
%1 = vector.extract %arg0[0, 0] : f32 from vector<4x8x1xf32>
}
@@ -304,21 +309,21 @@ func.func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) {
// -----
func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
- // expected-error at +1 {{'vector.insert' op expected a number of indices no greater than the indexed vector rank}}
+ // expected-error at +1 {{'vector.insert' op expected a number of indices no greater than the destination vector rank}}
%1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32>
}
// -----
func.func @insert_vector_type(%a: vector<4xf32>, %b: vector<4x8x16xf32>) {
- // expected-error at +1 {{'vector.insert' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
+ // expected-error at +1 {{'vector.insert' op expected destination vector rank minus number of indices to match the rank of the inserted vector}}
%1 = vector.insert %a, %b[3] : vector<4xf32> into vector<4x8x16xf32>
}
// -----
func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
- // expected-error at +1 {{'vector.insert' op expected indexed vector rank to match the number of indices for scalar cases}}
+ // expected-error at +1 {{'vector.insert' op expected destination vector rank to match the number of indices for scalar cases}}
%1 = vector.insert %a, %b[3, 3] : f32 into vector<4x8x16xf32>
}
@@ -338,16 +343,86 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
// -----
-func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
- // expected-error at +1 {{'vector.insert' op expected indexed vector rank minus number of indices to match the rank of the non-indexed vector rank}}
- %1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
+func.func @insert_scalar_into_0d_wrong_index(%arg0: f32, %arg1: vector<f32>) {
+ // expected-error at +1 {{'vector.insert' op expected a number of indices no greater than the destination vector rank}}
+ %1 = vector.insert %arg0, %arg1[0] : f32 into vector<f32>
+}
+
+// -----
+
+func.func @insert_0d_into_0d_wrong_index(%arg0: vector<f32>, %arg1: vector<f32>) {
+ // expected-error at +1 {{'vector.insert' op expected a number of indices no greater than the destination vector rank}}
+ %2 = vector.insert %arg0, %arg1[0] : vector<f32> into vector<f32>
+}
+
+// -----
+
+func.func @insert_1d_into_0d_wrong_index(%arg0: vector<1xf32>, %arg1: vector<f32>) {
+ // expected-error at +1 {{'vector.insert' op expected a number of indices no greater than the destination vector rank}}
+ %3 = vector.insert %arg0, %arg1[0] : vector<1xf32> into vector<f32>
+}
+
+// -----
+
+func.func @insert_scalar_into_1d_wrong_index(%arg0: f32, %arg1: vector<1xf32>) {
+ // expected-error at +1 {{'vector.insert' op expected destination vector rank to match the number of indices for scalar cases}}
+ %1 = vector.insert %arg0, %arg1[] : f32 into vector<1xf32>
+}
+
+// -----
+
+func.func @insert_0d_into_1d_wrong_index(%arg0: vector<f32>, %arg1: vector<1xf32>) {
+ // expected-error at +1 {{'vector.insert' op expected source and destination vectors with different number of elements}}
+ %2 = vector.insert %arg0, %arg1[] : vector<f32> into vector<1xf32>
+}
+
+// -----
+
+func.func @insert_0d_into_1d(%arg0: vector<f32>, %arg1: vector<1xf32>) {
+ // expected-error at +1 {{'vector.insert' op expected source and destination vectors with different number of elements}}
+ %4 = vector.insert %arg0, %arg1[0] : vector<f32> into vector<1xf32>
+}
+
+// -----
+
+func.func @insert_1d_into_0d(%arg0: vector<1xf32>, %arg1: vector<f32>) {
+ // expected-error at +1 {{'vector.insert' op expected source and destination vectors with different number of elements}}
+ %4 = vector.insert %arg0, %arg1[] : vector<1xf32> into vector<f32>
+}
+
+// -----
+
+func.func @insert_scalar_into_2d(%arg0: f32, %arg1: vector<4x1xf32>) {
+ // expected-error at +1 {{'vector.insert' op expected destination vector rank to match the number of indices for scalar cases}}
+ %6 = vector.insert %arg0, %arg1[2] : f32 into vector<4x1xf32>
+}
+
+// -----
+
+func.func @insert_0d_into_2d(%arg0: vector<f32>, %arg1: vector<4x1xf32>) {
+ // expected-error at +1 {{'vector.insert' op expected destination vector rank minus number of indices to match the rank of the inserted vector}}
+ %7 = vector.insert %arg0, %arg1[2] : vector<f32> into vector<4x1xf32>
+}
+
+// -----
+
+func.func @insert_scalar_into_2d_wrong_index(%arg0: f32, %arg1: vector<4x8xf32>) {
+ // expected-error at +1 {{'vector.insert' op expected destination vector rank to match the number of indices for scalar cases}}
+ %8 = vector.insert %arg0, %arg1[3] : f32 into vector<4x8xf32>
+}
+
+// -----
+
+func.func @insert_0d_into_2d(%arg0: vector<f32>, %arg1: vector<4x8xf32>) {
+ // expected-error at +1 {{'vector.insert' op expected destination vector rank minus number of indices to match the rank of the inserted vector}}
+ %9 = vector.insert %arg0, %arg1[3] : vector<f32> into vector<4x8xf32>
}
// -----
-func.func @insert_0d(%a: f32, %b: vector<f32>) {
- // expected-error at +1 {{'vector.insert' op expected a number of indices no greater than the indexed vector rank}}
- %1 = vector.insert %a, %b[0] : f32 into vector<f32>
+func.func @insert_1d_into_2d(%arg0: vector<1xf32>, %arg1: vector<4x8xf32>) {
+ // expected-error at +1 {{'vector.insert' op expected inserted vector shape to match the sub-vector shape of the destination vector}}
+ %10 = vector.insert %arg0, %arg1[3] : vector<1xf32> into vector<4x8xf32>
}
// -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 8d842afceca206..00a433424d6c46 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -234,24 +234,20 @@ func.func @extract_val_idx(%arg0: vector<4x8x16xf32>, %idx: index)
}
// CHECK-LABEL: @extract_0d
-func.func @extract_0d(%arg0: vector<f32>) -> (f32, vector<1xf32>) {
+func.func @extract_0d(%arg0: vector<f32>) -> f32 {
// CHECK: vector.extract %{{.*}}[] : f32 from vector<f32>
%0 = vector.extract %arg0[] : f32 from vector<f32>
- // CHECK-NEXT: vector.extract %{{.*}}[] : vector<1xf32> from vector<f32>
- %1 = vector.extract %arg0[] : vector<1xf32> from vector<f32>
- return %0, %1 : f32, vector<1xf32>
+ return %0 : f32
}
// CHECK-LABEL: @extract_1d
func.func @extract_1d(%arg0: vector<1xf32>, %arg1: vector<4x1xf32>)
- -> (f32, vector<1xf32>, vector<1xf32>) {
+ -> (f32, vector<1xf32>) {
// CHECK: vector.extract %{{.*}}[0] : f32 from vector<1xf32>
%0 = vector.extract %arg0[0] : f32 from vector<1xf32>
- // CHECK-NEXT: vector.extract %{{.*}}[0] : vector<1xf32> from vector<1xf32>
- %1 = vector.extract %arg0[0] : vector<1xf32> from vector<1xf32>
// CHECK-NEXT: vector.extract %{{.*}}[2] : vector<1xf32> from vector<4x1xf32>
- %2 = vector.extract %arg1[2] : vector<1xf32> from vector<4x1xf32>
- return %0, %1, %2 : f32, vector<1xf32>, vector<1xf32>
+ %1 = vector.extract %arg1[2] : vector<1xf32> from vector<4x1xf32>
+ return %0, %1 : f32, vector<1xf32>
}
// CHECK-LABEL: @insert_element_0d
@@ -302,8 +298,8 @@ func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<1xf32>, %d: vector<2x3
-> (vector<f32>, vector<f32>, vector<2x3xf32>) {
// CHECK: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
%1 = vector.insert %a, %b[] : f32 into vector<f32>
- // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : vector<1xf32> into vector<f32>
- %2 = vector.insert %c, %b[] : vector<1xf32> into vector<f32>
+ // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : vector<f32> into vector<f32>
+ %2 = vector.insert %b, %b[] : vector<f32> into vector<f32>
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
%3 = vector.insert %b, %d[0, 1] : vector<f32> into vector<2x3xf32>
return %1, %2, %3 : vector<f32>, vector<f32>, vector<2x3xf32>
@@ -316,7 +312,7 @@ func.func @insert_1d(%a: f32, %b: vector<1xf32>, %c: vector<4x1xf32>,
// CHECK: vector.insert %{{.*}}, %{{.*}}[0] : f32 into vector<1xf32>
%0 = vector.insert %a, %b[0] : f32 into vector<1xf32>
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0] : vector<1xf32> into vector<1xf32>
- %1 = vector.insert %0, %b[0] : vector<1xf32> into vector<1xf32>
+ %1 = vector.insert %b, %b[0] : vector<1xf32> into vector<1xf32>
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[2] : vector<1xf32> into vector<4x1xf32>
%2 = vector.insert %b, %c[2] : vector<1xf32> into vector<4x1xf32>
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<1xf32> into vector<2x3xf32>
More information about the Mlir-commits
mailing list