[Mlir-commits] [mlir] 51ef80a - [mlir][Vector] Add support for 0-D vectors to vector.insert/extract
Diego Caballero
llvmlistbot at llvm.org
Tue Jul 11 12:30:01 PDT 2023
Author: Diego Caballero
Date: 2023-07-11T19:28:16Z
New Revision: 51ef80a7c20b8a54ef637550a6af8e4293a50407
URL: https://github.com/llvm/llvm-project/commit/51ef80a7c20b8a54ef637550a6af8e4293a50407
DIFF: https://github.com/llvm/llvm-project/commit/51ef80a7c20b8a54ef637550a6af8e4293a50407.diff
LOG: [mlir][Vector] Add support for 0-D vectors to vector.insert/extract
This is part of the process to remove vector.insertelement/extractelement
from the Vector dialect.
RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops
Differential Revision: https://reviews.llvm.org/D152644
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 22b60d8680fd0e..555ed0bec3c9a6 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -577,18 +577,19 @@ def Vector_ExtractOp :
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
- Arguments<(ins AnyVector:$vector, I64ArrayAttr:$position)>,
+ Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$position)>,
Results<(outs AnyType)> {
let summary = "extract operation";
let description = [{
Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
- the proper position. Degenerates to an element type in the 0-D case.
+ the proper position. Degenerates to an element type if n-k is zero.
Example:
```mlir
%1 = vector.extract %0[3]: vector<4x8x16xf32>
%2 = vector.extract %0[3, 3, 3]: vector<4x8x16xf32>
+ %3 = vector.extract %1[]: vector<f32>
```
}];
let builders = [
@@ -694,19 +695,21 @@ def Vector_InsertOp :
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
AllTypesMatch<["dest", "res"]>]>,
- Arguments<(ins AnyType:$source, AnyVector:$dest, I64ArrayAttr:$position)>,
- Results<(outs AnyVector:$res)> {
+ Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, I64ArrayAttr:$position)>,
+ Results<(outs AnyVectorOfAnyRank:$res)> {
let summary = "insert operation";
let description = [{
Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
and inserts the n-D source into the (n+k)-D destination at the proper
- position. Degenerates to a scalar source type when n = 0.
+ position. Degenerates to a scalar or a 0-d vector source type when n = 0.
Example:
```mlir
%2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
%5 = vector.insert %3, %4[3, 3, 3] : f32 into vector<4x8x16xf32>
+ %8 = vector.insert %6, %7[] : f32 into vector<f32>
+ %11 = vector.insert %9, %10[3, 3, 3] : vector<f32> into vector<4x8x16xf32>
```
}];
let assemblyFormat = [{
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e25e0f2f218269..dcb3d2fbb0c790 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1163,8 +1163,7 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
inferredReturnTypes.push_back(vectorType.getElementType());
} else {
- auto n =
- std::min<size_t>(op.getPosition().size(), vectorType.getRank() - 1);
+ auto n = std::min<size_t>(op.getPosition().size(), vectorType.getRank());
inferredReturnTypes.push_back(VectorType::get(
vectorType.getShape().drop_front(n), vectorType.getElementType()));
}
@@ -2328,7 +2327,7 @@ LogicalResult InsertOp::verify() {
auto destVectorType = getDestVectorType();
if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
return emitOpError(
- "expected position attribute of rank smaller than dest vector rank");
+ "expected position attribute of rank no greater than dest vector rank");
auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
if (srcVectorType &&
(static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 75ccde168c994e..225a20fd3c4ccd 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -152,6 +152,13 @@ func.func @extract_precise_position_overflow(%arg0: vector<4x8x16xf32>) {
// -----
+func.func @extract_0d(%arg0: vector<f32>) {
+ // expected-error at +1 {{expected position attribute of rank smaller than vector rank}}
+ %1 = vector.extract %arg0[0] : vector<f32>
+}
+
+// -----
+
func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error at +1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
%1 = vector.extract %arg0[0, 0, -1] : vector<4x8x16xf32>
@@ -192,7 +199,7 @@ func.func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) {
// -----
func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
- // expected-error at +1 {{expected position attribute of rank smaller than dest vector rank}}
+ // expected-error at +1 {{expected position attribute of rank no greater than dest vector rank}}
%1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32>
}
@@ -226,6 +233,20 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
// -----
+func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
+ // expected-error at +1 {{expected position attribute rank + source rank to match dest vector rank}}
+ %1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
+}
+
+// -----
+
+func.func @insert_0d(%a: f32, %b: vector<f32>) {
+ // expected-error at +1 {{expected position attribute of rank no greater than dest vector rank}}
+ %1 = vector.insert %a, %b[0] : f32 into vector<f32>
+}
+
+// -----
+
func.func @outerproduct_num_operands(%arg0: f32) {
// expected-error at +1 {{expected at least 2 operands}}
%1 = vector.outerproduct %arg0 : f32, f32
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 19488c5cbeda06..90dc9a954a2fdb 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -219,6 +219,13 @@ func.func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x1
return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32
}
+// CHECK-LABEL: @extract_0d
+func.func @extract_0d(%a: vector<f32>) -> f32 {
+ // CHECK-NEXT: vector.extract %{{.*}}[] : vector<f32>
+ %0 = vector.extract %a[] : vector<f32>
+ return %0 : f32
+}
+
// CHECK-LABEL: @insert_element_0d
func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
// CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>
@@ -248,6 +255,15 @@ func.func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vecto
return %4 : vector<4x8x16xf32>
}
+// CHECK-LABEL: @insert_0d
+func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
+ // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
+ %1 = vector.insert %a, %b[] : f32 into vector<f32>
+ // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
+ %2 = vector.insert %b, %c[0, 1] : vector<f32> into vector<2x3xf32>
+ return %1, %2 : vector<f32>, vector<2x3xf32>
+}
+
// CHECK-LABEL: @outerproduct
func.func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
// CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>
More information about the Mlir-commits
mailing list