[Mlir-commits] [mlir] [MLIR][Vector] Add unrolling support for bitcast, interleave, and deinterleave ops (PR #194513)
Jianhui Li
llvmlistbot at llvm.org
Thu Apr 30 15:59:23 PDT 2026
https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/194513
>From 264b0a9149baa11d135aea54bccdb8ad1508af47 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 23 Apr 2026 22:45:27 +0000
Subject: [PATCH 1/6] [mlir][Vector] Add unrolling support for bitcast,
interleave, and deinterleave ops
This patch adds VectorUnrollOpInterface implementations and unrolling patterns
for vector bitcast, interleave, and deinterleave operations.
- UnrollBitCastPattern: Unrolls bitcast by adjusting tile shapes based on element
type bitwidth ratios
- UnrollInterleavePattern: Unrolls interleave ops which double the trailing dimension
- UnrollDeinterleavePattern: Unrolls deinterleave ops which halve the trailing dimension
These patterns enable fine-grained tiling of vector transformations across different
element type conversions and data layout transformations.
Co-Authored-By: Claude Sonnet 4.5 <noreply at anthropic.com>
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 9 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 20 ++
.../Vector/Transforms/VectorUnroll.cpp | 214 +++++++++++++++++-
.../Dialect/Vector/vector-unroll-options.mlir | 72 ++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 18 ++
5 files changed, 328 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 68ef49172e662..74b49db36c6cd 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -533,7 +533,8 @@ def ResultIsDoubleSourceVectorType : TypesMatchWith<
def Vector_InterleaveOp :
Vector_Op<"interleave", [Pure, AllTypesMatch<["lhs", "rhs"]>,
- ResultIsDoubleSourceVectorType]> {
+ ResultIsDoubleSourceVectorType,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>]> {
let summary = "constructs a vector by interleaving two input vectors";
let description = [{
The interleave operation constructs a new vector by interleaving the
@@ -609,7 +610,8 @@ def Vector_DeinterleaveOp :
Vector_Op<"deinterleave", [Pure,
SourceVectorEvenElementCount,
ResultIsHalfSourceVectorType<"res1">,
- AllTypesMatch<["res1", "res2"]>
+ AllTypesMatch<["res1", "res2"]>,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]> {
let summary = "constructs two vectors by deinterleaving an input vector";
let description = [{
@@ -2464,7 +2466,8 @@ def Vector_ShapeCastOp :
}
def Vector_BitCastOp :
- Vector_Op<"bitcast", [Pure, AllRanksMatch<["source", "result"]>]>,
+ Vector_Op<"bitcast", [Pure, AllRanksMatch<["source", "result"]>,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>]>,
Arguments<(ins AnyVectorOfNonI0Elem:$source)>,
Results<(outs AnyVectorOfNonI0Elem:$result)>{
let summary = "bitcast casts between vectors";
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3d3e49134363f..a7bd498299727 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7073,6 +7073,10 @@ OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
return {};
}
+std::optional<SmallVector<int64_t, 4>> BitCastOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
//===----------------------------------------------------------------------===//
// TypeCastOp
//===----------------------------------------------------------------------===//
@@ -8319,6 +8323,22 @@ Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
mask, newValue, passthru);
}
+//===----------------------------------------------------------------------===//
+// InterleaveOp
+//===----------------------------------------------------------------------===//
+
+std::optional<SmallVector<int64_t, 4>> InterleaveOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
+//===----------------------------------------------------------------------===//
+// DeinterleaveOp
+//===----------------------------------------------------------------------===//
+
+std::optional<SmallVector<int64_t, 4>> DeinterleaveOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index ec08f01d2a4b9..58eccc9301248 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1389,6 +1389,214 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
vector::UnrollVectorOptions options;
};
+struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
+ UnrollBitCastPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::BitCastOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, bitCastOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType sourceType = bitCastOp.getSourceVectorType();
+ VectorType resultType = bitCastOp.getResultVectorType();
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+ Location loc = bitCastOp.getLoc();
+
+ // Bail out if target shape rank doesn't match result rank
+ if (targetShape->size() != resultShape.size())
+ return rewriter.notifyMatchFailure(
+ bitCastOp, "target shape rank must match result rank");
+
+ // BitCast changes element type, which may change the trailing dimension.
+ // For the source, deduce the tile shape from the result tile shape.
+ // The relationship: if result trailing dim is R and source is S,
+ // then resultBitWidth / R = sourceBitWidth / S (same bits per element).
+
+ unsigned sourceElementBits = sourceType.getElementTypeBitWidth();
+ unsigned resultElementBits = resultType.getElementTypeBitWidth();
+
+ // Deduce source tile shape: same as target except the trailing dimension
+ SmallVector<int64_t> sourceTileShape(targetShape->begin(),
+ targetShape->end());
+ int64_t lastDim = sourceTileShape.size() - 1;
+
+ // Scale the trailing dimension by the bitwidth ratio
+ sourceTileShape[lastDim] =
+ ((*targetShape)[lastDim] * resultElementBits) / sourceElementBits;
+
+ // Prepare the result vector
+ Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
+ SmallVector<int64_t> resultStrides(targetShape->size(), 1);
+ SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
+
+ VectorType targetType =
+ VectorType::get(*targetShape, resultType.getElementType());
+
+ // Unroll the bitcast
+ for (SmallVector<int64_t> resultOffsets :
+ StaticTileOffsetRange(resultShape, *targetShape)) {
+ // Compute corresponding source offsets
+ SmallVector<int64_t> sourceOffsets = resultOffsets;
+ sourceOffsets[lastDim] =
+ (resultOffsets[lastDim] * resultElementBits) / sourceElementBits;
+
+ Value sourceSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, bitCastOp.getSource(), sourceOffsets, sourceTileShape,
+ sourceStrides);
+ Value bitcastSlice = rewriter.createOrFold<vector::BitCastOp>(
+ loc, targetType, sourceSlice);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, bitcastSlice, result, resultOffsets, resultStrides);
+ }
+
+ rewriter.replaceOp(bitCastOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
+struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
+ UnrollInterleavePattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::InterleaveOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::InterleaveOp interleaveOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, interleaveOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType resultType = interleaveOp.getResultVectorType();
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+ Location loc = interleaveOp.getLoc();
+
+ // Bail out if target shape rank doesn't match result rank
+ if (targetShape->size() != resultShape.size())
+ return rewriter.notifyMatchFailure(
+ interleaveOp, "target shape rank must match result rank");
+
+ // Interleave doubles the trailing dimension: [N] -> [2*N]
+ // For source tile shape, halve the trailing dimension of target shape
+ SmallVector<int64_t> sourceTileShape(targetShape->begin(),
+ targetShape->end());
+ int64_t lastDim = sourceTileShape.size() - 1;
+ sourceTileShape[lastDim] = (*targetShape)[lastDim] / 2;
+
+ // Prepare the result vector
+ Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
+ SmallVector<int64_t> resultStrides(targetShape->size(), 1);
+ SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
+
+ VectorType targetType =
+ VectorType::get(*targetShape, resultType.getElementType());
+
+ // Unroll the interleave
+ for (SmallVector<int64_t> resultOffsets :
+ StaticTileOffsetRange(resultShape, *targetShape)) {
+ // Compute corresponding source offsets
+ SmallVector<int64_t> sourceOffsets = resultOffsets;
+ sourceOffsets[lastDim] = resultOffsets[lastDim] / 2;
+
+ Value lhsSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, interleaveOp.getLhs(), sourceOffsets, sourceTileShape,
+ sourceStrides);
+ Value rhsSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, interleaveOp.getRhs(), sourceOffsets, sourceTileShape,
+ sourceStrides);
+ Value interleaveSlice = rewriter.createOrFold<vector::InterleaveOp>(
+ loc, targetType, lhsSlice, rhsSlice);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, interleaveSlice, result, resultOffsets, resultStrides);
+ }
+
+ rewriter.replaceOp(interleaveOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
+struct UnrollDeinterleavePattern
+ : public OpRewritePattern<vector::DeinterleaveOp> {
+ UnrollDeinterleavePattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::DeinterleaveOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::DeinterleaveOp deinterleaveOp,
+ PatternRewriter &rewriter) const override {
+ // Get target shape based on the result type (res1)
+ auto targetShape = getTargetShape(options, deinterleaveOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType resultType = deinterleaveOp.getResultVectorType();
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+ Location loc = deinterleaveOp.getLoc();
+
+ // Bail out if target shape rank doesn't match result rank
+ if (targetShape->size() != resultShape.size())
+ return rewriter.notifyMatchFailure(
+ deinterleaveOp, "target shape rank must match result rank");
+
+ // Deinterleave halves the trailing dimension: [2*N] -> [N]
+ // For source tile shape, double the trailing dimension of target shape
+ SmallVector<int64_t> sourceTileShape(targetShape->begin(),
+ targetShape->end());
+ int64_t lastDim = sourceTileShape.size() - 1;
+ sourceTileShape[lastDim] = (*targetShape)[lastDim] * 2;
+
+ // Prepare the result vectors
+ Value result1 = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
+ Value result2 = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
+ SmallVector<int64_t> resultStrides(targetShape->size(), 1);
+ SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
+
+ // Unroll the deinterleave
+ for (SmallVector<int64_t> resultOffsets :
+ StaticTileOffsetRange(resultShape, *targetShape)) {
+ // Compute corresponding source offsets
+ SmallVector<int64_t> sourceOffsets = resultOffsets;
+ sourceOffsets[lastDim] = resultOffsets[lastDim] * 2;
+
+ Value sourceSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, deinterleaveOp.getSource(), sourceOffsets, sourceTileShape,
+ sourceStrides);
+
+ auto deinterleaveSlice =
+ vector::DeinterleaveOp::create(rewriter, loc, sourceSlice);
+
+ result1 = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, deinterleaveSlice.getRes1(), result1, resultOffsets,
+ resultStrides);
+ result2 = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, deinterleaveSlice.getRes2(), result2, resultOffsets,
+ resultStrides);
+ }
+
+ rewriter.replaceOp(deinterleaveOp, ValueRange{result1, result2});
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
@@ -1400,8 +1608,10 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
- UnrollCreateMaskPattern, UnrollConstantMaskPattern>(
- patterns.getContext(), options, benefit);
+ UnrollCreateMaskPattern, UnrollConstantMaskPattern,
+ UnrollBitCastPattern, UnrollInterleavePattern,
+ UnrollDeinterleavePattern>(patterns.getContext(), options,
+ benefit);
}
void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 036d09053552d..358d39d4ff2dd 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -667,3 +667,75 @@ func.func @shape_cast_with_all_unit_target_shape(%v: vector<2xf32>) -> vector<2x
// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<1xf32> to vector<1x1xf32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0], strides = [1, 1]} : vector<1x1xf32> into vector<2x1xf32>
// CHECK: return %[[I1]] : vector<2x1xf32>
+
+// -----
+
+// Test BitCastOp unrolling - target shape [4, 4]
+func.func @bitcast_unroll(%arg0: vector<8x4xf32>) -> vector<8x8xi16> {
+ %0 = vector.bitcast %arg0 : vector<8x4xf32> to vector<8x8xi16>
+ return %0 : vector<8x8xi16>
+}
+// CHECK-LABEL: func @bitcast_unroll
+// CHECK-SAME: (%[[ARG:.*]]: vector<8x4xf32>) -> vector<8x8xi16>
+// CHECK: %[[INIT:.*]] = arith.constant dense<0> : vector<8x8xi16>
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
+// CHECK: %[[BC0:.*]] = vector.bitcast %[[S0]] : vector<4x2xf32> to vector<4x4xi16>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[BC0]], %[[INIT]] {offsets = [0, 0], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [0, 2], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
+// CHECK: %[[BC1:.*]] = vector.bitcast %[[S1]] : vector<4x2xf32> to vector<4x4xi16>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[BC1]], %[[I0]] {offsets = [0, 4], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
+// CHECK: %[[S2:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
+// CHECK: %[[BC2:.*]] = vector.bitcast %[[S2]] : vector<4x2xf32> to vector<4x4xi16>
+// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[BC2]], %[[I1]] {offsets = [4, 0], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
+// CHECK: %[[S3:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [4, 2], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
+// CHECK: %[[BC3:.*]] = vector.bitcast %[[S3]] : vector<4x2xf32> to vector<4x4xi16>
+// CHECK: %[[I3:.*]] = vector.insert_strided_slice %[[BC3]], %[[I2]] {offsets = [4, 4], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
+// CHECK: return %[[I3]] : vector<8x8xi16>
+
+// -----
+
+// Test InterleaveOp unrolling - target shape [8]
+func.func @interleave_unroll(%arg0: vector<16xi32>, %arg1: vector<16xi32>) -> vector<32xi32> {
+ %0 = vector.interleave %arg0, %arg1 : vector<16xi32> -> vector<32xi32>
+ return %0 : vector<32xi32>
+}
+// CHECK-LABEL: func @interleave_unroll
+// CHECK-SAME: (%[[LHS:.*]]: vector<16xi32>, %[[RHS:.*]]: vector<16xi32>) -> vector<32xi32>
+// CHECK: %[[INIT:.*]] = arith.constant dense<0> : vector<32xi32>
+// CHECK: %[[L0:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
+// CHECK: %[[R0:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
+// CHECK: %[[INT0:.*]] = vector.interleave %[[L0]], %[[R0]] : vector<4xi32> -> vector<8xi32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[INT0]], %[[INIT]] {offsets = [0], strides = [1]} : vector<8xi32> into vector<32xi32>
+// CHECK: %[[L1:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
+// CHECK: %[[R1:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
+// CHECK: %[[INT1:.*]] = vector.interleave %[[L1]], %[[R1]] : vector<4xi32> -> vector<8xi32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[INT1]], %[[I0]] {offsets = [8], strides = [1]} : vector<8xi32> into vector<32xi32>
+// CHECK: %[[L2:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [8], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
+// CHECK: %[[R2:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [8], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
+// CHECK: %[[INT2:.*]] = vector.interleave %[[L2]], %[[R2]] : vector<4xi32> -> vector<8xi32>
+// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[INT2]], %[[I1]] {offsets = [16], strides = [1]} : vector<8xi32> into vector<32xi32>
+// CHECK: %[[L3:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [12], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
+// CHECK: %[[R3:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [12], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
+// CHECK: %[[INT3:.*]] = vector.interleave %[[L3]], %[[R3]] : vector<4xi32> -> vector<8xi32>
+// CHECK: %[[I3:.*]] = vector.insert_strided_slice %[[INT3]], %[[I2]] {offsets = [24], strides = [1]} : vector<8xi32> into vector<32xi32>
+// CHECK: return %[[I3]] : vector<32xi32>
+
+// -----
+
+// Test DeinterleaveOp unrolling - target shape [4]
+func.func @deinterleave_unroll(%arg0: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>) {
+ %0, %1 = vector.deinterleave %arg0 : vector<16xi32> -> vector<8xi32>
+ return %0, %1 : vector<8xi32>, vector<8xi32>
+}
+// CHECK-LABEL: func @deinterleave_unroll
+// CHECK-SAME: (%[[ARG:.*]]: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>)
+// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xi32> to vector<8xi32>
+// CHECK: {{.*}} = vector.deinterleave %[[S0]] : vector<8xi32> -> vector<4xi32>
+// CHECK: {{.*}} = vector.insert_strided_slice {{.*}}, %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+// CHECK: {{.*}} = vector.insert_strided_slice {{.*}}, %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xi32> to vector<8xi32>
+// CHECK: {{.*}} = vector.deinterleave %[[S1]] : vector<8xi32> -> vector<4xi32>
+// CHECK: {{.*}} = vector.insert_strided_slice {{.*}}, {{.*}} {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+// CHECK: {{.*}} = vector.insert_strided_slice {{.*}}, {{.*}} {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+// CHECK: return {{.*}}, {{.*}} : vector<8xi32>, vector<8xi32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index ff3520a286cc8..fe31d6b3e9639 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -218,6 +218,24 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(isa<vector::TransposeOp>(op));
}));
+ populateVectorUnrollPatterns(
+ patterns, UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{4, 4})
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::BitCastOp>(op));
+ }));
+ populateVectorUnrollPatterns(
+ patterns, UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{8})
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::InterleaveOp>(op));
+ }));
+ populateVectorUnrollPatterns(
+ patterns, UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{4})
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::DeinterleaveOp>(op));
+ }));
if (unrollBasedOnType) {
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
>From 25bd00ac3d82d5d1970ec1394f3656b18fbee7ff Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 28 Apr 2026 01:58:33 +0000
Subject: [PATCH 2/6] add comments
---
.../Vector/Transforms/VectorUnroll.cpp | 56 +++++++++++--------
1 file changed, 32 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 58eccc9301248..b189f163d0660 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1389,6 +1389,11 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
vector::UnrollVectorOptions options;
};
+// Unroll vector::BitCastOp into smaller tile-based bitcast operations.
+// Tiles the result vector into target shape chunks and bitcasts corresponding
+// source slices, accounting for element bitwidth ratios.
+// Example: bitcast v8f32 to v16f16 with target shape [4] unrolls into
+// multiple bitcast operations on 4-element tiles.
struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
UnrollBitCastPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
@@ -1407,29 +1412,20 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
ArrayRef<int64_t> resultShape = resultType.getShape();
Location loc = bitCastOp.getLoc();
- // Bail out if target shape rank doesn't match result rank
if (targetShape->size() != resultShape.size())
return rewriter.notifyMatchFailure(
bitCastOp, "target shape rank must match result rank");
- // BitCast changes element type, which may change the trailing dimension.
- // For the source, deduce the tile shape from the result tile shape.
- // The relationship: if result trailing dim is R and source is S,
- // then resultBitWidth / R = sourceBitWidth / S (same bits per element).
-
unsigned sourceElementBits = sourceType.getElementTypeBitWidth();
unsigned resultElementBits = resultType.getElementTypeBitWidth();
- // Deduce source tile shape: same as target except the trailing dimension
SmallVector<int64_t> sourceTileShape(targetShape->begin(),
targetShape->end());
int64_t lastDim = sourceTileShape.size() - 1;
- // Scale the trailing dimension by the bitwidth ratio
sourceTileShape[lastDim] =
((*targetShape)[lastDim] * resultElementBits) / sourceElementBits;
- // Prepare the result vector
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
rewriter.getZeroAttr(resultType));
SmallVector<int64_t> resultStrides(targetShape->size(), 1);
@@ -1438,10 +1434,8 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
VectorType targetType =
VectorType::get(*targetShape, resultType.getElementType());
- // Unroll the bitcast
for (SmallVector<int64_t> resultOffsets :
StaticTileOffsetRange(resultShape, *targetShape)) {
- // Compute corresponding source offsets
SmallVector<int64_t> sourceOffsets = resultOffsets;
sourceOffsets[lastDim] =
(resultOffsets[lastDim] * resultElementBits) / sourceElementBits;
@@ -1463,6 +1457,18 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
vector::UnrollVectorOptions options;
};
+/// Pattern to unroll vector.interleave into smaller tile-sized operations.
+/// Decomposes a large interleave into tiles by extracting slices from both
+/// input vectors, interleaving them, and inserting back into the result.
+///
+/// Example:
+/// vector.interleave %lhs, %rhs : vector<8xf32>
+/// // Unrolled with target shape [4]:
+/// %slice_lhs_0 = vector.extract_strided_slice %lhs[0] : vector<2xf32>
+/// %slice_rhs_0 = vector.extract_strided_slice %rhs[0] : vector<2xf32>
+/// %tile_0 = vector.interleave %slice_lhs_0, %slice_rhs_0 : vector<4xf32>
+/// %result = vector.insert_strided_slice %tile_0, %init[0]
+/// // ... repeat for remaining tiles
struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
UnrollInterleavePattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
@@ -1480,19 +1486,15 @@ struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
ArrayRef<int64_t> resultShape = resultType.getShape();
Location loc = interleaveOp.getLoc();
- // Bail out if target shape rank doesn't match result rank
if (targetShape->size() != resultShape.size())
return rewriter.notifyMatchFailure(
interleaveOp, "target shape rank must match result rank");
- // Interleave doubles the trailing dimension: [N] -> [2*N]
- // For source tile shape, halve the trailing dimension of target shape
SmallVector<int64_t> sourceTileShape(targetShape->begin(),
targetShape->end());
int64_t lastDim = sourceTileShape.size() - 1;
sourceTileShape[lastDim] = (*targetShape)[lastDim] / 2;
- // Prepare the result vector
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
rewriter.getZeroAttr(resultType));
SmallVector<int64_t> resultStrides(targetShape->size(), 1);
@@ -1501,10 +1503,8 @@ struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
VectorType targetType =
VectorType::get(*targetShape, resultType.getElementType());
- // Unroll the interleave
for (SmallVector<int64_t> resultOffsets :
StaticTileOffsetRange(resultShape, *targetShape)) {
- // Compute corresponding source offsets
SmallVector<int64_t> sourceOffsets = resultOffsets;
sourceOffsets[lastDim] = resultOffsets[lastDim] / 2;
@@ -1528,6 +1528,21 @@ struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
vector::UnrollVectorOptions options;
};
+/// Pattern to unroll vector.deinterleave into smaller tile-sized operations.
+/// Decomposes a large deinterleave (which splits a vector into even/odd halves)
+/// by extracting source slices, deinterleaving them, and inserting into two
+/// result vectors.
+///
+/// Example:
+/// %res1, %res2 = vector.deinterleave %src : vector<8xf32>
+/// // Result: %res1 = [src[0], src[2], src[4], src[6]]
+/// // %res2 = [src[1], src[3], src[5], src[7]]
+/// // Unrolled with target shape [2]:
+/// %slice_0 = vector.extract_strided_slice %src[0] : vector<4xf32>
+/// %tile1_0, %tile2_0 = vector.deinterleave %slice_0 : vector<2xf32>
+/// %result1 = vector.insert_strided_slice %tile1_0, %init1[0]
+/// %result2 = vector.insert_strided_slice %tile2_0, %init2[0]
+/// // ... repeat for remaining tiles
struct UnrollDeinterleavePattern
: public OpRewritePattern<vector::DeinterleaveOp> {
UnrollDeinterleavePattern(MLIRContext *context,
@@ -1538,7 +1553,6 @@ struct UnrollDeinterleavePattern
LogicalResult matchAndRewrite(vector::DeinterleaveOp deinterleaveOp,
PatternRewriter &rewriter) const override {
- // Get target shape based on the result type (res1)
auto targetShape = getTargetShape(options, deinterleaveOp);
if (!targetShape)
return failure();
@@ -1547,19 +1561,15 @@ struct UnrollDeinterleavePattern
ArrayRef<int64_t> resultShape = resultType.getShape();
Location loc = deinterleaveOp.getLoc();
- // Bail out if target shape rank doesn't match result rank
if (targetShape->size() != resultShape.size())
return rewriter.notifyMatchFailure(
deinterleaveOp, "target shape rank must match result rank");
- // Deinterleave halves the trailing dimension: [2*N] -> [N]
- // For source tile shape, double the trailing dimension of target shape
SmallVector<int64_t> sourceTileShape(targetShape->begin(),
targetShape->end());
int64_t lastDim = sourceTileShape.size() - 1;
sourceTileShape[lastDim] = (*targetShape)[lastDim] * 2;
- // Prepare the result vectors
Value result1 = arith::ConstantOp::create(rewriter, loc, resultType,
rewriter.getZeroAttr(resultType));
Value result2 = arith::ConstantOp::create(rewriter, loc, resultType,
@@ -1567,10 +1577,8 @@ struct UnrollDeinterleavePattern
SmallVector<int64_t> resultStrides(targetShape->size(), 1);
SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
- // Unroll the deinterleave
for (SmallVector<int64_t> resultOffsets :
StaticTileOffsetRange(resultShape, *targetShape)) {
- // Compute corresponding source offsets
SmallVector<int64_t> sourceOffsets = resultOffsets;
sourceOffsets[lastDim] = resultOffsets[lastDim] * 2;
>From 360400d3ac2786bfa82e814921e253a3130ed93d Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 29 Apr 2026 00:47:39 +0000
Subject: [PATCH 3/6] address feedback
---
.../Vector/Transforms/VectorUnroll.cpp | 84 ++++++++++---------
.../Dialect/Vector/vector-unroll-options.mlir | 54 ++++++++----
2 files changed, 83 insertions(+), 55 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index b189f163d0660..4d1d39cd9d61d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1405,7 +1405,8 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
PatternRewriter &rewriter) const override {
auto targetShape = getTargetShape(options, bitCastOp);
if (!targetShape)
- return failure();
+ return rewriter.notifyMatchFailure(bitCastOp,
+ "failed to get target shape");
VectorType sourceType = bitCastOp.getSourceVectorType();
VectorType resultType = bitCastOp.getResultVectorType();
@@ -1419,17 +1420,17 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
unsigned sourceElementBits = sourceType.getElementTypeBitWidth();
unsigned resultElementBits = resultType.getElementTypeBitWidth();
- SmallVector<int64_t> sourceTileShape(targetShape->begin(),
- targetShape->end());
- int64_t lastDim = sourceTileShape.size() - 1;
+ SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
+ targetShape->end());
+ int64_t lastDim = sourceSliceShape.size() - 1;
- sourceTileShape[lastDim] =
+ sourceSliceShape[lastDim] =
((*targetShape)[lastDim] * resultElementBits) / sourceElementBits;
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
rewriter.getZeroAttr(resultType));
SmallVector<int64_t> resultStrides(targetShape->size(), 1);
- SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
+ SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
VectorType targetType =
VectorType::get(*targetShape, resultType.getElementType());
@@ -1441,7 +1442,7 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
(resultOffsets[lastDim] * resultElementBits) / sourceElementBits;
Value sourceSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
- loc, bitCastOp.getSource(), sourceOffsets, sourceTileShape,
+ loc, bitCastOp.getSource(), sourceOffsets, sourceSliceShape,
sourceStrides);
Value bitcastSlice = rewriter.createOrFold<vector::BitCastOp>(
loc, targetType, sourceSlice);
@@ -1462,13 +1463,18 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
/// input vectors, interleaving them, and inserting back into the result.
///
/// Example:
-/// vector.interleave %lhs, %rhs : vector<8xf32>
-/// // Unrolled with target shape [4]:
-/// %slice_lhs_0 = vector.extract_strided_slice %lhs[0] : vector<2xf32>
-/// %slice_rhs_0 = vector.extract_strided_slice %rhs[0] : vector<2xf32>
-/// %tile_0 = vector.interleave %slice_lhs_0, %slice_rhs_0 : vector<4xf32>
-/// %result = vector.insert_strided_slice %tile_0, %init[0]
-/// // ... repeat for remaining tiles
+/// Given an interleave Op:
+///
+/// vector.interleave %lhs, %rhs : vector<4x8xf32>
+///
+/// and a target unroll shape of <2x4>, the pattern produces:
+///
+/// %slice_lhs_0 = vector.extract_strided_slice %lhs[0, 0] : vector<2x2xf32>
+/// %slice_rhs_0 = vector.extract_strided_slice %rhs[0, 0] : vector<2x2xf32>
+/// %tile_0 = vector.interleave %slice_lhs_0, %slice_rhs_0
+/// : vector<2x4xf32>
+/// %result = vector.insert_strided_slice %tile_0, %init[0, 0]
+/// // ... repeat for remaining tiles
struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
UnrollInterleavePattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
@@ -1480,7 +1486,8 @@ struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
PatternRewriter &rewriter) const override {
auto targetShape = getTargetShape(options, interleaveOp);
if (!targetShape)
- return failure();
+ return rewriter.notifyMatchFailure(interleaveOp,
+ "failed to get target shape");
VectorType resultType = interleaveOp.getResultVectorType();
ArrayRef<int64_t> resultShape = resultType.getShape();
@@ -1490,15 +1497,15 @@ struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
return rewriter.notifyMatchFailure(
interleaveOp, "target shape rank must match result rank");
- SmallVector<int64_t> sourceTileShape(targetShape->begin(),
- targetShape->end());
- int64_t lastDim = sourceTileShape.size() - 1;
- sourceTileShape[lastDim] = (*targetShape)[lastDim] / 2;
+ SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
+ targetShape->end());
+ int64_t lastDim = sourceSliceShape.size() - 1;
+ sourceSliceShape[lastDim] = (*targetShape)[lastDim] / 2;
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
rewriter.getZeroAttr(resultType));
SmallVector<int64_t> resultStrides(targetShape->size(), 1);
- SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
+ SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
VectorType targetType =
VectorType::get(*targetShape, resultType.getElementType());
@@ -1509,10 +1516,10 @@ struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
sourceOffsets[lastDim] = resultOffsets[lastDim] / 2;
Value lhsSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
- loc, interleaveOp.getLhs(), sourceOffsets, sourceTileShape,
+ loc, interleaveOp.getLhs(), sourceOffsets, sourceSliceShape,
sourceStrides);
Value rhsSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
- loc, interleaveOp.getRhs(), sourceOffsets, sourceTileShape,
+ loc, interleaveOp.getRhs(), sourceOffsets, sourceSliceShape,
sourceStrides);
Value interleaveSlice = rewriter.createOrFold<vector::InterleaveOp>(
loc, targetType, lhsSlice, rhsSlice);
@@ -1555,7 +1562,8 @@ struct UnrollDeinterleavePattern
PatternRewriter &rewriter) const override {
auto targetShape = getTargetShape(options, deinterleaveOp);
if (!targetShape)
- return failure();
+ return rewriter.notifyMatchFailure(deinterleaveOp,
+ "failed to get target shape");
VectorType resultType = deinterleaveOp.getResultVectorType();
ArrayRef<int64_t> resultShape = resultType.getShape();
@@ -1565,17 +1573,17 @@ struct UnrollDeinterleavePattern
return rewriter.notifyMatchFailure(
deinterleaveOp, "target shape rank must match result rank");
- SmallVector<int64_t> sourceTileShape(targetShape->begin(),
- targetShape->end());
- int64_t lastDim = sourceTileShape.size() - 1;
- sourceTileShape[lastDim] = (*targetShape)[lastDim] * 2;
+ SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
+ targetShape->end());
+ int64_t lastDim = sourceSliceShape.size() - 1;
+ sourceSliceShape[lastDim] = (*targetShape)[lastDim] * 2;
- Value result1 = arith::ConstantOp::create(rewriter, loc, resultType,
- rewriter.getZeroAttr(resultType));
- Value result2 = arith::ConstantOp::create(rewriter, loc, resultType,
- rewriter.getZeroAttr(resultType));
+ Value resultOdd = arith::ConstantOp::create(
+ rewriter, loc, resultType, rewriter.getZeroAttr(resultType));
+ Value resultEven = arith::ConstantOp::create(
+ rewriter, loc, resultType, rewriter.getZeroAttr(resultType));
SmallVector<int64_t> resultStrides(targetShape->size(), 1);
- SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
+ SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
for (SmallVector<int64_t> resultOffsets :
StaticTileOffsetRange(resultShape, *targetShape)) {
@@ -1583,21 +1591,21 @@ struct UnrollDeinterleavePattern
sourceOffsets[lastDim] = resultOffsets[lastDim] * 2;
Value sourceSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
- loc, deinterleaveOp.getSource(), sourceOffsets, sourceTileShape,
+ loc, deinterleaveOp.getSource(), sourceOffsets, sourceSliceShape,
sourceStrides);
auto deinterleaveSlice =
vector::DeinterleaveOp::create(rewriter, loc, sourceSlice);
- result1 = rewriter.createOrFold<vector::InsertStridedSliceOp>(
- loc, deinterleaveSlice.getRes1(), result1, resultOffsets,
+ resultOdd = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, deinterleaveSlice.getRes1(), resultOdd, resultOffsets,
resultStrides);
- result2 = rewriter.createOrFold<vector::InsertStridedSliceOp>(
- loc, deinterleaveSlice.getRes2(), result2, resultOffsets,
+ resultEven = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, deinterleaveSlice.getRes2(), resultEven, resultOffsets,
resultStrides);
}
- rewriter.replaceOp(deinterleaveOp, ValueRange{result1, result2});
+ rewriter.replaceOp(deinterleaveOp, ValueRange{resultOdd, resultEven});
return success();
}
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 358d39d4ff2dd..16637eacd5b95 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -671,23 +671,31 @@ func.func @shape_cast_with_all_unit_target_shape(%v: vector<2xf32>) -> vector<2x
// -----
// Test BitCastOp unrolling - target shape [4, 4]
-func.func @bitcast_unroll(%arg0: vector<8x4xf32>) -> vector<8x8xi16> {
- %0 = vector.bitcast %arg0 : vector<8x4xf32> to vector<8x8xi16>
+func.func @bitcast_2d(%v: vector<8x4xf32>) -> vector<8x8xi16> {
+ %0 = vector.bitcast %v : vector<8x4xf32> to vector<8x8xi16>
return %0 : vector<8x8xi16>
}
-// CHECK-LABEL: func @bitcast_unroll
-// CHECK-SAME: (%[[ARG:.*]]: vector<8x4xf32>) -> vector<8x8xi16>
+// CHECK-LABEL: func @bitcast_2d
+// CHECK-SAME: (%[[V:.*]]: vector<8x4xf32>) -> vector<8x8xi16>
// CHECK: %[[INIT:.*]] = arith.constant dense<0> : vector<8x8xi16>
-// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
+//
+/// SLICE 0:
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
// CHECK: %[[BC0:.*]] = vector.bitcast %[[S0]] : vector<4x2xf32> to vector<4x4xi16>
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[BC0]], %[[INIT]] {offsets = [0, 0], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
-// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [0, 2], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
+//
+/// SLICE 1:
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 2], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
// CHECK: %[[BC1:.*]] = vector.bitcast %[[S1]] : vector<4x2xf32> to vector<4x4xi16>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[BC1]], %[[I0]] {offsets = [0, 4], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
-// CHECK: %[[S2:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
+//
+/// SLICE 2:
+// CHECK: %[[S2:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
// CHECK: %[[BC2:.*]] = vector.bitcast %[[S2]] : vector<4x2xf32> to vector<4x4xi16>
// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[BC2]], %[[I1]] {offsets = [4, 0], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
-// CHECK: %[[S3:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [4, 2], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
+//
+// SLICE 3:
+// CHECK: %[[S3:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 2], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
// CHECK: %[[BC3:.*]] = vector.bitcast %[[S3]] : vector<4x2xf32> to vector<4x4xi16>
// CHECK: %[[I3:.*]] = vector.insert_strided_slice %[[BC3]], %[[I2]] {offsets = [4, 4], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
// CHECK: return %[[I3]] : vector<8x8xi16>
@@ -695,25 +703,33 @@ func.func @bitcast_unroll(%arg0: vector<8x4xf32>) -> vector<8x8xi16> {
// -----
// Test InterleaveOp unrolling - target shape [8]
-func.func @interleave_unroll(%arg0: vector<16xi32>, %arg1: vector<16xi32>) -> vector<32xi32> {
- %0 = vector.interleave %arg0, %arg1 : vector<16xi32> -> vector<32xi32>
+func.func @interleave_1d(%V: vector<16xi32>, %arg1: vector<16xi32>) -> vector<32xi32> {
+ %0 = vector.interleave %V, %arg1 : vector<16xi32> -> vector<32xi32>
return %0 : vector<32xi32>
}
-// CHECK-LABEL: func @interleave_unroll
+// CHECK-LABEL: func @interleave_1d
// CHECK-SAME: (%[[LHS:.*]]: vector<16xi32>, %[[RHS:.*]]: vector<16xi32>) -> vector<32xi32>
// CHECK: %[[INIT:.*]] = arith.constant dense<0> : vector<32xi32>
+//
+/// SLICE 0:
// CHECK: %[[L0:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
// CHECK: %[[R0:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
// CHECK: %[[INT0:.*]] = vector.interleave %[[L0]], %[[R0]] : vector<4xi32> -> vector<8xi32>
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[INT0]], %[[INIT]] {offsets = [0], strides = [1]} : vector<8xi32> into vector<32xi32>
+//
+/// SLICE 1:
// CHECK: %[[L1:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
// CHECK: %[[R1:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
// CHECK: %[[INT1:.*]] = vector.interleave %[[L1]], %[[R1]] : vector<4xi32> -> vector<8xi32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[INT1]], %[[I0]] {offsets = [8], strides = [1]} : vector<8xi32> into vector<32xi32>
+//
+/// SLICE 2:
// CHECK: %[[L2:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [8], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
// CHECK: %[[R2:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [8], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
// CHECK: %[[INT2:.*]] = vector.interleave %[[L2]], %[[R2]] : vector<4xi32> -> vector<8xi32>
// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[INT2]], %[[I1]] {offsets = [16], strides = [1]} : vector<8xi32> into vector<32xi32>
+//
+/// SLICE 3:
// CHECK: %[[L3:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [12], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
// CHECK: %[[R3:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [12], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
// CHECK: %[[INT3:.*]] = vector.interleave %[[L3]], %[[R3]] : vector<4xi32> -> vector<8xi32>
@@ -723,18 +739,22 @@ func.func @interleave_unroll(%arg0: vector<16xi32>, %arg1: vector<16xi32>) -> ve
// -----
// Test DeinterleaveOp unrolling - target shape [4]
-func.func @deinterleave_unroll(%arg0: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>) {
- %0, %1 = vector.deinterleave %arg0 : vector<16xi32> -> vector<8xi32>
+func.func @deinterleave_1d(%V: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>) {
+ %0, %1 = vector.deinterleave %v : vector<16xi32> -> vector<8xi32>
return %0, %1 : vector<8xi32>, vector<8xi32>
}
-// CHECK-LABEL: func @deinterleave_unroll
-// CHECK-SAME: (%[[ARG:.*]]: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>)
+// CHECK-LABEL: func @deinterleave_1d
+// CHECK-SAME: (%[[V:.*]]: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>)
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
-// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xi32> to vector<8xi32>
+//
+/// SLICE 0:
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xi32> to vector<8xi32>
// CHECK: {{.*}} = vector.deinterleave %[[S0]] : vector<8xi32> -> vector<4xi32>
// CHECK: {{.*}} = vector.insert_strided_slice {{.*}}, %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
// CHECK: {{.*}} = vector.insert_strided_slice {{.*}}, %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
-// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xi32> to vector<8xi32>
+//
+/// SLICE 1:
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xi32> to vector<8xi32>
// CHECK: {{.*}} = vector.deinterleave %[[S1]] : vector<8xi32> -> vector<4xi32>
// CHECK: {{.*}} = vector.insert_strided_slice {{.*}}, {{.*}} {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
// CHECK: {{.*}} = vector.insert_strided_slice {{.*}}, {{.*}} {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
>From 7c8150e313525b64d44bf3a09de48abdd40d19c7 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 29 Apr 2026 00:51:51 +0000
Subject: [PATCH 4/6] fix test
---
mlir/test/Dialect/Vector/vector-unroll-options.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 16637eacd5b95..b1a0c4211f5d1 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -739,7 +739,7 @@ func.func @interleave_1d(%V: vector<16xi32>, %arg1: vector<16xi32>) -> vector<32
// -----
// Test DeinterleaveOp unrolling - target shape [4]
-func.func @deinterleave_1d(%V: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>) {
+func.func @deinterleave_1d(%v: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>) {
%0, %1 = vector.deinterleave %v : vector<16xi32> -> vector<8xi32>
return %0, %1 : vector<8xi32>, vector<8xi32>
}
>From d1a859adf7f55506f6c766bb1ba8ba4af2c04976 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 29 Apr 2026 18:25:19 +0000
Subject: [PATCH 5/6] address feedback and improve tests
---
.../Vector/Transforms/VectorUnroll.cpp | 48 ++++++----
.../Dialect/Vector/vector-unroll-options.mlir | 96 +++++++++----------
.../Dialect/Vector/TestVectorTransforms.cpp | 4 +-
3 files changed, 80 insertions(+), 68 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 4d1d39cd9d61d..acf05a00872d7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1389,11 +1389,20 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
vector::UnrollVectorOptions options;
};
-// Unroll vector::BitCastOp into smaller tile-based bitcast operations.
+// Unroll vector::BitCastOp into smaller slice-based bitcast operations.
// Tiles the result vector into target shape chunks and bitcasts corresponding
// source slices, accounting for element bitwidth ratios.
-// Example: bitcast v8f32 to v16f16 with target shape [4] unrolls into
-// multiple bitcast operations on 4-element tiles.
+/// Example:
+/// Given a deinterleave Op:
+///
+/// vector.bitcast %src : vector<4x8xf32>
+///
+/// and a target unroll shape of <2x4>, the pattern produces:
+///
+/// %slice_0 = vector.extract_strided_slice %lhs[0, 0] : vector<2x4xf32>
+/// %slice_0 = vector.bitcast %slice_0 : vector<2x4xf32>
+/// %result = vector.insert_strided_slice %slice_0, %init[0, 0]
+/// // ... repeat for remaining slices
struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
UnrollBitCastPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
@@ -1458,8 +1467,8 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
vector::UnrollVectorOptions options;
};
-/// Pattern to unroll vector.interleave into smaller tile-sized operations.
-/// Decomposes a large interleave into tiles by extracting slices from both
+/// Pattern to unroll vector.interleave into smaller slice-sized operations.
+/// Decomposes a large interleave into slices by extracting slices from both
/// input vectors, interleaving them, and inserting back into the result.
///
/// Example:
@@ -1471,10 +1480,10 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
///
/// %slice_lhs_0 = vector.extract_strided_slice %lhs[0, 0] : vector<2x2xf32>
/// %slice_rhs_0 = vector.extract_strided_slice %rhs[0, 0] : vector<2x2xf32>
-/// %tile_0 = vector.interleave %slice_lhs_0, %slice_rhs_0
+/// %slice_0 = vector.interleave %slice_lhs_0, %slice_rhs_0
/// : vector<2x4xf32>
-/// %result = vector.insert_strided_slice %tile_0, %init[0, 0]
-/// // ... repeat for remaining tiles
+/// %result = vector.insert_strided_slice %slice_0, %init[0, 0]
+/// // ... repeat for remaining slices
struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
UnrollInterleavePattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
@@ -1535,21 +1544,24 @@ struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
vector::UnrollVectorOptions options;
};
-/// Pattern to unroll vector.deinterleave into smaller tile-sized operations.
+/// Pattern to unroll vector.deinterleave into smaller slice-sized operations.
/// Decomposes a large deinterleave (which splits a vector into even/odd halves)
/// by extracting source slices, deinterleaving them, and inserting into two
/// result vectors.
///
/// Example:
-/// %res1, %res2 = vector.deinterleave %src : vector<8xf32>
-/// // Result: %res1 = [src[0], src[2], src[4], src[6]]
-/// // %res2 = [src[1], src[3], src[5], src[7]]
-/// // Unrolled with target shape [2]:
-/// %slice_0 = vector.extract_strided_slice %src[0] : vector<4xf32>
-/// %tile1_0, %tile2_0 = vector.deinterleave %slice_0 : vector<2xf32>
-/// %result1 = vector.insert_strided_slice %tile1_0, %init1[0]
-/// %result2 = vector.insert_strided_slice %tile2_0, %init2[0]
-/// // ... repeat for remaining tiles
+/// Given a deinterleave Op:
+///
+/// vector.deinterleave %src : vector<4x8xf32>
+///
+/// and a target unroll shape of <2x4>, the pattern produces:
+///
+/// %slice_0 = vector.extract_strided_slice %src[0, 0] : vector<2x4xf32>
+/// %slice_lhs_0, %slice_rhs_0 = vector.deinterleave %slice_0 :
+/// vector<2x4xf32> %result1 = vector.insert_strided_slice %slice_lhs_0,
+/// %init1[0, 0] %result2 = vector.insert_strided_slice %slice_rhs_0,
+/// %init2[0, 0]
+/// // ... repeat for remaining slices
struct UnrollDeinterleavePattern
: public OpRewritePattern<vector::DeinterleaveOp> {
UnrollDeinterleavePattern(MLIRContext *context,
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index b1a0c4211f5d1..bb6fc4e38813d 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -679,22 +679,22 @@ func.func @bitcast_2d(%v: vector<8x4xf32>) -> vector<8x8xi16> {
// CHECK-SAME: (%[[V:.*]]: vector<8x4xf32>) -> vector<8x8xi16>
// CHECK: %[[INIT:.*]] = arith.constant dense<0> : vector<8x8xi16>
//
-/// SLICE 0:
+/// SLICE 0,0:
// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
// CHECK: %[[BC0:.*]] = vector.bitcast %[[S0]] : vector<4x2xf32> to vector<4x4xi16>
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[BC0]], %[[INIT]] {offsets = [0, 0], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
//
-/// SLICE 1:
+/// SLICE 0,1:
// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 2], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
// CHECK: %[[BC1:.*]] = vector.bitcast %[[S1]] : vector<4x2xf32> to vector<4x4xi16>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[BC1]], %[[I0]] {offsets = [0, 4], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
//
-/// SLICE 2:
+/// SLICE 1,0:
// CHECK: %[[S2:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
// CHECK: %[[BC2:.*]] = vector.bitcast %[[S2]] : vector<4x2xf32> to vector<4x4xi16>
// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[BC2]], %[[I1]] {offsets = [4, 0], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
//
-// SLICE 3:
+/// SLICE 1,1:
// CHECK: %[[S3:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 2], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
// CHECK: %[[BC3:.*]] = vector.bitcast %[[S3]] : vector<4x2xf32> to vector<4x4xi16>
// CHECK: %[[I3:.*]] = vector.insert_strided_slice %[[BC3]], %[[I2]] {offsets = [4, 4], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
@@ -702,60 +702,60 @@ func.func @bitcast_2d(%v: vector<8x4xf32>) -> vector<8x8xi16> {
// -----
-// Test InterleaveOp unrolling - target shape [8]
-func.func @interleave_1d(%V: vector<16xi32>, %arg1: vector<16xi32>) -> vector<32xi32> {
- %0 = vector.interleave %V, %arg1 : vector<16xi32> -> vector<32xi32>
- return %0 : vector<32xi32>
+// Test InterleaveOp unrolling - target shape [2x4]
+func.func @interleave_2d(%V: vector<4x4xi32>, %arg1: vector<4x4xi32>) -> vector<4x8xi32> {
+ %0 = vector.interleave %V, %arg1 : vector<4x4xi32> -> vector<4x8xi32>
+ return %0 : vector<4x8xi32>
}
-// CHECK-LABEL: func @interleave_1d
-// CHECK-SAME: (%[[LHS:.*]]: vector<16xi32>, %[[RHS:.*]]: vector<16xi32>) -> vector<32xi32>
-// CHECK: %[[INIT:.*]] = arith.constant dense<0> : vector<32xi32>
+// CHECK-LABEL: func @interleave_2d
+// CHECK-SAME: (%[[LHS:.*]]: vector<4x4xi32>, %[[RHS:.*]]: vector<4x4xi32>) -> vector<4x8xi32>
+// CHECK: %[[INIT:.*]] = arith.constant dense<0> : vector<4x8xi32>
//
-/// SLICE 0:
-// CHECK: %[[L0:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
-// CHECK: %[[R0:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
-// CHECK: %[[INT0:.*]] = vector.interleave %[[L0]], %[[R0]] : vector<4xi32> -> vector<8xi32>
-// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[INT0]], %[[INIT]] {offsets = [0], strides = [1]} : vector<8xi32> into vector<32xi32>
+/// SLICE 0,0:
+// CHECK: %[[L0:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK: %[[R0:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK: %[[INT0:.*]] = vector.interleave %[[L0]], %[[R0]] : vector<2x2xi32> -> vector<2x4xi32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[INT0]], %[[INIT]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xi32> into vector<4x8xi32>
//
-/// SLICE 1:
-// CHECK: %[[L1:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
-// CHECK: %[[R1:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
-// CHECK: %[[INT1:.*]] = vector.interleave %[[L1]], %[[R1]] : vector<4xi32> -> vector<8xi32>
-// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[INT1]], %[[I0]] {offsets = [8], strides = [1]} : vector<8xi32> into vector<32xi32>
+/// SLICE 0,1:
+// CHECK: %[[L1:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK: %[[R1:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK: %[[INT1:.*]] = vector.interleave %[[L1]], %[[R1]] : vector<2x2xi32> -> vector<2x4xi32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[INT1]], %[[I0]] {offsets = [0, 4], strides = [1, 1]} : vector<2x4xi32> into vector<4x8xi32>
//
-/// SLICE 2:
-// CHECK: %[[L2:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [8], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
-// CHECK: %[[R2:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [8], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
-// CHECK: %[[INT2:.*]] = vector.interleave %[[L2]], %[[R2]] : vector<4xi32> -> vector<8xi32>
-// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[INT2]], %[[I1]] {offsets = [16], strides = [1]} : vector<8xi32> into vector<32xi32>
+/// SLICE 1,0:
+// CHECK: %[[L2:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK: %[[R2:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK: %[[INT2:.*]] = vector.interleave %[[L2]], %[[R2]] : vector<2x2xi32> -> vector<2x4xi32>
+// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[INT2]], %[[I1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xi32> into vector<4x8xi32>
//
-/// SLICE 3:
-// CHECK: %[[L3:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [12], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
-// CHECK: %[[R3:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [12], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
-// CHECK: %[[INT3:.*]] = vector.interleave %[[L3]], %[[R3]] : vector<4xi32> -> vector<8xi32>
-// CHECK: %[[I3:.*]] = vector.insert_strided_slice %[[INT3]], %[[I2]] {offsets = [24], strides = [1]} : vector<8xi32> into vector<32xi32>
-// CHECK: return %[[I3]] : vector<32xi32>
+/// SLICE 1,1:
+// CHECK: %[[L3:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK: %[[R3:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK: %[[INT3:.*]] = vector.interleave %[[L3]], %[[R3]] : vector<2x2xi32> -> vector<2x4xi32>
+// CHECK: %[[I3:.*]] = vector.insert_strided_slice %[[INT3]], %[[I2]] {offsets = [2, 4], strides = [1, 1]} : vector<2x4xi32> into vector<4x8xi32>
+// CHECK: return %[[I3]] : vector<4x8xi32>
// -----
-// Test DeinterleaveOp unrolling - target shape [4]
-func.func @deinterleave_1d(%v: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>) {
- %0, %1 = vector.deinterleave %v : vector<16xi32> -> vector<8xi32>
- return %0, %1 : vector<8xi32>, vector<8xi32>
+// Test DeinterleaveOp unrolling - target shape [2x4]
+func.func @deinterleave_2d(%v: vector<4x8xi32>) -> (vector<4x4xi32>, vector<4x4xi32>) {
+ %0, %1 = vector.deinterleave %v : vector<4x8xi32> -> vector<4x4xi32>
+ return %0, %1 : vector<4x4xi32>, vector<4x4xi32>
}
-// CHECK-LABEL: func @deinterleave_1d
-// CHECK-SAME: (%[[V:.*]]: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>)
-// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+// CHECK-LABEL: func @deinterleave_2d
+// CHECK-SAME: (%[[V:.*]]: vector<4x8xi32>) -> (vector<4x4xi32>, vector<4x4xi32>)
+// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4x4xi32>
//
/// SLICE 0:
-// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xi32> to vector<8xi32>
-// CHECK: {{.*}} = vector.deinterleave %[[S0]] : vector<8xi32> -> vector<4xi32>
-// CHECK: {{.*}} = vector.insert_strided_slice {{.*}}, %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
-// CHECK: {{.*}} = vector.insert_strided_slice {{.*}}, %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi32> to vector<2x8xi32>
+// CHECK: {{.*}} = vector.deinterleave %[[S0]] : vector<2x8xi32> -> vector<2x4xi32>
+// CHECK: {{.*}} = vector.insert_strided_slice {{.*}}, %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xi32> into vector<4x4xi32>
+// CHECK: {{.*}} = vector.insert_strided_slice {{.*}}, %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xi32> into vector<4x4xi32>
//
/// SLICE 1:
-// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xi32> to vector<8xi32>
-// CHECK: {{.*}} = vector.deinterleave %[[S1]] : vector<8xi32> -> vector<4xi32>
-// CHECK: {{.*}} = vector.insert_strided_slice {{.*}}, {{.*}} {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
-// CHECK: {{.*}} = vector.insert_strided_slice {{.*}}, {{.*}} {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
-// CHECK: return {{.*}}, {{.*}} : vector<8xi32>, vector<8xi32>
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi32> to vector<2x8xi32>
+// CHECK: {{.*}} = vector.deinterleave %[[S1]] : vector<2x8xi32> -> vector<2x4xi32>
+// CHECK: {{.*}} = vector.insert_strided_slice {{.*}}, {{.*}} {offsets = [2, 0], strides = [1, 1]} : vector<2x4xi32> into vector<4x4xi32>
+// CHECK: {{.*}} = vector.insert_strided_slice {{.*}}, {{.*}} {offsets = [2, 0], strides = [1, 1]} : vector<2x4xi32> into vector<4x4xi32>
+// CHECK: return {{.*}}, {{.*}} : vector<4x4xi32>, vector<4x4xi32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index fe31d6b3e9639..043181c16c759 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -226,13 +226,13 @@ struct TestVectorUnrollingPatterns
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
- .setNativeShape(ArrayRef<int64_t>{8})
+ .setNativeShape(ArrayRef<int64_t>{2, 4})
.setFilterConstraint([](Operation *op) {
return success(isa<vector::InterleaveOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
- .setNativeShape(ArrayRef<int64_t>{4})
+ .setNativeShape(ArrayRef<int64_t>{2, 4})
.setFilterConstraint([](Operation *op) {
return success(isa<vector::DeinterleaveOp>(op));
}));
>From 2d52306205ab8fa5a1e3976eb4321812a5c8e6e7 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 30 Apr 2026 22:51:05 +0000
Subject: [PATCH 6/6] fix comments
---
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index acf05a00872d7..25d2e2c578441 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1390,10 +1390,10 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
};
// Unroll vector::BitCastOp into smaller slice-based bitcast operations.
-// Tiles the result vector into target shape chunks and bitcasts corresponding
-// source slices, accounting for element bitwidth ratios.
+// Decomposes the result vector into target shape chunks and bitcasts
+// corresponding source slices, accounting for element bitwidth ratios.
/// Example:
-/// Given a deinterleave Op:
+/// Given a bitcast Op:
///
/// vector.bitcast %src : vector<4x8xf32>
///
More information about the Mlir-commits
mailing list