[Mlir-commits] [mlir] 6ed05ed - [mlir][vector] linearize vector.insert_strided_slice (flatten to vector.shuffle) (#138725)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 14 12:13:57 PDT 2025
Author: James Newling
Date: 2025-05-14T12:13:53-07:00
New Revision: 6ed05ed773f9ab1aa31c86d55254884ce48f0915
URL: https://github.com/llvm/llvm-project/commit/6ed05ed773f9ab1aa31c86d55254884ce48f0915
DIFF: https://github.com/llvm/llvm-project/commit/6ed05ed773f9ab1aa31c86d55254884ce48f0915.diff
LOG: [mlir][vector] linearize vector.insert_strided_slice (flatten to vector.shuffle) (#138725)
Extends the set of vector operations that we can linearize to include
vector.insert_strided_slice. The new pattern reuses the ideas from
vector.extract_strided_slice linearization.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
mlir/test/Dialect/Vector/linearize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b9cef003fa365..c7169c5297d9a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -109,17 +109,110 @@ struct LinearizeVectorizable final
}
};
-/// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
-/// on a linearized vector.
-/// Following,
+template <typename TOp>
+static bool stridesAllOne(TOp op) {
+ static_assert(
+ std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
+ std::is_same_v<TOp, vector::InsertStridedSliceOp>,
+ "expected vector.extract_strided_slice or vector.insert_strided_slice");
+ ArrayAttr strides = op.getStrides();
+ return llvm::all_of(
+ strides, [](auto stride) { return isConstantIntValue(stride, 1); });
+}
+
+/// Convert an array of attributes into a vector of integers, if possible.
+static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
+ if (!attrs)
+ return failure();
+ SmallVector<int64_t> ints;
+ ints.reserve(attrs.size());
+ for (auto attr : attrs) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ ints.push_back(intAttr.getInt());
+ } else {
+ return failure();
+ }
+ }
+ return ints;
+}
+
+/// Consider inserting a vector of shape `small` into a vector of shape `large`,
+/// at position `offsets`: this function enumeratates all the indices in `large`
+/// that are written to. The enumeration is with row-major ordering.
+///
+/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
+/// positions written to are (1,3) and (1,4), which have linearized indices 8
+/// and 9. So [8,9] is returned.
+///
+/// The length of the returned vector is equal to the number of elements in
+/// the shape `small` (i.e. the product of dimensions of `small`).
+SmallVector<int64_t> static getStridedSliceInsertionIndices(
+ ArrayRef<int64_t> small, ArrayRef<int64_t> large,
+ ArrayRef<int64_t> offsets) {
+
+ // Example of alignment between, `large`, `small` and `offsets`:
+ // large = 4, 5, 6, 7, 8
+ // small = 1, 6, 7, 8
+ // offsets = 2, 3, 0
+ //
+ // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
+ assert((large.size() >= small.size()) &&
+ "rank of 'large' cannot be lower than rank of 'small'");
+ assert((large.size() >= offsets.size()) &&
+ "rank of 'large' cannot be lower than the number of offsets");
+ unsigned delta = large.size() - small.size();
+ unsigned nOffsets = offsets.size();
+ auto getSmall = [&](int64_t i) -> int64_t {
+ return i >= delta ? small[i - delta] : 1;
+ };
+ auto getOffset = [&](int64_t i) -> int64_t {
+ return i < nOffsets ? offsets[i] : 0;
+ };
+
+ // Using 2 vectors of indices, at each iteration populate the updated set of
+ // indices based on the old set of indices, and the size of the small vector
+ // in the current iteration.
+ SmallVector<int64_t> indices{0};
+ int64_t stride = 1;
+ for (int i = large.size() - 1; i >= 0; --i) {
+ int64_t currentSize = indices.size();
+ int64_t smallSize = getSmall(i);
+ int64_t nextSize = currentSize * smallSize;
+ SmallVector<int64_t> nextIndices(nextSize);
+ int64_t *base = nextIndices.begin();
+ int64_t offset = getOffset(i) * stride;
+ for (int j = 0; j < smallSize; ++j) {
+ for (int k = 0; k < currentSize; ++k) {
+ base[k] = indices[k] + offset;
+ }
+ offset += stride;
+ base += currentSize;
+ }
+ stride *= large[i];
+ indices = std::move(nextIndices);
+ }
+ return indices;
+}
+
+/// This pattern converts a vector.extract_strided_slice operation into a
+/// vector.shuffle operation that has a rank-1 (linearized) operand and result.
+///
+/// For example, the following:
+///
+/// ```
/// vector.extract_strided_slice %source
/// { offsets = [..], strides = [..], sizes = [..] }
+/// ```
+///
/// is converted to :
+/// ```
/// %source_1d = vector.shape_cast %source
-/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
-/// %out_nd = vector.shape_cast %out_1d
-/// `shuffle_indices_1d` is computed using the offsets and sizes of the
-/// extraction.
+/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
+/// %out_nd = vector.shape_cast %out_1d
+/// ```
+///
+/// `shuffle_indices_1d` is computed using the offsets and sizes of the original
+/// vector.extract_strided_slice operation.
struct LinearizeVectorExtractStridedSlice final
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
@@ -129,88 +222,116 @@ struct LinearizeVectorExtractStridedSlice final
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
- matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
+ matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp,
+ OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType dstType =
- getTypeConverter()->convertType<VectorType>(extractOp.getType());
- assert(dstType && "vector type destination expected.");
- if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
- return rewriter.notifyMatchFailure(extractOp,
- "scalable vectors are not supported.");
- ArrayAttr offsets = extractOp.getOffsets();
- ArrayAttr sizes = extractOp.getSizes();
- ArrayAttr strides = extractOp.getStrides();
- if (!isConstantIntValue(strides[0], 1))
+ VectorType flatOutputType = getTypeConverter()->convertType<VectorType>(
+ extractStridedSliceOp.getType());
+ assert(flatOutputType && "vector type expected");
+
+ // Expect a legalization failure if the strides are not all 1 (if ever the
+ // verifier for extract_strided_slice allows non-1 strides).
+ if (!stridesAllOne(extractStridedSliceOp)) {
return rewriter.notifyMatchFailure(
- extractOp, "Strided slice with stride != 1 is not supported.");
- Value srcVector = adaptor.getVector();
- // If kD offsets are specified for nD source vector (n > k), the granularity
- // of the extraction is greater than 1. In this case last (n-k) dimensions
- // form the extraction granularity.
- // Example :
- // vector.extract_strided_slice %src {
- // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
- // vector<4x8x8xf32> to vector<2x2x8xf32>
- // Here, extraction granularity is 8.
- int64_t extractGranularitySize = 1;
- int64_t nD = extractOp.getSourceVectorType().getRank();
- int64_t kD = (int64_t)offsets.size();
- int64_t k = kD;
- while (k < nD) {
- extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
- ++k;
+ extractStridedSliceOp,
+ "extract_strided_slice with strides != 1 not supported");
}
- // Get total number of extracted slices.
- int64_t nExtractedSlices = 1;
- for (Attribute size : sizes) {
- nExtractedSlices *= cast<IntegerAttr>(size).getInt();
+
+ FailureOr<SmallVector<int64_t>> offsets =
+ intsFromArrayAttr(extractStridedSliceOp.getOffsets());
+ if (failed(offsets)) {
+ return rewriter.notifyMatchFailure(extractStridedSliceOp,
+ "failed to get integer offsets");
}
- // Compute the strides of the source vector considering first k dimensions.
- llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
- for (int i = kD - 2; i >= 0; --i) {
- sourceStrides[i] = sourceStrides[i + 1] *
- extractOp.getSourceVectorType().getShape()[i + 1];
+
+ ArrayRef<int64_t> inputShape =
+ extractStridedSliceOp.getSourceVectorType().getShape();
+
+ ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
+
+ SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
+ outputShape, inputShape, offsets.value());
+
+ Value srcVector = adaptor.getVector();
+ rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+ extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
+ return success();
+ }
+};
+
+/// This pattern converts a vector.insert_strided_slice operation into a
+/// vector.shuffle operation that has rank-1 (linearized) operands and result.
+///
+/// For example, the following:
+/// ```
+/// %0 = vector.insert_strided_slice %to_store, %into
+/// {offsets = [1, 0, 0, 0], strides = [1, 1]}
+/// : vector<2x2xi8> into vector<2x1x3x2xi8>
+/// ```
+///
+/// is converted to
+/// ```
+/// %to_store_1d
+/// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
+/// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
+/// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
+/// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
+/// ```
+///
+/// where shuffle_indices_1d in this case is
+/// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
+/// ^^^^^^^^^^^^^^
+/// to_store_1d
+///
+struct LinearizeVectorInsertStridedSlice final
+ : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter,
+ MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // Expect a legalization failure if the strides are not all 1 (if ever the
+ // verifier for insert_strided_slice allows non-1 strides).
+ if (!stridesAllOne(insertStridedSliceOp)) {
+ return rewriter.notifyMatchFailure(
+ insertStridedSliceOp,
+ "insert_strided_slice with strides != 1 not supported");
}
- // Final shuffle indices has nExtractedSlices * extractGranularitySize
- // elements.
- llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
- extractGranularitySize);
- // Compute the strides of the extracted kD vector.
- llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
- // Compute extractedStrides.
- for (int i = kD - 2; i >= 0; --i) {
- extractedStrides[i] =
- extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
+
+ VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+
+ VectorType outputType = insertStridedSliceOp.getType();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ int64_t nOutputElements = outputType.getNumElements();
+
+ FailureOr<SmallVector<int64_t>> offsets =
+ intsFromArrayAttr(insertStridedSliceOp.getOffsets());
+ if (failed(offsets)) {
+ return rewriter.notifyMatchFailure(insertStridedSliceOp,
+ "failed to get integer offsets");
}
- // Iterate over all extracted slices from 0 to nExtractedSlices - 1
- // and compute the multi-dimensional index and the corresponding linearized
- // index within the source vector.
- for (int64_t i = 0; i < nExtractedSlices; ++i) {
- int64_t index = i;
- // Compute the corresponding multi-dimensional index.
- llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
- for (int64_t j = 0; j < kD; ++j) {
- multiDimIndex[j] = (index / extractedStrides[j]);
- index -= multiDimIndex[j] * extractedStrides[j];
- }
- // Compute the corresponding linearized index in the source vector
- // i.e. shift the multiDimIndex by the offsets.
- int64_t linearizedIndex = 0;
- for (int64_t j = 0; j < kD; ++j) {
- linearizedIndex +=
- (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
- sourceStrides[j];
- }
- // Fill the indices array form linearizedIndex to linearizedIndex +
- // extractGranularitySize.
- for (int64_t j = 0; j < extractGranularitySize; ++j) {
- indices[i * extractGranularitySize + j] = linearizedIndex + j;
- }
+ SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices(
+ inputShape, outputShape, offsets.value());
+
+ SmallVector<int64_t> indices(nOutputElements);
+ std::iota(indices.begin(), indices.end(), 0);
+ for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) {
+ indices[sliceIndex] = index + nOutputElements;
}
- // Perform a shuffle to extract the kD vector.
- rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
- extractOp, dstType, srcVector, srcVector, indices);
+
+ Value flatToStore = adaptor.getValueToStore();
+ Value flatDest = adaptor.getDest();
+ rewriter.replaceOpWithNewOp<vector::ShuffleOp>(insertStridedSliceOp,
+ flatDest.getType(), flatDest,
+ flatToStore, indices);
return success();
}
};
@@ -296,7 +417,7 @@ struct LinearizeVectorExtract final
// Skip if result is not a vector type
if (!isa<VectorType>(extractOp.getType()))
return rewriter.notifyMatchFailure(extractOp,
- "scalar extract is not supported.");
+ "scalar extract not supported");
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
assert(dstTy && "expected 1-D vector type");
@@ -453,8 +574,8 @@ struct LinearizeVectorSplat final
static bool isNotLinearizableBecauseScalable(Operation *op) {
bool unsupported =
- isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
- op);
+ isa<vector::ExtractStridedSliceOp, vector::InsertStridedSliceOp,
+ vector::ExtractOp, vector::InsertOp>(op);
if (!unsupported)
return false;
@@ -539,6 +660,7 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &typeConverter, const ConversionTarget &target,
RewritePatternSet &patterns) {
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
- LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
- typeConverter, patterns.getContext());
+ LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
+ LinearizeVectorInsertStridedSlice>(typeConverter,
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 01ad1ac48b012..3cdbef8db604b 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s
// CHECK-LABEL: test_linearize
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
@@ -131,9 +131,9 @@ func.func @test_0d_vector() -> vector<f32> {
// -----
-// CHECK-LABEL: test_extract_strided_slice_1
+// CHECK-LABEL: test_extract_strided_slice_2D
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> {
-func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
+func.func @test_extract_strided_slice_2D(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
@@ -147,13 +147,13 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
// -----
-// CHECK-LABEL: func.func @test_extract_strided_slice_1_scalable(
+// CHECK-LABEL: func.func @test_extract_strided_slice_2D_scalable(
// CHECK-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
-func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
+func.func @test_extract_strided_slice_2D_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
// CHECK-NOT: vector.shuffle
// CHECK-NOT: vector.shape_cast
- // CHECK: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32>
+ // CHECK: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]]
%0 = vector.extract_strided_slice %arg0 { sizes = [2, 8], strides = [1, 1], offsets = [1, 0] } : vector<4x[8]xf32> to vector<2x[8]xf32>
// CHECK: return %[[RES]] : vector<2x[8]xf32>
@@ -162,9 +162,9 @@ func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> ve
// -----
-// CHECK-LABEL: test_extract_strided_slice_2
+// CHECK-LABEL: test_extract_strided_slice_3D
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
-func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
+func.func @test_extract_strided_slice_3D(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
@@ -178,6 +178,76 @@ func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4
// -----
+// Test of insert_strided_slice -> shuffle.
+// This is a contiguous insertion of 4 elements at offset 6 into a vector of 12 elements.
+// CHECK-LABEL: insert_strided_slice_2D_into_4D
+func.func @insert_strided_slice_2D_into_4D(%arg0 : vector<2x2xi8>, %arg1 : vector<2x1x3x2xi8>) -> vector<2x1x3x2xi8> {
+
+// CHECK-DAG: %[[ARG0:.*]] = vector.shape_cast {{.*}} to vector<4xi8>
+// CHECK-DAG: %[[ARG1:.*]] = vector.shape_cast {{.*}} to vector<12xi8>
+// CHECK: vector.shuffle %[[ARG1]], %[[ARG0]]
+// CHECK-SAME: [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11] : vector<12xi8>, vector<4xi8>
+ %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1, 0, 0, 0], strides = [1, 1]} : vector<2x2xi8> into vector<2x1x3x2xi8>
+
+// CHECK: %[[RES:.*]] = vector.shape_cast {{.*}} to vector<2x1x3x2xi8>
+// CHECK: return %[[RES]] : vector<2x1x3x2xi8>
+ return %0 : vector<2x1x3x2xi8>
+}
+
+// -----
+
+// Test of insert_strided_slice -> shuffle.
+// [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]], [[12, 13], [14, 15]], [[16, 17]]]
+// ^ ^
+// | |
+// where the 2 elements are inserted into the 3x3x2 vector
+// CHECK-LABEL: insert_strided_slice_3D
+func.func @insert_strided_slice_3D(%arg0 : vector<1x2x1xi8>, %arg1 : vector<3x3x2xi8>) -> vector<3x3x2xi8> {
+
+// CHECK-DAG: %[[ARG0:.*]] = vector.shape_cast {{.*}} to vector<2xi8>
+// CHECK-DAG: %[[ARG1:.*]] = vector.shape_cast {{.*}} to vector<18xi8>
+// CHECK: vector.shuffle %[[ARG1]], %[[ARG0]]
+// CHECK-SAME: [0, 1, 2, 3, 4, 5, 6, 7, 8, 18, 10, 19, 12, 13, 14, 15, 16, 17] : vector<18xi8>, vector<2xi8>
+ %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1, 1, 1], sizes = [1, 2, 1], strides = [1, 1, 1]} : vector<1x2x1xi8> into vector<3x3x2xi8>
+
+// CHECK: %[[RES:.*]] = vector.shape_cast {{.*}} to vector<3x3x2xi8>
+// CHECK: return %[[RES]] : vector<3x3x2xi8>
+ return %0 : vector<3x3x2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: insert_strided_slice_2D_higher_offsets
+func.func @insert_strided_slice_2D_higher_offsets(%arg0 : vector<2x1xi8>, %arg1 : vector<2x2xi8>, %arg2 : vector<5x2xi8>) -> vector<5x2xi8> {
+
+ // CHECK: [0, 1, 2, 3, 10, 11, 12, 13, 8, 9]
+ // ^^^ ^^^ ^^^ ^^^
+ // insertion indices
+ %0 = vector.insert_strided_slice %arg1, %arg2 {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x2xi8> into vector<5x2xi8>
+
+ // CHECK: [0, 1, 2, 3, 10, 5, 11, 7, 8, 9]
+ // ^^^ ^^^
+ %1 = vector.insert_strided_slice %arg0, %0 {offsets = [2, 0], sizes = [2, 1], strides = [1, 1]} : vector<2x1xi8> into vector<5x2xi8>
+
+ // CHECK: [0, 1, 2, 3, 4, 5, 6, 10, 8, 11]
+ // ^^^ ^^^
+ %2 = vector.insert_strided_slice %arg0, %1 {offsets = [3, 1], sizes = [2, 1], strides = [1, 1]} : vector<2x1xi8> into vector<5x2xi8>
+
+ return %2 : vector<5x2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: negative_insert_strided_slice_scalable
+// CHECK-NOT: vector.shuffle
+// CHECK: return
+func.func @negative_insert_strided_slice_scalable(%arg0 : vector<1x[2]xi8>, %arg1 : vector<2x[2]xi8>) -> vector<2x[2]xi8> {
+ %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0], strides = [1,1]} : vector<1x[2]xi8> into vector<2x[2]xi8>
+ return %0 : vector<2x[2]xi8>
+}
+
+// -----
+
// CHECK-LABEL: test_vector_shuffle
// CHECK-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> {
func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> {
@@ -345,3 +415,4 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
%0 = vector.splat %arg0 : vector<4x[2]xi32>
return %0 : vector<4x[2]xi32>
}
+
More information about the Mlir-commits
mailing list