[Mlir-commits] [mlir] [mlir][vector] linearize vector.insert_strided_slice (flatten to vector.shuffle) (PR #138725)
James Newling
llvmlistbot at llvm.org
Fri May 9 14:56:30 PDT 2025
================
@@ -178,6 +178,45 @@ func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4
// -----
+// Test of insert_strided_slice -> shuffle.
+// This is a contiguous insertion of 4 elements at offset 6 into a vector of 12 elements.
+// CHECK-LABEL: insert_strided_slice_2D_into_4D
+func.func @insert_strided_slice_2D_into_4D(%arg0 : vector<2x2xi8>, %arg1 : vector<2x1x3x2xi8>) -> vector<2x1x3x2xi8> {
+
+// CHECK-DAG: %[[ARG0:.*]] = vector.shape_cast {{.*}} to vector<4xi8>
+// CHECK-DAG: %[[ARG1:.*]] = vector.shape_cast {{.*}} to vector<12xi8>
+// CHECK: vector.shuffle %[[ARG1]], %[[ARG0]]
+// CHECK-SAME: [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11] : vector<12xi8>, vector<4xi8>
+ %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1, 0, 0, 0], strides = [1, 1]} : vector<2x2xi8> into vector<2x1x3x2xi8>
+
+// CHECK: %[[RES:.*]] = vector.shape_cast {{.*}} to vector<2x1x3x2xi8>
+// CHECK: return %[[RES]] : vector<2x1x3x2xi8>
+ return %0 : vector<2x1x3x2xi8>
+}
+
+// -----
+
+// Test of insert_strided_slice -> shuffle.
+// [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]], [[12, 13], [14, 15]], [[16, 17]]]
+// ^ ^
+// | |
+// where the 2 elements are inserted into the 3x3x2 vector
+// CHECK-LABEL: insert_strided_slice_3D
+func.func @insert_strided_slice_3D(%arg0 : vector<1x2x1xi8>, %arg2 : vector<3x3x2xi8>) -> vector<3x3x2xi8> {
+
+// CHECK-DAG: %[[ARG0:.*]] = vector.shape_cast {{.*}} to vector<2xi8>
+// CHECK-DAG: %[[ARG1:.*]] = vector.shape_cast {{.*}} to vector<18xi8>
+// CHECK: vector.shuffle %[[ARG1]], %[[ARG0]]
+// CHECK-SAME: [0, 1, 2, 3, 4, 5, 6, 7, 8, 18, 10, 19, 12, 13, 14, 15, 16, 17] : vector<18xi8>, vector<2xi8>
+ %0 = vector.insert_strided_slice %arg0, %arg2 {offsets = [1, 1, 1], sizes = [1, 2, 1], strides = [1, 1, 1]} : vector<1x2x1xi8> into vector<3x3x2xi8>
+
+// CHECK: %[[RES:.*]] = vector.shape_cast {{.*}} to vector<3x3x2xi8>
+// CHECK: return %[[RES]] : vector<3x3x2xi8>
+ return %0 : vector<3x3x2xi8>
+}
----------------
newling wrote:
Makes sense, I prefer that. I can't test it, but I that's unavoidable. Until strides != 1 is supported (or entirely removed :-))
https://github.com/llvm/llvm-project/pull/138725
More information about the Mlir-commits
mailing list