[Mlir-commits] [mlir] [mlir][vector] linearize vector.insert_strided_slice (flatten to vector.shuffle) (PR #138725)

James Newling llvmlistbot at llvm.org
Fri May 9 10:31:29 PDT 2025


================
@@ -129,88 +215,110 @@ struct LinearizeVectorExtractStridedSlice final
       : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
-  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
+  matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp,
+                  OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType dstType =
-        getTypeConverter()->convertType<VectorType>(extractOp.getType());
-    assert(dstType && "vector type destination expected.");
-    if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
-      return rewriter.notifyMatchFailure(extractOp,
-                                         "scalable vectors are not supported.");
 
-    ArrayAttr offsets = extractOp.getOffsets();
-    ArrayAttr sizes = extractOp.getSizes();
-    ArrayAttr strides = extractOp.getStrides();
-    if (!isConstantIntValue(strides[0], 1))
-      return rewriter.notifyMatchFailure(
-          extractOp, "Strided slice with stride != 1 is not supported.");
-    Value srcVector = adaptor.getVector();
-    // If kD offsets are specified for nD source vector (n > k), the granularity
-    // of the extraction is greater than 1. In this case last (n-k) dimensions
-    // form the extraction granularity.
-    // Example :
-    //  vector.extract_strided_slice %src {
-    //      offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
-    //      vector<4x8x8xf32> to vector<2x2x8xf32>
-    // Here, extraction granularity is 8.
-    int64_t extractGranularitySize = 1;
-    int64_t nD = extractOp.getSourceVectorType().getRank();
-    int64_t kD = (int64_t)offsets.size();
-    int64_t k = kD;
-    while (k < nD) {
-      extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
-      ++k;
+    VectorType flatOutputType = getTypeConverter()->convertType<VectorType>(
+        extractStridedSliceOp.getType());
+    assert(flatOutputType && "vector type expected");
+
+    if (!stridesAllOne(extractStridedSliceOp)) {
+      return rewriter.notifyMatchFailure(extractStridedSliceOp,
+                                         "strides other than 1 not supported");
     }
-    // Get total number of extracted slices.
-    int64_t nExtractedSlices = 1;
-    for (Attribute size : sizes) {
-      nExtractedSlices *= cast<IntegerAttr>(size).getInt();
+
+    FailureOr<SmallVector<int64_t>> offsets =
+        intsFromArrayAttr(extractStridedSliceOp.getOffsets());
+    if (failed(offsets)) {
+      return rewriter.notifyMatchFailure(extractStridedSliceOp,
+                                         "failed to get integer offsets");
     }
-    // Compute the strides of the source vector considering first k dimensions.
-    llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
-    for (int i = kD - 2; i >= 0; --i) {
-      sourceStrides[i] = sourceStrides[i + 1] *
-                         extractOp.getSourceVectorType().getShape()[i + 1];
+
+    ArrayRef<int64_t> inputShape =
+        extractStridedSliceOp.getSourceVectorType().getShape();
+
+    ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
+
+    SmallVector<int64_t> indices = getFlattenedStridedSliceIndices(
+        outputShape, inputShape, offsets.value());
+
+    Value srcVector = adaptor.getVector();
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+        extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
+    return success();
+  }
+};
+
+/// This pattern converts a vector.insert_strided_slice operation into a
+/// vector.shuffle operation that has rank-1 (linearized) operands and result.
+///
+/// For example, the following:
+/// ```
+///  %0 = vector.insert_strided_slice %to_store, %into
+///             {offsets = [1, 0, 0, 0], strides = [1, 1]}
+///                  : vector<2x2xi8> into vector<2x1x3x2xi8>
+/// ```
+///
+/// is converted to
+/// ```
+///  %to_store_1d
+///           = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
+///  %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
+///  %out_1d  = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
+///  %out_nd  = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
+/// ```
+///
+/// where shuffle_indices_1d in this case is
+///     [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
+///                        ^^^^^^^^^^^^^^
+///                          to_store_1d
+///
+struct LinearizeVectorInsertStridedSlice final
+    : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter,
+                                    MLIRContext *context,
+                                    PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!stridesAllOne(insertStridedSliceOp)) {
+      return rewriter.notifyMatchFailure(insertStridedSliceOp,
+                                         "strides other than 1 not supported");
     }
-    // Final shuffle indices has nExtractedSlices * extractGranularitySize
-    // elements.
-    llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
-                                          extractGranularitySize);
-    // Compute the strides of the extracted kD vector.
-    llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
-    // Compute extractedStrides.
-    for (int i = kD - 2; i >= 0; --i) {
-      extractedStrides[i] =
-          extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
+
+    VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+
+    VectorType outputType = insertStridedSliceOp.getType();
+    ArrayRef<int64_t> outputShape = outputType.getShape();
+    int64_t nOutputElements = outputType.getNumElements();
+
+    FailureOr<SmallVector<int64_t>> offsets =
+        intsFromArrayAttr(insertStridedSliceOp.getOffsets());
+    if (failed(offsets)) {
+      return rewriter.notifyMatchFailure(insertStridedSliceOp,
+                                         "failed to get integer offsets");
     }
-    // Iterate over all extracted slices from 0 to nExtractedSlices - 1
-    // and compute the multi-dimensional index and the corresponding linearized
-    // index within the source vector.
-    for (int64_t i = 0; i < nExtractedSlices; ++i) {
-      int64_t index = i;
-      // Compute the corresponding multi-dimensional index.
-      llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
-      for (int64_t j = 0; j < kD; ++j) {
-        multiDimIndex[j] = (index / extractedStrides[j]);
-        index -= multiDimIndex[j] * extractedStrides[j];
-      }
-      // Compute the corresponding linearized index in the source vector
-      // i.e. shift the multiDimIndex by the offsets.
-      int64_t linearizedIndex = 0;
-      for (int64_t j = 0; j < kD; ++j) {
-        linearizedIndex +=
-            (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
-            sourceStrides[j];
-      }
-      // Fill the indices array form linearizedIndex to linearizedIndex +
-      // extractGranularitySize.
-      for (int64_t j = 0; j < extractGranularitySize; ++j) {
-        indices[i * extractGranularitySize + j] = linearizedIndex + j;
-      }
+    SmallVector<int64_t> sliceIndices = getFlattenedStridedSliceIndices(
+        inputShape, outputShape, offsets.value());
+
+    SmallVector<int64_t> indices(nOutputElements, 0);
----------------
newling wrote:

Right, because it's not like std::vector. Forgot that, thanks! 

https://github.com/llvm/llvm-project/pull/138725


More information about the Mlir-commits mailing list