[Mlir-commits] [mlir] 51ef80a - [mlir][Vector] Add support for 0-D vectors to vector.insert/extract

Diego Caballero llvmlistbot at llvm.org
Tue Jul 11 12:30:01 PDT 2023


Author: Diego Caballero
Date: 2023-07-11T19:28:16Z
New Revision: 51ef80a7c20b8a54ef637550a6af8e4293a50407

URL: https://github.com/llvm/llvm-project/commit/51ef80a7c20b8a54ef637550a6af8e4293a50407
DIFF: https://github.com/llvm/llvm-project/commit/51ef80a7c20b8a54ef637550a6af8e4293a50407.diff

LOG: [mlir][Vector] Add support for 0-D vectors to vector.insert/extract

This is part of the process to remove vector.insertelement/extractelement
from the Vector dialect.

RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops

Differential Revision: https://reviews.llvm.org/D152644

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 22b60d8680fd0e..555ed0bec3c9a6 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -577,18 +577,19 @@ def Vector_ExtractOp :
      PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
      DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
-    Arguments<(ins AnyVector:$vector, I64ArrayAttr:$position)>,
+    Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$position)>,
     Results<(outs AnyType)> {
   let summary = "extract operation";
   let description = [{
     Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
-    the proper position. Degenerates to an element type in the 0-D case.
+    the proper position. Degenerates to an element type if n-k is zero.
 
     Example:
 
     ```mlir
     %1 = vector.extract %0[3]: vector<4x8x16xf32>
     %2 = vector.extract %0[3, 3, 3]: vector<4x8x16xf32>
+    %3 = vector.extract %1[]: vector<f32>
     ```
   }];
   let builders = [
@@ -694,19 +695,21 @@ def Vector_InsertOp :
      PredOpTrait<"source operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
      AllTypesMatch<["dest", "res"]>]>,
-     Arguments<(ins AnyType:$source, AnyVector:$dest, I64ArrayAttr:$position)>,
-     Results<(outs AnyVector:$res)> {
+     Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, I64ArrayAttr:$position)>,
+     Results<(outs AnyVectorOfAnyRank:$res)> {
   let summary = "insert operation";
   let description = [{
     Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
     and inserts the n-D source into the (n+k)-D destination at the proper
-    position. Degenerates to a scalar source type when n = 0.
+    position. Degenerates to a scalar or a 0-d vector source type when n = 0.
 
     Example:
 
     ```mlir
     %2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
     %5 = vector.insert %3, %4[3, 3, 3] : f32 into vector<4x8x16xf32>
+    %8 = vector.insert %6, %7[] : f32 into vector<f32>
+    %11 = vector.insert %9, %10[3, 3, 3] : vector<f32> into vector<4x8x16xf32>
     ```
   }];
   let assemblyFormat = [{

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e25e0f2f218269..dcb3d2fbb0c790 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1163,8 +1163,7 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
   if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
     inferredReturnTypes.push_back(vectorType.getElementType());
   } else {
-    auto n =
-        std::min<size_t>(op.getPosition().size(), vectorType.getRank() - 1);
+    auto n = std::min<size_t>(op.getPosition().size(), vectorType.getRank());
     inferredReturnTypes.push_back(VectorType::get(
         vectorType.getShape().drop_front(n), vectorType.getElementType()));
   }
@@ -2328,7 +2327,7 @@ LogicalResult InsertOp::verify() {
   auto destVectorType = getDestVectorType();
   if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
     return emitOpError(
-        "expected position attribute of rank smaller than dest vector rank");
+        "expected position attribute of rank no greater than dest vector rank");
   auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
   if (srcVectorType &&
       (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 75ccde168c994e..225a20fd3c4ccd 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -152,6 +152,13 @@ func.func @extract_precise_position_overflow(%arg0: vector<4x8x16xf32>) {
 
 // -----
 
+func.func @extract_0d(%arg0: vector<f32>) {
+  // expected-error at +1 {{expected position attribute of rank smaller than vector rank}}
+  %1 = vector.extract %arg0[0] : vector<f32>
+}
+
+// -----
+
 func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
   // expected-error at +1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
   %1 = vector.extract %arg0[0, 0, -1] : vector<4x8x16xf32>
@@ -192,7 +199,7 @@ func.func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) {
 // -----
 
 func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected position attribute of rank smaller than dest vector rank}}
+  // expected-error at +1 {{expected position attribute of rank no greater than dest vector rank}}
   %1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32>
 }
 
@@ -226,6 +233,20 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
 
 // -----
 
+func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
+  // expected-error at +1 {{expected position attribute rank + source rank to match dest vector rank}}
+  %1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
+}
+
+// -----
+
+func.func @insert_0d(%a: f32, %b: vector<f32>) {
+  // expected-error at +1 {{expected position attribute of rank no greater than dest vector rank}}
+  %1 = vector.insert %a, %b[0] : f32 into vector<f32>
+}
+
+// -----
+
 func.func @outerproduct_num_operands(%arg0: f32) {
   // expected-error at +1 {{expected at least 2 operands}}
   %1 = vector.outerproduct %arg0 : f32, f32

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 19488c5cbeda06..90dc9a954a2fdb 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -219,6 +219,13 @@ func.func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x1
   return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32
 }
 
+// CHECK-LABEL: @extract_0d
+func.func @extract_0d(%a: vector<f32>) -> f32 {
+  // CHECK-NEXT: vector.extract %{{.*}}[] : vector<f32>
+  %0 = vector.extract %a[] : vector<f32>
+  return %0 : f32
+}
+
 // CHECK-LABEL: @insert_element_0d
 func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
   // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>
@@ -248,6 +255,15 @@ func.func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vecto
   return %4 : vector<4x8x16xf32>
 }
 
+// CHECK-LABEL: @insert_0d
+func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
+  // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
+  %1 = vector.insert %a,  %b[] : f32 into vector<f32>
+  // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
+  %2 = vector.insert %b,  %c[0, 1] : vector<f32> into vector<2x3xf32>
+  return %1, %2 : vector<f32>, vector<2x3xf32>
+}
+
 // CHECK-LABEL: @outerproduct
 func.func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
   // CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>


        


More information about the Mlir-commits mailing list