[Mlir-commits] [mlir] [mlir][vector] linearize vector.insert_strided_slice (flatten to vector.shuffle) (PR #138725)

James Newling llvmlistbot at llvm.org
Fri May 9 11:22:26 PDT 2025


================
@@ -178,6 +178,45 @@ func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4
 
 // -----
 
+// Test of insert_strided_slice -> shuffle.
+// This is a contiguous insertion of 4 elements at offset 6 into a vector of 12 elements. 
+// CHECK-LABEL: insert_strided_slice_2D_into_4D
+func.func @insert_strided_slice_2D_into_4D(%arg0 : vector<2x2xi8>, %arg1 : vector<2x1x3x2xi8>) -> vector<2x1x3x2xi8> {
+
+//   CHECK-DAG:    %[[ARG0:.*]] = vector.shape_cast {{.*}}  to vector<4xi8>
+//   CHECK-DAG:    %[[ARG1:.*]] = vector.shape_cast {{.*}}  to vector<12xi8>
+//       CHECK:    vector.shuffle %[[ARG1]], %[[ARG0]]
+//  CHECK-SAME:      [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11] : vector<12xi8>, vector<4xi8>
+  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1, 0, 0, 0], strides = [1, 1]} : vector<2x2xi8> into vector<2x1x3x2xi8>
+
+//       CHECK:    %[[RES:.*]] = vector.shape_cast {{.*}} to vector<2x1x3x2xi8>
+//       CHECK:    return %[[RES]] : vector<2x1x3x2xi8>
+  return %0 : vector<2x1x3x2xi8>
+}
+
+// -----
+
+// Test of insert_strided_slice -> shuffle. 
+// [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]], [[12, 13], [14, 15]], [[16, 17]]]
+//                                         ^         ^
+//                                         |         |
+//                          where the 2 elements are inserted into the 3x3x2 vector
+// CHECK-LABEL: insert_strided_slice_3D
+func.func @insert_strided_slice_3D(%arg0 : vector<1x2x1xi8>, %arg2 : vector<3x3x2xi8>) -> vector<3x3x2xi8> {
+
+//   CHECK-DAG:     %[[ARG0:.*]] = vector.shape_cast {{.*}}  to vector<2xi8>
+//   CHECK-DAG:     %[[ARG1:.*]] = vector.shape_cast {{.*}}  to vector<18xi8>
+//       CHECK:     vector.shuffle %[[ARG1]], %[[ARG0]]
+//  CHECK-SAME:       [0, 1, 2, 3, 4, 5, 6, 7, 8, 18, 10, 19, 12, 13, 14, 15, 16, 17] : vector<18xi8>, vector<2xi8>
+  %0 = vector.insert_strided_slice %arg0, %arg2 {offsets = [1, 1, 1], sizes = [1, 2, 1], strides = [1, 1, 1]} : vector<1x2x1xi8> into vector<3x3x2xi8>
+
+//       CHECK:     %[[RES:.*]] = vector.shape_cast {{.*}} to vector<3x3x2xi8>
+//       CHECK:     return %[[RES]] : vector<3x3x2xi8>
+  return %0 : vector<3x3x2xi8>
+}
----------------
newling wrote:

I've added a negative test for scalable, but the strides !=1 case is not reachable because of the strided slice op verifiers confirm that stride is 1: https://github.com/llvm/llvm-project/blob/b3a6d434a7051d879718ef92a4fafd1697759aed/mlir/lib/Dialect/Vector/IR/VectorOps.cpp#L3293

I've changed to asserts, to indicate that we expect it to be impossible currently to reach this point with strides>1 (happy to change this back to emitOpError though if that's preferred). 

https://github.com/llvm/llvm-project/pull/138725


More information about the Mlir-commits mailing list