[Mlir-commits] [mlir] [mlir][vector] Add more tests for ConvertVectorToLLVM (5/n) (PR #106510)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Aug 29 01:46:15 PDT 2024


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/106510

Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:
  * vector.extract_strided_slice

For consistency with other tests, I have also removed "vector" from test
functions names for:
  * vector.print
  * vector.type_cast
  * vector.extract_strided_slice

The first two Ops precede `vector.extract_strided_slice` in the test
file, i.e. those should be next to be updated in this series of patches.
However,
  * For `vector.print`, we don't use vectors in this test file (that
    would require running VectorToSCF).
  * For `vector.type_cast`, the existing tests assume fixed-width sizes.
    We need to write new tests and I am leaving that as a TODO.

Note, I've also updated test function names to be more descriptive and
consistent with other tests, e.g.
  * `@extract_strided_slice3` -> `@extract_strided_slice_f32_2d_from_2d`


>From 7d6a1c3a506901e3e2cf6b07b1d7d351c2870ea2 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 29 Aug 2024 08:58:56 +0100
Subject: [PATCH] [mlir][vector] Add more tests for ConvertVectorToLLVM (5/n)

Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:
  * vector.extract_strided_slice

For consistency with other tests, I have also removed "vector" from test
functions names for:
  * vector.print
  * vector.type_cast
  * vector.extract_strided_slice

The first two Ops precede `vector.extract_strided_slice` in the test
file, i.e. those should be next to be updated in this series of patches.
However,
  * For `vector.print`, we don't use vectors in this test file (that
    would require running VectorToSCF).
  * For `vector.type_cast`, the existing tests assume fixed-width sizes.
    We need to write new tests and I am leaving that as a TODO.

Note, I've also updated test function names to be more descriptive and
consistent with other tests, e.g.
  * `@extract_strided_slice3` -> `@extract_strided_slice_f32_2d_from_2d`
---
 .../VectorToLLVM/vector-to-llvm.mlir          | 153 ++++++++++--------
 1 file changed, 88 insertions(+), 65 deletions(-)

diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 63bcecd863e95d..2c8b1a2a6ff1f6 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1,5 +1,7 @@
 // RUN: mlir-opt %s -convert-vector-to-llvm -split-input-file | FileCheck %s
 
