[Mlir-commits] [mlir] [mlir][vector] Add more tests for ConvertVectorToLLVM (3/n) (PR #102854)

Andrzej WarzyƄski llvmlistbot at llvm.org
Sun Aug 11 23:58:12 PDT 2024


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/102854

Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:
  * vector.extractelement
  * vector.extract

I have also renamed some function names from `@extract_element{}` to
`@extractelement{}` - that's to make a clearer distinction between
tests for `vector.extractelement` (tested by `@extractelement{}`) and
`vector.extract` (tested by `@extract_element{}`).


>From 6588586c6565b76f5863f4cf503b7a038a63ee7b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 11 Aug 2024 17:58:39 +0100
Subject: [PATCH] [mlir][vector] Add more tests for ConvertVectorToLLVM (3/n)

Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:
  * vector.extractelement
  * vector.extract

I have also renamed some function names from `@extract_element{}` to
`@extractelement{}` - that's to make a clearer distinction between
tests for `vector.extractelement` (tested by `@extractelement{}`) and
`vector.extract` (tested by `@extract_element{}`).
---
 .../VectorToLLVM/vector-to-llvm.mlir          | 91 +++++++++++++++++--
 1 file changed, 85 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index d164e875097968..9b61c4493994c2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1049,8 +1049,8 @@ func.func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf3
 
 // -----
 
-// CHECK-LABEL: @extract_element_0d
-func.func @extract_element_0d(%a: vector<f32>) -> f32 {
+// CHECK-LABEL: @extractelement_0d
+func.func @extractelement_0d(%a: vector<f32>) -> f32 {
   // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
   // CHECK: llvm.extractelement %{{.*}}[%[[C0]] : {{.*}}] : vector<1xf32>
   %1 = vector.extractelement %a[] : vector<f32>
@@ -1059,31 +1059,54 @@ func.func @extract_element_0d(%a: vector<f32>) -> f32 {
 
 // -----
 
-func.func @extract_element(%arg0: vector<16xf32>) -> f32 {
+func.func @extractelement(%arg0: vector<16xf32>) -> f32 {
   %0 = arith.constant 15 : i32
   %1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32>
   return %1 : f32
 }
-// CHECK-LABEL: @extract_element(
+// CHECK-LABEL: @extractelement(
 // CHECK-SAME: %[[A:.*]]: vector<16xf32>)
 //       CHECK:   %[[c:.*]] = arith.constant 15 : i32
 //       CHECK:   %[[x:.*]] = llvm.extractelement %[[A]][%[[c]] : i32] : vector<16xf32>
 //       CHECK:   return %[[x]] : f32
 
+func.func @extractelement_scalable(%arg0: vector<[16]xf32>) -> f32 {
+  %0 = arith.constant 15 : i32
+  %1 = vector.extractelement %arg0[%0 : i32]: vector<[16]xf32>
+  return %1 : f32
+}
+// CHECK-LABEL: @extractelement_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>)
+//       CHECK:   %[[c:.*]] = arith.constant 15 : i32
+//       CHECK:   %[[x:.*]] = llvm.extractelement %[[A]][%[[c]] : i32] : vector<[16]xf32>
+//       CHECK:   return %[[x]] : f32
+
 // -----
 
-func.func @extract_element_index(%arg0: vector<16xf32>) -> f32 {
+func.func @extractelement_index(%arg0: vector<16xf32>) -> f32 {
   %0 = arith.constant 15 : index
   %1 = vector.extractelement %arg0[%0 : index]: vector<16xf32>
   return %1 : f32
 }
-// CHECK-LABEL: @extract_element_index(
+// CHECK-LABEL: @extractelement_index(
 // CHECK-SAME: %[[A:.*]]: vector<16xf32>)
 //       CHECK:   %[[c:.*]] = arith.constant 15 : index
 //       CHECK:   %[[i:.*]] = builtin.unrealized_conversion_cast %[[c]] : index to i64
 //       CHECK:   %[[x:.*]] = llvm.extractelement %[[A]][%[[i]] : i64] : vector<16xf32>
 //       CHECK:   return %[[x]] : f32
 
+func.func @extractelement_index_scalable(%arg0: vector<[16]xf32>) -> f32 {
+  %0 = arith.constant 15 : index
+  %1 = vector.extractelement %arg0[%0 : index]: vector<[16]xf32>
+  return %1 : f32
+}
+// CHECK-LABEL: @extractelement_index_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>)
+//       CHECK:   %[[c:.*]] = arith.constant 15 : index
+//       CHECK:   %[[i:.*]] = builtin.unrealized_conversion_cast %[[c]] : index to i64
+//       CHECK:   %[[x:.*]] = llvm.extractelement %[[A]][%[[i]] : i64] : vector<[16]xf32>
+//       CHECK:   return %[[x]] : f32
+
 // -----
 
 func.func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 {
@@ -1095,6 +1118,15 @@ func.func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 {
 //       CHECK:   llvm.extractelement {{.*}}[{{.*}} : i64] : vector<16xf32>
 //       CHECK:   return {{.*}} : f32
 
+func.func @extract_element_from_vec_1d_scalable(%arg0: vector<[16]xf32>) -> f32 {
+  %0 = vector.extract %arg0[15]: f32 from vector<[16]xf32>
+  return %0 : f32
+}
+// CHECK-LABEL: @extract_element_from_vec_1d_scalable
+//       CHECK:   llvm.mlir.constant(15 : i64) : i64
+//       CHECK:   llvm.extractelement {{.*}}[{{.*}} : i64] : vector<[16]xf32>
+//       CHECK:   return {{.*}} : f32
+
 // -----
 
 func.func @extract_index_element_from_vec_1d(%arg0: vector<16xindex>) -> index {
@@ -1109,6 +1141,18 @@ func.func @extract_index_element_from_vec_1d(%arg0: vector<16xindex>) -> index {
 //       CHECK:   %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
 //       CHECK:   return %[[T3]] : index
 
+func.func @extract_index_element_from_vec_1d_scalable(%arg0: vector<[16]xindex>) -> index {
+  %0 = vector.extract %arg0[15]: index from vector<[16]xindex>
+  return %0 : index
+}
+// CHECK-LABEL: @extract_index_element_from_vec_1d_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xindex>)
+//       CHECK:   %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<[16]xindex> to vector<[16]xi64>
+//       CHECK:   %[[T1:.*]] = llvm.mlir.constant(15 : i64) : i64
+//       CHECK:   %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<[16]xi64>
+//       CHECK:   %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
+//       CHECK:   return %[[T3]] : index
+
 // -----
 
 func.func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> {
@@ -1119,6 +1163,14 @@ func.func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16x
 //       CHECK:   llvm.extractvalue {{.*}}[0] : !llvm.array<4 x array<3 x vector<16xf32>>>
 //       CHECK:   return {{.*}} : vector<3x16xf32>
 
+func.func @extract_vec_2d_from_vec_3d_scalable(%arg0: vector<4x3x[16]xf32>) -> vector<3x[16]xf32> {
+  %0 = vector.extract %arg0[0]: vector<3x[16]xf32> from vector<4x3x[16]xf32>
+  return %0 : vector<3x[16]xf32>
+}
+// CHECK-LABEL: @extract_vec_2d_from_vec_3d_scalable
+//       CHECK:   llvm.extractvalue {{.*}}[0] : !llvm.array<4 x array<3 x vector<[16]xf32>>>
+//       CHECK:   return {{.*}} : vector<3x[16]xf32>
+
 // -----
 
 func.func @extract_vec_1d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<16xf32> {
@@ -1129,6 +1181,14 @@ func.func @extract_vec_1d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<16xf3
 //       CHECK:   llvm.extractvalue {{.*}}[0, 0] : !llvm.array<4 x array<3 x vector<16xf32>>>
 //       CHECK:   return {{.*}} : vector<16xf32>
 
+func.func @extract_vec_1d_from_vec_3d_scalable(%arg0: vector<4x3x[16]xf32>) -> vector<[16]xf32> {
+  %0 = vector.extract %arg0[0, 0]: vector<[16]xf32> from vector<4x3x[16]xf32>
+  return %0 : vector<[16]xf32>
+}
+// CHECK-LABEL: @extract_vec_1d_from_vec_3d_scalable
+//       CHECK:   llvm.extractvalue {{.*}}[0, 0] : !llvm.array<4 x array<3 x vector<[16]xf32>>>
+//       CHECK:   return {{.*}} : vector<[16]xf32>
+
 // -----
 
 func.func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
@@ -1141,6 +1201,16 @@ func.func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
 //       CHECK:   llvm.extractelement {{.*}}[{{.*}} : i64] : vector<16xf32>
 //       CHECK:   return {{.*}} : f32
 
+func.func @extract_element_from_vec_3d_scalable(%arg0: vector<4x3x[16]xf32>) -> f32 {
+  %0 = vector.extract %arg0[0, 0, 0]: f32 from vector<4x3x[16]xf32>
+  return %0 : f32
+}
+// CHECK-LABEL: @extract_element_from_vec_3d_scalable
+//       CHECK:   llvm.extractvalue {{.*}}[0, 0] : !llvm.array<4 x array<3 x vector<[16]xf32>>>
+//       CHECK:   llvm.mlir.constant(0 : i64) : i64
+//       CHECK:   llvm.extractelement {{.*}}[{{.*}} : i64] : vector<[16]xf32>
+//       CHECK:   return {{.*}} : f32
+
 // -----
 
 func.func @extract_element_with_value_1d(%arg0: vector<16xf32>, %arg1: index) -> f32 {
@@ -1152,6 +1222,15 @@ func.func @extract_element_with_value_1d(%arg0: vector<16xf32>, %arg1: index) ->
 //       CHECK:   %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64
 //       CHECK:   llvm.extractelement %[[VEC]][%[[UC]] : i64] : vector<16xf32>
 
+func.func @extract_element_with_value_1d_scalable(%arg0: vector<[16]xf32>, %arg1: index) -> f32 {
+  %0 = vector.extract %arg0[%arg1]: f32 from vector<[16]xf32>
+  return %0 : f32
+}
+// CHECK-LABEL: @extract_element_with_value_1d_scalable
+//  CHECK-SAME:   %[[VEC:.+]]: vector<[16]xf32>, %[[INDEX:.+]]: index
+//       CHECK:   %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64
+//       CHECK:   llvm.extractelement %[[VEC]][%[[UC]] : i64] : vector<[16]xf32>
+
 // -----
 
 func.func @extract_element_with_value_2d(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {



More information about the Mlir-commits mailing list