+// TODO: Add tests for for vector.type_cast that would cover scalable vectors
+
 func.func @bitcast_f32_to_i32_vector_0d(%input: vector<f32>) -> vector<i32> {
   %0 = vector.bitcast %input : vector<f32> to vector<i32>
   return %0 : vector<i32>
@@ -1467,8 +1469,6 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %a
 // CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx(
 //       CHECK:   vector.insert
 
-// -----
-
 func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: f32, %idx: index)
                                         -> vector<1x[16]xf32> {
   %0 = vector.insert %arg1, %arg0[0, %idx]: f32 into vector<1x[16]xf32>
@@ -1482,11 +1482,11 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1
 
 // -----
 
-func.func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
+func.func @type_cast_f32(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
   %0 = vector.type_cast %arg0: memref<8x8x8xf32> to memref<vector<8x8x8xf32>>
   return %0 : memref<vector<8x8x8xf32>>
 }
-// CHECK-LABEL: @vector_type_cast
+// CHECK-LABEL: @type_cast_f32
 //       CHECK:   llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)>
 //       CHECK:   %[[allocated:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 //       CHECK:   llvm.insertvalue %[[allocated]], {{.*}}[0] : !llvm.struct<(ptr, ptr, i64)>
@@ -1495,18 +1495,22 @@ func.func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32
 //       CHECK:   llvm.mlir.constant(0 : index
 //       CHECK:   llvm.insertvalue {{.*}}[2] : !llvm.struct<(ptr, ptr, i64)>
 
+// NOTE: No test for scalable vectors - the input memref is fixed size.
+
 // -----
 
-func.func @vector_index_type_cast(%arg0: memref<8x8x8xindex>) -> memref<vector<8x8x8xindex>> {
+func.func @type_cast_index(%arg0: memref<8x8x8xindex>) -> memref<vector<8x8x8xindex>> {
   %0 = vector.type_cast %arg0: memref<8x8x8xindex> to memref<vector<8x8x8xindex>>
   return %0 : memref<vector<8x8x8xindex>>
 }
-// CHECK-LABEL: @vector_index_type_cast(
+// CHECK-LABEL: @type_cast_index(
 // CHECK-SAME: %[[A:.*]]: memref<8x8x8xindex>)
 //       CHECK:   %{{.*}} = builtin.unrealized_conversion_cast %[[A]] : memref<8x8x8xindex> to !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 
 //       CHECK:   %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr, ptr, i64)> to memref<vector<8x8x8xindex>>
 
+// NOTE: No test for scalable vectors - the input memref is fixed size.
+
 // -----
 
 func.func @vector_type_cast_non_zero_addrspace(%arg0: memref<8x8x8xf32, 3>) -> memref<vector<8x8x8xf32>, 3> {
@@ -1522,16 +1526,18 @@ func.func @vector_type_cast_non_zero_addrspace(%arg0: memref<8x8x8xf32, 3>) -> m
 //       CHECK:   llvm.mlir.constant(0 : index
 //       CHECK:   llvm.insertvalue {{.*}}[2] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
 
+// NOTE: No test for scalable vectors - the input memref is fixed size.
+
 // -----
 
-func.func @vector_print_scalar_i1(%arg0: i1) {
+func.func @print_scalar_i1(%arg0: i1) {
   vector.print %arg0 : i1
   return
 }
 //
 // Type "boolean" always uses zero extension.
 //
-// CHECK-LABEL: @vector_print_scalar_i1(
+// CHECK-LABEL: @print_scalar_i1(
 // CHECK-SAME: %[[A:.*]]: i1)
 //       CHECK: %[[S:.*]] = arith.extui %[[A]] : i1 to i64
 //       CHECK: llvm.call @printI64(%[[S]]) : (i64) -> ()
@@ -1539,11 +1545,11 @@ func.func @vector_print_scalar_i1(%arg0: i1) {
 
 // -----
 
-func.func @vector_print_scalar_i4(%arg0: i4) {
+func.func @print_scalar_i4(%arg0: i4) {
   vector.print %arg0 : i4
   return
 }
-// CHECK-LABEL: @vector_print_scalar_i4(
+// CHECK-LABEL: @print_scalar_i4(
 // CHECK-SAME: %[[A:.*]]: i4)
 //       CHECK: %[[S:.*]] = arith.extsi %[[A]] : i4 to i64
 //       CHECK: llvm.call @printI64(%[[S]]) : (i64) -> ()
@@ -1551,11 +1557,11 @@ func.func @vector_print_scalar_i4(%arg0: i4) {
 
 // -----
 
-func.func @vector_print_scalar_si4(%arg0: si4) {
+func.func @print_scalar_si4(%arg0: si4) {
   vector.print %arg0 : si4
   return
 }
-// CHECK-LABEL: @vector_print_scalar_si4(
+// CHECK-LABEL: @print_scalar_si4(
 // CHECK-SAME: %[[A:.*]]: si4)
 //       CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[A]] : si4 to i4
 //       CHECK: %[[S:.*]] = arith.extsi %[[C]] : i4 to i64
@@ -1564,11 +1570,11 @@ func.func @vector_print_scalar_si4(%arg0: si4) {
 
 // -----
 
-func.func @vector_print_scalar_ui4(%arg0: ui4) {
+func.func @print_scalar_ui4(%arg0: ui4) {
   vector.print %arg0 : ui4
   return
 }
-// CHECK-LABEL: @vector_print_scalar_ui4(
+// CHECK-LABEL: @print_scalar_ui4(
 // CHECK-SAME: %[[A:.*]]: ui4)
 //       CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[A]] : ui4 to i4
 //       CHECK: %[[S:.*]] = arith.extui %[[C]] : i4 to i64
@@ -1577,11 +1583,11 @@ func.func @vector_print_scalar_ui4(%arg0: ui4) {
 
 // -----
 
-func.func @vector_print_scalar_i32(%arg0: i32) {
+func.func @print_scalar_i32(%arg0: i32) {
   vector.print %arg0 : i32
   return
 }
-// CHECK-LABEL: @vector_print_scalar_i32(
+// CHECK-LABEL: @print_scalar_i32(
 // CHECK-SAME: %[[A:.*]]: i32)
 //       CHECK: %[[S:.*]] = arith.extsi %[[A]] : i32 to i64
 //       CHECK: llvm.call @printI64(%[[S]]) : (i64) -> ()
@@ -1589,11 +1595,11 @@ func.func @vector_print_scalar_i32(%arg0: i32) {
 
 // -----
 
-func.func @vector_print_scalar_ui32(%arg0: ui32) {
+func.func @print_scalar_ui32(%arg0: ui32) {
   vector.print %arg0 : ui32
   return
 }
-// CHECK-LABEL: @vector_print_scalar_ui32(
+// CHECK-LABEL: @print_scalar_ui32(
 // CHECK-SAME: %[[A:.*]]: ui32)
 //       CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[A]] : ui32 to i32
 //       CHECK: %[[S:.*]] = arith.extui %[[C]] : i32 to i64
@@ -1601,11 +1607,11 @@ func.func @vector_print_scalar_ui32(%arg0: ui32) {
 
 // -----
 
-func.func @vector_print_scalar_i40(%arg0: i40) {
+func.func @print_scalar_i40(%arg0: i40) {
   vector.print %arg0 : i40
   return
 }
-// CHECK-LABEL: @vector_print_scalar_i40(
+// CHECK-LABEL: @print_scalar_i40(
 // CHECK-SAME: %[[A:.*]]: i40)
 //       CHECK: %[[S:.*]] = arith.extsi %[[A]] : i40 to i64
 //       CHECK: llvm.call @printI64(%[[S]]) : (i64) -> ()
@@ -1613,11 +1619,11 @@ func.func @vector_print_scalar_i40(%arg0: i40) {
 
 // -----
 
-func.func @vector_print_scalar_si40(%arg0: si40) {
+func.func @print_scalar_si40(%arg0: si40) {
   vector.print %arg0 : si40
   return
 }
-// CHECK-LABEL: @vector_print_scalar_si40(
+// CHECK-LABEL: @print_scalar_si40(
 // CHECK-SAME: %[[A:.*]]: si40)
 //       CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[A]] : si40 to i40
 //       CHECK: %[[S:.*]] = arith.extsi %[[C]] : i40 to i64
@@ -1626,11 +1632,11 @@ func.func @vector_print_scalar_si40(%arg0: si40) {
 
 // -----
 
-func.func @vector_print_scalar_ui40(%arg0: ui40) {
+func.func @print_scalar_ui40(%arg0: ui40) {
   vector.print %arg0 : ui40
   return
 }
-// CHECK-LABEL: @vector_print_scalar_ui40(
+// CHECK-LABEL: @print_scalar_ui40(
 // CHECK-SAME: %[[A:.*]]: ui40)
 //       CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[A]] : ui40 to i40
 //       CHECK: %[[S:.*]] = arith.extui %[[C]] : i40 to i64
@@ -1639,22 +1645,22 @@ func.func @vector_print_scalar_ui40(%arg0: ui40) {
 
 // -----
 
-func.func @vector_print_scalar_i64(%arg0: i64) {
+func.func @print_scalar_i64(%arg0: i64) {
   vector.print %arg0 : i64
   return
 }
-// CHECK-LABEL: @vector_print_scalar_i64(
+// CHECK-LABEL: @print_scalar_i64(
 // CHECK-SAME: %[[A:.*]]: i64)
 //       CHECK:    llvm.call @printI64(%[[A]]) : (i64) -> ()
 //       CHECK:    llvm.call @printNewline() : () -> ()
 
 // -----
 
-func.func @vector_print_scalar_ui64(%arg0: ui64) {
+func.func @print_scalar_ui64(%arg0: ui64) {
   vector.print %arg0 : ui64
   return
 }
-// CHECK-LABEL: @vector_print_scalar_ui64(
+// CHECK-LABEL: @print_scalar_ui64(
 // CHECK-SAME: %[[A:.*]]: ui64)
 //       CHECK:    %[[C:.*]] = builtin.unrealized_conversion_cast %[[A]] : ui64 to i64
 //       CHECK:    llvm.call @printU64(%[[C]]) : (i64) -> ()
@@ -1662,11 +1668,11 @@ func.func @vector_print_scalar_ui64(%arg0: ui64) {
 
 // -----
 
-func.func @vector_print_scalar_index(%arg0: index) {
+func.func @print_scalar_index(%arg0: index) {
   vector.print %arg0 : index
   return
 }
-// CHECK-LABEL: @vector_print_scalar_index(
+// CHECK-LABEL: @print_scalar_index(
 // CHECK-SAME: %[[A:.*]]: index)
 //       CHECK:    %[[C:.*]] = builtin.unrealized_conversion_cast %[[A]] : index to i64
 //       CHECK:    llvm.call @printU64(%[[C]]) : (i64) -> ()
@@ -1674,22 +1680,22 @@ func.func @vector_print_scalar_index(%arg0: index) {
 
 // -----
 
-func.func @vector_print_scalar_f32(%arg0: f32) {
+func.func @print_scalar_f32(%arg0: f32) {
   vector.print %arg0 : f32
   return
 }
-// CHECK-LABEL: @vector_print_scalar_f32(
+// CHECK-LABEL: @print_scalar_f32(
 // CHECK-SAME: %[[A:.*]]: f32)
 //       CHECK:    llvm.call @printF32(%[[A]]) : (f32) -> ()
 //       CHECK:    llvm.call @printNewline() : () -> ()
 
 // -----
 
-func.func @vector_print_scalar_f64(%arg0: f64) {
+func.func @print_scalar_f64(%arg0: f64) {
   vector.print %arg0 : f64
   return
 }
-// CHECK-LABEL: @vector_print_scalar_f64(
+// CHECK-LABEL: @print_scalar_f64(
 // CHECK-SAME: %[[A:.*]]: f64)
 //       CHECK:    llvm.call @printF64(%[[A]]) : (f64) -> ()
 //       CHECK:    llvm.call @printNewline() : () -> ()
@@ -1699,46 +1705,50 @@ func.func @vector_print_scalar_f64(%arg0: f64) {
 // CHECK-LABEL: module {
 // CHECK: llvm.func @printString(!llvm.ptr)
 // CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]]({{.*}})
-// CHECK: @vector_print_string
+// CHECK: @print_string
 //       CHECK-NEXT: %[[GLOBAL_ADDR:.*]] = llvm.mlir.addressof @[[GLOBAL_STR]] : !llvm.ptr
 //       CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr
 //       CHECK-NEXT: llvm.call @printString(%[[STR_PTR]]) : (!llvm.ptr) -> ()
-func.func @vector_print_string() {
+func.func @print_string() {
   vector.print str "Hello, World!"
   return
 }
 
 // -----
 
-func.func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> {
+func.func @extract_strided_slice_f32(%arg0: vector<4xf32>) -> vector<2xf32> {
   %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
   return %0 : vector<2xf32>
 }
-// CHECK-LABEL: @extract_strided_slice1(
+// CHECK-LABEL: @extract_strided_slice_f32(
 //  CHECK-SAME:    %[[A:.*]]: vector<4xf32>)
 //       CHECK:    %[[T0:.*]] = llvm.shufflevector %[[A]], %[[A]] [2, 3] : vector<4xf32>
 //       CHECK:    return %[[T0]] : vector<2xf32>
 
+// NOTE: For scalable vectors we could only extract vector<[4]xf32> from vector<[4]xf32>, but that would be a NOP.
+
 // -----
 
-func.func @extract_strided_index_slice1(%arg0: vector<4xindex>) -> vector<2xindex> {
+func.func @extract_strided_index_slice_index(%arg0: vector<4xindex>) -> vector<2xindex> {
   %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xindex> to vector<2xindex>
   return %0 : vector<2xindex>
 }
-// CHECK-LABEL: @extract_strided_index_slice1(
+// CHECK-LABEL: @extract_strided_index_slice_index(
 //  CHECK-SAME:    %[[A:.*]]: vector<4xindex>)
 //       CHECK:    %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4xindex> to vector<4xi64>
 //       CHECK:    %[[T2:.*]] = llvm.shufflevector %[[T0]], %[[T0]] [2, 3] : vector<4xi64>
 //       CHECK:    %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<2xi64> to vector<2xindex>
 //       CHECK:    return %[[T3]] : vector<2xindex>
 
+// NOTE: For scalable vectors we could only extract vector<[4]xindex> from vector<[4]xindex>, but that would be a NOP.
+
 // -----
 
-func.func @extract_strided_slice2(%arg0: vector<4x8xf32>) -> vector<2x8xf32> {
+func.func @extract_strided_slice_f32_1d_from_2d(%arg0: vector<4x8xf32>) -> vector<2x8xf32> {
   %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32>
   return %0 : vector<2x8xf32>
 }
-// CHECK-LABEL: @extract_strided_slice2(
+// CHECK-LABEL: @extract_strided_slice_f32_1d_from_2d(
 //  CHECK-SAME:    %[[ARG:.*]]: vector<4x8xf32>)
 //       CHECK:    %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x8xf32> to !llvm.array<4 x vector<8xf32>>
 //       CHECK:    %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vector<8xf32>>
@@ -1749,13 +1759,28 @@ func.func @extract_strided_slice2(%arg0: vector<4x8xf32>) -> vector<2x8xf32> {
 //       CHECK:    %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<2 x vector<8xf32>> to vector<2x8xf32>
 //       CHECK:    return %[[T5]]
 
+func.func @extract_strided_slice_f32_1d_from_2d_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
+  %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x[8]xf32> to vector<2x[8]xf32>
+  return %0 : vector<2x[8]xf32>
+}
+// CHECK-LABEL:   func.func @extract_strided_slice_f32_1d_from_2d_scalable(
+//  CHECK-SAME:    %[[ARG:.*]]: vector<4x[8]xf32>)
+//       CHECK:    %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
+//       CHECK:    %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vector<[8]xf32>>
+//       CHECK:    %[[T1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
+//       CHECK:    %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm.array<2 x vector<[8]xf32>>
+//       CHECK:    %[[T3:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
+//       CHECK:    %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T2]][1] : !llvm.array<2 x vector<[8]xf32>>
+//       CHECK:    %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
+//       CHECK:    return %[[T5]]
+
 // -----
 
-func.func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
+func.func @extract_strided_slice_f32_2d_from_2d(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
   %0 = vector.extract_strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
   return %0 : vector<2x2xf32>
 }
-// CHECK-LABEL: @extract_strided_slice3(
+// CHECK-LABEL: @extract_strided_slice_f32_2d_from_2d(
 //  CHECK-SAME:    %[[ARG:.*]]: vector<4x8xf32>)
 //       CHECK:    %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x8xf32> to !llvm.array<4 x vector<8xf32>>
 //       CHECK:    %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
@@ -1769,27 +1794,25 @@ func.func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
 //       CHECK:    %[[VAL_12:.*]] = builtin.unrealized_conversion_cast %[[T7]] : !llvm.array<2 x vector<2xf32>> to vector<2x2xf32>
 //       CHECK:    return %[[VAL_12]] : vector<2x2xf32>
 
-// -----
-
-func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
-  %0 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[4]xi32>
-  return %0 : vector<1x1x[4]xi32>
-}
-
-// CHECK-LABEL:   func.func @extract_strided_slice_scalable(
-// CHECK-SAME:      %[[ARG_0:.*]]: vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
-
-//      CHECK:      %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
-//      CHECK:      %[[CST:.*]] = arith.constant dense<0> : vector<1x1x[4]xi32>
-//      CHECK:      %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
-//      CHECK:      %[[CST_1:.*]] = arith.constant dense<0> : vector<1x[4]xi32>
-//      CHECK:      %[[CAST_3:.*]] = builtin.unrealized_conversion_cast %[[CST_1]] : vector<1x[4]xi32> to !llvm.array<1 x vector<[4]xi32>>
-
-//      CHECK:      %[[EXT:.*]] = llvm.extractvalue %[[CAST_1]][0, 3] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
-//      CHECK:      %[[INS_1:.*]] = llvm.insertvalue %[[EXT]], %[[CAST_3]][0] : !llvm.array<1 x vector<[4]xi32>>
-//      CHECK:      %[[INS_2:.*]] = llvm.insertvalue %[[INS_1]], %[[CAST_2]][0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
-
-//      CHECK:      builtin.unrealized_conversion_cast %[[INS_2]] : !llvm.array<1 x array<1 x vector<[4]xi32>>> to vector<1x1x[4]xi32>
+// NOTE: For scalable vectors, we can only extract "full" scalable dimensions
+// (e.g. [8] from [8], but not [4] from [8]).
+
+func.func @extract_strided_slice_f32_2d_from_2d_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
+  %0 = vector.extract_strided_slice %arg0 {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32>
+  return %0 : vector<2x[8]xf32>
+}
+// CHECK-LABEL: @extract_strided_slice_f32_2d_from_2d_scalable(
+//  CHECK-SAME:     %[[ARG:.*]]: vector<4x[8]xf32>)
+// CHECK:           %[[T1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
+// CHECK:           %[[T2:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[T3:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
+// CHECK:           %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
+// CHECK:           %[[T5:.*]] = llvm.extractvalue %[[T1]][2] : !llvm.array<4 x vector<[8]xf32>>
+// CHECK:           %[[T6:.*]] = llvm.insertvalue %[[T5]], %[[T4]][0] : !llvm.array<2 x vector<[8]xf32>>
+// CHECK:           %[[T7:.*]] = llvm.extractvalue %[[T1]][3] : !llvm.array<4 x vector<[8]xf32>>
+// CHECK:           %[[T8:.*]] = llvm.insertvalue %[[T7]], %[[T6]][1] : !llvm.array<2 x vector<[8]xf32>>
+// CHECK:           %[[T9:.*]] = builtin.unrealized_conversion_cast %[[T8]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
+// CHECK:           return %[[T9]] : vector<2x[8]xf32>
 
 // -----
 



More information about the Mlir-commits mailing list