[Mlir-commits] [mlir] [MLIR][Vector] Add unrolling support for bitcast, interleave, and deinterleave ops (PR #194513)

Jianhui Li llvmlistbot at llvm.org
Thu Apr 30 15:59:23 PDT 2026


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/194513

>From 264b0a9149baa11d135aea54bccdb8ad1508af47 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 23 Apr 2026 22:45:27 +0000
Subject: [PATCH 1/6] [mlir][Vector] Add unrolling support for bitcast,
 interleave, and deinterleave ops

This patch adds VectorUnrollOpInterface implementations and unrolling patterns
for vector bitcast, interleave, and deinterleave operations.

- UnrollBitCastPattern: Unrolls bitcast by adjusting tile shapes based on element
  type bitwidth ratios
- UnrollInterleavePattern: Unrolls interleave ops which double the trailing dimension
- UnrollDeinterleavePattern: Unrolls deinterleave ops which halve the trailing dimension

These patterns enable fine-grained tiling of vector transformations across different
element type conversions and data layout transformations.

Co-Authored-By: Claude Sonnet 4.5 <noreply at anthropic.com>
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |   9 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  20 ++
 .../Vector/Transforms/VectorUnroll.cpp        | 214 +++++++++++++++++-
 .../Dialect/Vector/vector-unroll-options.mlir |  72 ++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |  18 ++
 5 files changed, 328 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 68ef49172e662..74b49db36c6cd 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -533,7 +533,8 @@ def ResultIsDoubleSourceVectorType : TypesMatchWith<
 
 def Vector_InterleaveOp :
   Vector_Op<"interleave", [Pure, AllTypesMatch<["lhs", "rhs"]>,
-    ResultIsDoubleSourceVectorType]> {
+    ResultIsDoubleSourceVectorType,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>]> {
   let summary = "constructs a vector by interleaving two input vectors";
   let description = [{
     The interleave operation constructs a new vector by interleaving the
@@ -609,7 +610,8 @@ def Vector_DeinterleaveOp :
   Vector_Op<"deinterleave", [Pure,
     SourceVectorEvenElementCount,
     ResultIsHalfSourceVectorType<"res1">,
-    AllTypesMatch<["res1", "res2"]>
+    AllTypesMatch<["res1", "res2"]>,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
     ]> {
       let summary = "constructs two vectors by deinterleaving an input vector";
       let description = [{
@@ -2464,7 +2466,8 @@ def Vector_ShapeCastOp :
 }
 
 def Vector_BitCastOp :
-  Vector_Op<"bitcast", [Pure, AllRanksMatch<["source", "result"]>]>,
+  Vector_Op<"bitcast", [Pure, AllRanksMatch<["source", "result"]>,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>]>,
     Arguments<(ins AnyVectorOfNonI0Elem:$source)>,
     Results<(outs AnyVectorOfNonI0Elem:$result)>{
   let summary = "bitcast casts between vectors";
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3d3e49134363f..a7bd498299727 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7073,6 +7073,10 @@ OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+std::optional<SmallVector<int64_t, 4>> BitCastOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
 //===----------------------------------------------------------------------===//
 // TypeCastOp
 //===----------------------------------------------------------------------===//
@@ -8319,6 +8323,22 @@ Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
                                  mask, newValue, passthru);
 }
 
+//===----------------------------------------------------------------------===//
+// InterleaveOp
+//===----------------------------------------------------------------------===//
+
+std::optional<SmallVector<int64_t, 4>> InterleaveOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
+//===----------------------------------------------------------------------===//
+// DeinterleaveOp
+//===----------------------------------------------------------------------===//
+
+std::optional<SmallVector<int64_t, 4>> DeinterleaveOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index ec08f01d2a4b9..58eccc9301248 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1389,6 +1389,214 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
   vector::UnrollVectorOptions options;
 };
 
+struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
+  UnrollBitCastPattern(MLIRContext *context,
+                       const vector::UnrollVectorOptions &options,
+                       PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::BitCastOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, bitCastOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType sourceType = bitCastOp.getSourceVectorType();
+    VectorType resultType = bitCastOp.getResultVectorType();
+    ArrayRef<int64_t> resultShape = resultType.getShape();
+    Location loc = bitCastOp.getLoc();
+
+    // Bail out if target shape rank doesn't match result rank
+    if (targetShape->size() != resultShape.size())
+      return rewriter.notifyMatchFailure(
+          bitCastOp, "target shape rank must match result rank");
+
+    // BitCast changes element type, which may change the trailing dimension.
+    // For the source, deduce the tile shape from the result tile shape.
+    // The relationship: if result trailing dim is R and source is S,
+    // then resultBitWidth / R = sourceBitWidth / S (same bits per element).
+
+    unsigned sourceElementBits = sourceType.getElementTypeBitWidth();
+    unsigned resultElementBits = resultType.getElementTypeBitWidth();
+
+    // Deduce source tile shape: same as target except the trailing dimension
+    SmallVector<int64_t> sourceTileShape(targetShape->begin(),
+                                         targetShape->end());
+    int64_t lastDim = sourceTileShape.size() - 1;
+
+    // Scale the trailing dimension by the bitwidth ratio
+    sourceTileShape[lastDim] =
+        ((*targetShape)[lastDim] * resultElementBits) / sourceElementBits;
+
+    // Prepare the result vector
+    Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+                                             rewriter.getZeroAttr(resultType));
+    SmallVector<int64_t> resultStrides(targetShape->size(), 1);
+    SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
+
+    VectorType targetType =
+        VectorType::get(*targetShape, resultType.getElementType());
+
+    // Unroll the bitcast
+    for (SmallVector<int64_t> resultOffsets :
+         StaticTileOffsetRange(resultShape, *targetShape)) {
+      // Compute corresponding source offsets
+      SmallVector<int64_t> sourceOffsets = resultOffsets;
+      sourceOffsets[lastDim] =
+          (resultOffsets[lastDim] * resultElementBits) / sourceElementBits;
+
+      Value sourceSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+          loc, bitCastOp.getSource(), sourceOffsets, sourceTileShape,
+          sourceStrides);
+      Value bitcastSlice = rewriter.createOrFold<vector::BitCastOp>(
+          loc, targetType, sourceSlice);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, bitcastSlice, result, resultOffsets, resultStrides);
+    }
+
+    rewriter.replaceOp(bitCastOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
+struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
+  UnrollInterleavePattern(MLIRContext *context,
+                          const vector::UnrollVectorOptions &options,
+                          PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::InterleaveOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::InterleaveOp interleaveOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, interleaveOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType resultType = interleaveOp.getResultVectorType();
+    ArrayRef<int64_t> resultShape = resultType.getShape();
+    Location loc = interleaveOp.getLoc();
+
+    // Bail out if target shape rank doesn't match result rank
+    if (targetShape->size() != resultShape.size())
+      return rewriter.notifyMatchFailure(
+          interleaveOp, "target shape rank must match result rank");
+
+    // Interleave doubles the trailing dimension: [N] -> [2*N]
+    // For source tile shape, halve the trailing dimension of target shape
+    SmallVector<int64_t> sourceTileShape(targetShape->begin(),
+                                         targetShape->end());
+    int64_t lastDim = sourceTileShape.size() - 1;
+    sourceTileShape[lastDim] = (*targetShape)[lastDim] / 2;
+
+    // Prepare the result vector
+    Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+                                             rewriter.getZeroAttr(resultType));
+    SmallVector<int64_t> resultStrides(targetShape->size(), 1);
+    SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
+
+    VectorType targetType =
+        VectorType::get(*targetShape, resultType.getElementType());
+
+    // Unroll the interleave
+    for (SmallVector<int64_t> resultOffsets :
+         StaticTileOffsetRange(resultShape, *targetShape)) {
+      // Compute corresponding source offsets
+      SmallVector<int64_t> sourceOffsets = resultOffsets;
+      sourceOffsets[lastDim] = resultOffsets[lastDim] / 2;
+
+      Value lhsSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+          loc, interleaveOp.getLhs(), sourceOffsets, sourceTileShape,
+          sourceStrides);
+      Value rhsSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+          loc, interleaveOp.getRhs(), sourceOffsets, sourceTileShape,
+          sourceStrides);
+      Value interleaveSlice = rewriter.createOrFold<vector::InterleaveOp>(
+          loc, targetType, lhsSlice, rhsSlice);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, interleaveSlice, result, resultOffsets, resultStrides);
+    }
+
+    rewriter.replaceOp(interleaveOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
+struct UnrollDeinterleavePattern
+    : public OpRewritePattern<vector::DeinterleaveOp> {
+  UnrollDeinterleavePattern(MLIRContext *context,
+                            const vector::UnrollVectorOptions &options,
+                            PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::DeinterleaveOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::DeinterleaveOp deinterleaveOp,
+                                PatternRewriter &rewriter) const override {
+    // Get target shape based on the result type (res1)
+    auto targetShape = getTargetShape(options, deinterleaveOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType resultType = deinterleaveOp.getResultVectorType();
+    ArrayRef<int64_t> resultShape = resultType.getShape();
+    Location loc = deinterleaveOp.getLoc();
+
+    // Bail out if target shape rank doesn't match result rank
+    if (targetShape->size() != resultShape.size())
+      return rewriter.notifyMatchFailure(
+          deinterleaveOp, "target shape rank must match result rank");
+
+    // Deinterleave halves the trailing dimension: [2*N] -> [N]
+    // For source tile shape, double the trailing dimension of target shape
+    SmallVector<int64_t> sourceTileShape(targetShape->begin(),
+                                         targetShape->end());
+    int64_t lastDim = sourceTileShape.size() - 1;
+    sourceTileShape[lastDim] = (*targetShape)[lastDim] * 2;
+
+    // Prepare the result vectors
+    Value result1 = arith::ConstantOp::create(rewriter, loc, resultType,
+                                              rewriter.getZeroAttr(resultType));
+    Value result2 = arith::ConstantOp::create(rewriter, loc, resultType,
+                                              rewriter.getZeroAttr(resultType));
+    SmallVector<int64_t> resultStrides(targetShape->size(), 1);
+    SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
+
+    // Unroll the deinterleave
+    for (SmallVector<int64_t> resultOffsets :
+         StaticTileOffsetRange(resultShape, *targetShape)) {
+      // Compute corresponding source offsets
+      SmallVector<int64_t> sourceOffsets = resultOffsets;
+      sourceOffsets[lastDim] = resultOffsets[lastDim] * 2;
+
+      Value sourceSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+          loc, deinterleaveOp.getSource(), sourceOffsets, sourceTileShape,
+          sourceStrides);
+
+      auto deinterleaveSlice =
+          vector::DeinterleaveOp::create(rewriter, loc, sourceSlice);
+
+      result1 = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, deinterleaveSlice.getRes1(), result1, resultOffsets,
+          resultStrides);
+      result2 = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, deinterleaveSlice.getRes2(), result2, resultOffsets,
+          resultStrides);
+    }
+
+    rewriter.replaceOp(deinterleaveOp, ValueRange{result1, result2});
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
@@ -1400,8 +1608,10 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
                UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
                UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
-               UnrollCreateMaskPattern, UnrollConstantMaskPattern>(
-      patterns.getContext(), options, benefit);
+               UnrollCreateMaskPattern, UnrollConstantMaskPattern,
+               UnrollBitCastPattern, UnrollInterleavePattern,
+               UnrollDeinterleavePattern>(patterns.getContext(), options,
+                                          benefit);
 }
 
 void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 036d09053552d..358d39d4ff2dd 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -667,3 +667,75 @@ func.func @shape_cast_with_all_unit_target_shape(%v: vector<2xf32>) -> vector<2x
 // CHECK:   %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<1xf32> to vector<1x1xf32>
 // CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0], strides = [1, 1]} : vector<1x1xf32> into vector<2x1xf32>
 // CHECK:   return %[[I1]] : vector<2x1xf32>
+
+// -----
+
+// Test BitCastOp unrolling - target shape [4, 4]
+func.func @bitcast_unroll(%arg0: vector<8x4xf32>) -> vector<8x8xi16> {
+  %0 = vector.bitcast %arg0 : vector<8x4xf32> to vector<8x8xi16>
+  return %0 : vector<8x8xi16>
+}
+// CHECK-LABEL: func @bitcast_unroll
+// CHECK-SAME: (%[[ARG:.*]]: vector<8x4xf32>) -> vector<8x8xi16>
+// CHECK:   %[[INIT:.*]] = arith.constant dense<0> : vector<8x8xi16>
+// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
+// CHECK:   %[[BC0:.*]] = vector.bitcast %[[S0]] : vector<4x2xf32> to vector<4x4xi16>
+// CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[BC0]], %[[INIT]] {offsets = [0, 0], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
+// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [0, 2], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
+// CHECK:   %[[BC1:.*]] = vector.bitcast %[[S1]] : vector<4x2xf32> to vector<4x4xi16>
+// CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[BC1]], %[[I0]] {offsets = [0, 4], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
+// CHECK:   %[[S2:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
+// CHECK:   %[[BC2:.*]] = vector.bitcast %[[S2]] : vector<4x2xf32> to vector<4x4xi16>
+// CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[BC2]], %[[I1]] {offsets = [4, 0], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
+// CHECK:   %[[S3:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [4, 2], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
+// CHECK:   %[[BC3:.*]] = vector.bitcast %[[S3]] : vector<4x2xf32> to vector<4x4xi16>
+// CHECK:   %[[I3:.*]] = vector.insert_strided_slice %[[BC3]], %[[I2]] {offsets = [4, 4], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
+// CHECK:   return %[[I3]] : vector<8x8xi16>
+
+// -----
+
+// Test InterleaveOp unrolling - target shape [8]
+func.func @interleave_unroll(%arg0: vector<16xi32>, %arg1: vector<16xi32>) -> vector<32xi32> {
+  %0 = vector.interleave %arg0, %arg1 : vector<16xi32> -> vector<32xi32>
+  return %0 : vector<32xi32>
+}
+// CHECK-LABEL: func @interleave_unroll
+// CHECK-SAME: (%[[LHS:.*]]: vector<16xi32>, %[[RHS:.*]]: vector<16xi32>) -> vector<32xi32>
+// CHECK:   %[[INIT:.*]] = arith.constant dense<0> : vector<32xi32>
+// CHECK:   %[[L0:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
+// CHECK:   %[[R0:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
+// CHECK:   %[[INT0:.*]] = vector.interleave %[[L0]], %[[R0]] : vector<4xi32> -> vector<8xi32>
+// CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[INT0]], %[[INIT]] {offsets = [0], strides = [1]} : vector<8xi32> into vector<32xi32>
+// CHECK:   %[[L1:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
+// CHECK:   %[[R1:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
+// CHECK:   %[[INT1:.*]] = vector.interleave %[[L1]], %[[R1]] : vector<4xi32> -> vector<8xi32>
+// CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[INT1]], %[[I0]] {offsets = [8], strides = [1]} : vector<8xi32> into vector<32xi32>
+// CHECK:   %[[L2:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [8], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
+// CHECK:   %[[R2:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [8], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
+// CHECK:   %[[INT2:.*]] = vector.interleave %[[L2]], %[[R2]] : vector<4xi32> -> vector<8xi32>
+// CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[INT2]], %[[I1]] {offsets = [16], strides = [1]} : vector<8xi32> into vector<32xi32>
+// CHECK:   %[[L3:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [12], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
+// CHECK:   %[[R3:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [12], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
+// CHECK:   %[[INT3:.*]] = vector.interleave %[[L3]], %[[R3]] : vector<4xi32> -> vector<8xi32>
+// CHECK:   %[[I3:.*]] = vector.insert_strided_slice %[[INT3]], %[[I2]] {offsets = [24], strides = [1]} : vector<8xi32> into vector<32xi32>
+// CHECK:   return %[[I3]] : vector<32xi32>
+
+// -----
+
+// Test DeinterleaveOp unrolling - target shape [4]
+func.func @deinterleave_unroll(%arg0: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>) {
+  %0, %1 = vector.deinterleave %arg0 : vector<16xi32> -> vector<8xi32>
+  return %0, %1 : vector<8xi32>, vector<8xi32>
+}
+// CHECK-LABEL: func @deinterleave_unroll
+// CHECK-SAME: (%[[ARG:.*]]: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>)
+// CHECK:   %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xi32> to vector<8xi32>
+// CHECK:   {{.*}} = vector.deinterleave %[[S0]] : vector<8xi32> -> vector<4xi32>
+// CHECK:   {{.*}} = vector.insert_strided_slice {{.*}}, %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+// CHECK:   {{.*}} = vector.insert_strided_slice {{.*}}, %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xi32> to vector<8xi32>
+// CHECK:   {{.*}} = vector.deinterleave %[[S1]] : vector<8xi32> -> vector<4xi32>
+// CHECK:   {{.*}} = vector.insert_strided_slice {{.*}}, {{.*}} {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+// CHECK:   {{.*}} = vector.insert_strided_slice {{.*}}, {{.*}} {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
+// CHECK:   return {{.*}}, {{.*}} : vector<8xi32>, vector<8xi32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index ff3520a286cc8..fe31d6b3e9639 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -218,6 +218,24 @@ struct TestVectorUnrollingPatterns
                       .setFilterConstraint([](Operation *op) {
                         return success(isa<vector::TransposeOp>(op));
                       }));
+    populateVectorUnrollPatterns(
+        patterns, UnrollVectorOptions()
+                      .setNativeShape(ArrayRef<int64_t>{4, 4})
+                      .setFilterConstraint([](Operation *op) {
+                        return success(isa<vector::BitCastOp>(op));
+                      }));
+    populateVectorUnrollPatterns(
+        patterns, UnrollVectorOptions()
+                      .setNativeShape(ArrayRef<int64_t>{8})
+                      .setFilterConstraint([](Operation *op) {
+                        return success(isa<vector::InterleaveOp>(op));
+                      }));
+    populateVectorUnrollPatterns(
+        patterns, UnrollVectorOptions()
+                      .setNativeShape(ArrayRef<int64_t>{4})
+                      .setFilterConstraint([](Operation *op) {
+                        return success(isa<vector::DeinterleaveOp>(op));
+                      }));
 
     if (unrollBasedOnType) {
       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =

>From 25bd00ac3d82d5d1970ec1394f3656b18fbee7ff Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 28 Apr 2026 01:58:33 +0000
Subject: [PATCH 2/6] add comments

---
 .../Vector/Transforms/VectorUnroll.cpp        | 56 +++++++++++--------
 1 file changed, 32 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 58eccc9301248..b189f163d0660 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1389,6 +1389,11 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
   vector::UnrollVectorOptions options;
 };
 
+// Unroll vector::BitCastOp into smaller tile-based bitcast operations.
+// Tiles the result vector into target shape chunks and bitcasts corresponding
+// source slices, accounting for element bitwidth ratios.
+// Example: bitcast v8f32 to v16f16 with target shape [4] unrolls into
+// multiple bitcast operations on 4-element tiles.
 struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
   UnrollBitCastPattern(MLIRContext *context,
                        const vector::UnrollVectorOptions &options,
@@ -1407,29 +1412,20 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
     ArrayRef<int64_t> resultShape = resultType.getShape();
     Location loc = bitCastOp.getLoc();
 
-    // Bail out if target shape rank doesn't match result rank
     if (targetShape->size() != resultShape.size())
       return rewriter.notifyMatchFailure(
           bitCastOp, "target shape rank must match result rank");
 
-    // BitCast changes element type, which may change the trailing dimension.
-    // For the source, deduce the tile shape from the result tile shape.
-    // The relationship: if result trailing dim is R and source is S,
-    // then resultBitWidth / R = sourceBitWidth / S (same bits per element).
-
     unsigned sourceElementBits = sourceType.getElementTypeBitWidth();
     unsigned resultElementBits = resultType.getElementTypeBitWidth();
 
-    // Deduce source tile shape: same as target except the trailing dimension
     SmallVector<int64_t> sourceTileShape(targetShape->begin(),
                                          targetShape->end());
     int64_t lastDim = sourceTileShape.size() - 1;
 
-    // Scale the trailing dimension by the bitwidth ratio
     sourceTileShape[lastDim] =
         ((*targetShape)[lastDim] * resultElementBits) / sourceElementBits;
 
-    // Prepare the result vector
     Value result = arith::ConstantOp::create(rewriter, loc, resultType,
                                              rewriter.getZeroAttr(resultType));
     SmallVector<int64_t> resultStrides(targetShape->size(), 1);
@@ -1438,10 +1434,8 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
     VectorType targetType =
         VectorType::get(*targetShape, resultType.getElementType());
 
-    // Unroll the bitcast
     for (SmallVector<int64_t> resultOffsets :
          StaticTileOffsetRange(resultShape, *targetShape)) {
-      // Compute corresponding source offsets
       SmallVector<int64_t> sourceOffsets = resultOffsets;
       sourceOffsets[lastDim] =
           (resultOffsets[lastDim] * resultElementBits) / sourceElementBits;
@@ -1463,6 +1457,18 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
   vector::UnrollVectorOptions options;
 };
 
+/// Pattern to unroll vector.interleave into smaller tile-sized operations.
+/// Decomposes a large interleave into tiles by extracting slices from both
+/// input vectors, interleaving them, and inserting back into the result.
+///
+/// Example:
+///   vector.interleave %lhs, %rhs : vector<8xf32>
+///   // Unrolled with target shape [4]:
+///   %slice_lhs_0 = vector.extract_strided_slice %lhs[0] : vector<2xf32>
+///   %slice_rhs_0 = vector.extract_strided_slice %rhs[0] : vector<2xf32>
+///   %tile_0 = vector.interleave %slice_lhs_0, %slice_rhs_0 : vector<4xf32>
+///   %result = vector.insert_strided_slice %tile_0, %init[0]
+///   // ... repeat for remaining tiles
 struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
   UnrollInterleavePattern(MLIRContext *context,
                           const vector::UnrollVectorOptions &options,
@@ -1480,19 +1486,15 @@ struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
     ArrayRef<int64_t> resultShape = resultType.getShape();
     Location loc = interleaveOp.getLoc();
 
-    // Bail out if target shape rank doesn't match result rank
     if (targetShape->size() != resultShape.size())
       return rewriter.notifyMatchFailure(
           interleaveOp, "target shape rank must match result rank");
 
-    // Interleave doubles the trailing dimension: [N] -> [2*N]
-    // For source tile shape, halve the trailing dimension of target shape
     SmallVector<int64_t> sourceTileShape(targetShape->begin(),
                                          targetShape->end());
     int64_t lastDim = sourceTileShape.size() - 1;
     sourceTileShape[lastDim] = (*targetShape)[lastDim] / 2;
 
-    // Prepare the result vector
     Value result = arith::ConstantOp::create(rewriter, loc, resultType,
                                              rewriter.getZeroAttr(resultType));
     SmallVector<int64_t> resultStrides(targetShape->size(), 1);
@@ -1501,10 +1503,8 @@ struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
     VectorType targetType =
         VectorType::get(*targetShape, resultType.getElementType());
 
-    // Unroll the interleave
     for (SmallVector<int64_t> resultOffsets :
          StaticTileOffsetRange(resultShape, *targetShape)) {
-      // Compute corresponding source offsets
       SmallVector<int64_t> sourceOffsets = resultOffsets;
       sourceOffsets[lastDim] = resultOffsets[lastDim] / 2;
 
@@ -1528,6 +1528,21 @@ struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
   vector::UnrollVectorOptions options;
 };
 
+/// Pattern to unroll vector.deinterleave into smaller tile-sized operations.
+/// Decomposes a large deinterleave (which splits a vector into even/odd halves)
+/// by extracting source slices, deinterleaving them, and inserting into two
+/// result vectors.
+///
+/// Example:
+///   %res1, %res2 = vector.deinterleave %src : vector<8xf32>
+///   // Result: %res1 = [src[0], src[2], src[4], src[6]]
+///   //         %res2 = [src[1], src[3], src[5], src[7]]
+///   // Unrolled with target shape [2]:
+///   %slice_0 = vector.extract_strided_slice %src[0] : vector<4xf32>
+///   %tile1_0, %tile2_0 = vector.deinterleave %slice_0 : vector<2xf32>
+///   %result1 = vector.insert_strided_slice %tile1_0, %init1[0]
+///   %result2 = vector.insert_strided_slice %tile2_0, %init2[0]
+///   // ... repeat for remaining tiles
 struct UnrollDeinterleavePattern
     : public OpRewritePattern<vector::DeinterleaveOp> {
   UnrollDeinterleavePattern(MLIRContext *context,
@@ -1538,7 +1553,6 @@ struct UnrollDeinterleavePattern
 
   LogicalResult matchAndRewrite(vector::DeinterleaveOp deinterleaveOp,
                                 PatternRewriter &rewriter) const override {
-    // Get target shape based on the result type (res1)
     auto targetShape = getTargetShape(options, deinterleaveOp);
     if (!targetShape)
       return failure();
@@ -1547,19 +1561,15 @@ struct UnrollDeinterleavePattern
     ArrayRef<int64_t> resultShape = resultType.getShape();
     Location loc = deinterleaveOp.getLoc();
 
-    // Bail out if target shape rank doesn't match result rank
     if (targetShape->size() != resultShape.size())
       return rewriter.notifyMatchFailure(
           deinterleaveOp, "target shape rank must match result rank");
 
-    // Deinterleave halves the trailing dimension: [2*N] -> [N]
-    // For source tile shape, double the trailing dimension of target shape
     SmallVector<int64_t> sourceTileShape(targetShape->begin(),
                                          targetShape->end());
     int64_t lastDim = sourceTileShape.size() - 1;
     sourceTileShape[lastDim] = (*targetShape)[lastDim] * 2;
 
-    // Prepare the result vectors
     Value result1 = arith::ConstantOp::create(rewriter, loc, resultType,
                                               rewriter.getZeroAttr(resultType));
     Value result2 = arith::ConstantOp::create(rewriter, loc, resultType,
@@ -1567,10 +1577,8 @@ struct UnrollDeinterleavePattern
     SmallVector<int64_t> resultStrides(targetShape->size(), 1);
     SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
 
-    // Unroll the deinterleave
     for (SmallVector<int64_t> resultOffsets :
          StaticTileOffsetRange(resultShape, *targetShape)) {
-      // Compute corresponding source offsets
       SmallVector<int64_t> sourceOffsets = resultOffsets;
       sourceOffsets[lastDim] = resultOffsets[lastDim] * 2;
 

>From 360400d3ac2786bfa82e814921e253a3130ed93d Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 29 Apr 2026 00:47:39 +0000
Subject: [PATCH 3/6] address feedback

---
 .../Vector/Transforms/VectorUnroll.cpp        | 84 ++++++++++---------
 .../Dialect/Vector/vector-unroll-options.mlir | 54 ++++++++----
 2 files changed, 83 insertions(+), 55 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index b189f163d0660..4d1d39cd9d61d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1405,7 +1405,8 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
                                 PatternRewriter &rewriter) const override {
     auto targetShape = getTargetShape(options, bitCastOp);
     if (!targetShape)
-      return failure();
+      return rewriter.notifyMatchFailure(bitCastOp,
+                                         "failed to get target shape");
 
     VectorType sourceType = bitCastOp.getSourceVectorType();
     VectorType resultType = bitCastOp.getResultVectorType();
@@ -1419,17 +1420,17 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
     unsigned sourceElementBits = sourceType.getElementTypeBitWidth();
     unsigned resultElementBits = resultType.getElementTypeBitWidth();
 
-    SmallVector<int64_t> sourceTileShape(targetShape->begin(),
-                                         targetShape->end());
-    int64_t lastDim = sourceTileShape.size() - 1;
+    SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
+                                          targetShape->end());
+    int64_t lastDim = sourceSliceShape.size() - 1;
 
-    sourceTileShape[lastDim] =
+    sourceSliceShape[lastDim] =
         ((*targetShape)[lastDim] * resultElementBits) / sourceElementBits;
 
     Value result = arith::ConstantOp::create(rewriter, loc, resultType,
                                              rewriter.getZeroAttr(resultType));
     SmallVector<int64_t> resultStrides(targetShape->size(), 1);
-    SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
+    SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
 
     VectorType targetType =
         VectorType::get(*targetShape, resultType.getElementType());
@@ -1441,7 +1442,7 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
           (resultOffsets[lastDim] * resultElementBits) / sourceElementBits;
 
       Value sourceSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
-          loc, bitCastOp.getSource(), sourceOffsets, sourceTileShape,
+          loc, bitCastOp.getSource(), sourceOffsets, sourceSliceShape,
           sourceStrides);
       Value bitcastSlice = rewriter.createOrFold<vector::BitCastOp>(
           loc, targetType, sourceSlice);
@@ -1462,13 +1463,18 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
 /// input vectors, interleaving them, and inserting back into the result.
 ///
 /// Example:
-///   vector.interleave %lhs, %rhs : vector<8xf32>
-///   // Unrolled with target shape [4]:
-///   %slice_lhs_0 = vector.extract_strided_slice %lhs[0] : vector<2xf32>
-///   %slice_rhs_0 = vector.extract_strided_slice %rhs[0] : vector<2xf32>
-///   %tile_0 = vector.interleave %slice_lhs_0, %slice_rhs_0 : vector<4xf32>
-///   %result = vector.insert_strided_slice %tile_0, %init[0]
-///   // ... repeat for remaining tiles
+///   Given an interleave Op:
+///
+///     vector.interleave %lhs, %rhs : vector<4x8xf32>
+///
+///   and a target unroll shape of <2x4>, the pattern produces:
+///
+///     %slice_lhs_0 = vector.extract_strided_slice %lhs[0, 0] : vector<2x2xf32>
+///     %slice_rhs_0 = vector.extract_strided_slice %rhs[0, 0] : vector<2x2xf32>
+///     %tile_0 = vector.interleave %slice_lhs_0, %slice_rhs_0
+///       : vector<2x4xf32>
+///     %result = vector.insert_strided_slice %tile_0, %init[0, 0]
+///     // ... repeat for remaining tiles
 struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
   UnrollInterleavePattern(MLIRContext *context,
                           const vector::UnrollVectorOptions &options,
@@ -1480,7 +1486,8 @@ struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
                                 PatternRewriter &rewriter) const override {
     auto targetShape = getTargetShape(options, interleaveOp);
     if (!targetShape)
-      return failure();
+      return rewriter.notifyMatchFailure(interleaveOp,
+                                         "failed to get target shape");
 
     VectorType resultType = interleaveOp.getResultVectorType();
     ArrayRef<int64_t> resultShape = resultType.getShape();
@@ -1490,15 +1497,15 @@ struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
       return rewriter.notifyMatchFailure(
           interleaveOp, "target shape rank must match result rank");
 
-    SmallVector<int64_t> sourceTileShape(targetShape->begin(),
-                                         targetShape->end());
-    int64_t lastDim = sourceTileShape.size() - 1;
-    sourceTileShape[lastDim] = (*targetShape)[lastDim] / 2;
+    SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
+                                          targetShape->end());
+    int64_t lastDim = sourceSliceShape.size() - 1;
+    sourceSliceShape[lastDim] = (*targetShape)[lastDim] / 2;
 
     Value result = arith::ConstantOp::create(rewriter, loc, resultType,
                                              rewriter.getZeroAttr(resultType));
     SmallVector<int64_t> resultStrides(targetShape->size(), 1);
-    SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
+    SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
 
     VectorType targetType =
         VectorType::get(*targetShape, resultType.getElementType());
@@ -1509,10 +1516,10 @@ struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
       sourceOffsets[lastDim] = resultOffsets[lastDim] / 2;
 
       Value lhsSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
-          loc, interleaveOp.getLhs(), sourceOffsets, sourceTileShape,
+          loc, interleaveOp.getLhs(), sourceOffsets, sourceSliceShape,
           sourceStrides);
       Value rhsSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
-          loc, interleaveOp.getRhs(), sourceOffsets, sourceTileShape,
+          loc, interleaveOp.getRhs(), sourceOffsets, sourceSliceShape,
           sourceStrides);
       Value interleaveSlice = rewriter.createOrFold<vector::InterleaveOp>(
           loc, targetType, lhsSlice, rhsSlice);
@@ -1555,7 +1562,8 @@ struct UnrollDeinterleavePattern
                                 PatternRewriter &rewriter) const override {
     auto targetShape = getTargetShape(options, deinterleaveOp);
     if (!targetShape)
-      return failure();
+      return rewriter.notifyMatchFailure(deinterleaveOp,
+                                         "failed to get target shape");
 
     VectorType resultType = deinterleaveOp.getResultVectorType();
     ArrayRef<int64_t> resultShape = resultType.getShape();
@@ -1565,17 +1573,17 @@ struct UnrollDeinterleavePattern
       return rewriter.notifyMatchFailure(
           deinterleaveOp, "target shape rank must match result rank");
 
-    SmallVector<int64_t> sourceTileShape(targetShape->begin(),
-                                         targetShape->end());
-    int64_t lastDim = sourceTileShape.size() - 1;
-    sourceTileShape[lastDim] = (*targetShape)[lastDim] * 2;
+    SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
+                                          targetShape->end());
+    int64_t lastDim = sourceSliceShape.size() - 1;
+    sourceSliceShape[lastDim] = (*targetShape)[lastDim] * 2;
 
-    Value result1 = arith::ConstantOp::create(rewriter, loc, resultType,
-                                              rewriter.getZeroAttr(resultType));
-    Value result2 = arith::ConstantOp::create(rewriter, loc, resultType,
-                                              rewriter.getZeroAttr(resultType));
+    Value resultOdd = arith::ConstantOp::create(
+        rewriter, loc, resultType, rewriter.getZeroAttr(resultType));
+    Value resultEven = arith::ConstantOp::create(
+        rewriter, loc, resultType, rewriter.getZeroAttr(resultType));
     SmallVector<int64_t> resultStrides(targetShape->size(), 1);
-    SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
+    SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
 
     for (SmallVector<int64_t> resultOffsets :
          StaticTileOffsetRange(resultShape, *targetShape)) {
@@ -1583,21 +1591,21 @@ struct UnrollDeinterleavePattern
       sourceOffsets[lastDim] = resultOffsets[lastDim] * 2;
 
       Value sourceSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
-          loc, deinterleaveOp.getSource(), sourceOffsets, sourceTileShape,
+          loc, deinterleaveOp.getSource(), sourceOffsets, sourceSliceShape,
           sourceStrides);
 
       auto deinterleaveSlice =
           vector::DeinterleaveOp::create(rewriter, loc, sourceSlice);
 
-      result1 = rewriter.createOrFold<vector::InsertStridedSliceOp>(
-          loc, deinterleaveSlice.getRes1(), result1, resultOffsets,
+      resultOdd = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, deinterleaveSlice.getRes1(), resultOdd, resultOffsets,
           resultStrides);
-      result2 = rewriter.createOrFold<vector::InsertStridedSliceOp>(
-          loc, deinterleaveSlice.getRes2(), result2, resultOffsets,
+      resultEven = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, deinterleaveSlice.getRes2(), resultEven, resultOffsets,
           resultStrides);
     }
 
-    rewriter.replaceOp(deinterleaveOp, ValueRange{result1, result2});
+    rewriter.replaceOp(deinterleaveOp, ValueRange{resultOdd, resultEven});
     return success();
   }
 
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 358d39d4ff2dd..16637eacd5b95 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -671,23 +671,31 @@ func.func @shape_cast_with_all_unit_target_shape(%v: vector<2xf32>) -> vector<2x
 // -----
 
 // Test BitCastOp unrolling - target shape [4, 4]
-func.func @bitcast_unroll(%arg0: vector<8x4xf32>) -> vector<8x8xi16> {
-  %0 = vector.bitcast %arg0 : vector<8x4xf32> to vector<8x8xi16>
+func.func @bitcast_2d(%v: vector<8x4xf32>) -> vector<8x8xi16> {
+  %0 = vector.bitcast %v : vector<8x4xf32> to vector<8x8xi16>
   return %0 : vector<8x8xi16>
 }
-// CHECK-LABEL: func @bitcast_unroll
-// CHECK-SAME: (%[[ARG:.*]]: vector<8x4xf32>) -> vector<8x8xi16>
+// CHECK-LABEL: func @bitcast_2d
+// CHECK-SAME: (%[[V:.*]]: vector<8x4xf32>) -> vector<8x8xi16>
 // CHECK:   %[[INIT:.*]] = arith.constant dense<0> : vector<8x8xi16>
-// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
+//
+/// SLICE 0:
+// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
 // CHECK:   %[[BC0:.*]] = vector.bitcast %[[S0]] : vector<4x2xf32> to vector<4x4xi16>
 // CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[BC0]], %[[INIT]] {offsets = [0, 0], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
-// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [0, 2], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
+//
+/// SLICE 1:
+// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 2], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
 // CHECK:   %[[BC1:.*]] = vector.bitcast %[[S1]] : vector<4x2xf32> to vector<4x4xi16>
 // CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[BC1]], %[[I0]] {offsets = [0, 4], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
-// CHECK:   %[[S2:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
+//
+/// SLICE 2:
+// CHECK:   %[[S2:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
 // CHECK:   %[[BC2:.*]] = vector.bitcast %[[S2]] : vector<4x2xf32> to vector<4x4xi16>
 // CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[BC2]], %[[I1]] {offsets = [4, 0], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
-// CHECK:   %[[S3:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [4, 2], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
+//
+// SLICE 3:
+// CHECK:   %[[S3:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 2], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
 // CHECK:   %[[BC3:.*]] = vector.bitcast %[[S3]] : vector<4x2xf32> to vector<4x4xi16>
 // CHECK:   %[[I3:.*]] = vector.insert_strided_slice %[[BC3]], %[[I2]] {offsets = [4, 4], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
 // CHECK:   return %[[I3]] : vector<8x8xi16>
@@ -695,25 +703,33 @@ func.func @bitcast_unroll(%arg0: vector<8x4xf32>) -> vector<8x8xi16> {
 // -----
 
 // Test InterleaveOp unrolling - target shape [8]
-func.func @interleave_unroll(%arg0: vector<16xi32>, %arg1: vector<16xi32>) -> vector<32xi32> {
-  %0 = vector.interleave %arg0, %arg1 : vector<16xi32> -> vector<32xi32>
+func.func @interleave_1d(%V: vector<16xi32>, %arg1: vector<16xi32>) -> vector<32xi32> {
+  %0 = vector.interleave %V, %arg1 : vector<16xi32> -> vector<32xi32>
   return %0 : vector<32xi32>
 }
-// CHECK-LABEL: func @interleave_unroll
+// CHECK-LABEL: func @interleave_1d
 // CHECK-SAME: (%[[LHS:.*]]: vector<16xi32>, %[[RHS:.*]]: vector<16xi32>) -> vector<32xi32>
 // CHECK:   %[[INIT:.*]] = arith.constant dense<0> : vector<32xi32>
+//
+/// SLICE 0:
 // CHECK:   %[[L0:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
 // CHECK:   %[[R0:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
 // CHECK:   %[[INT0:.*]] = vector.interleave %[[L0]], %[[R0]] : vector<4xi32> -> vector<8xi32>
 // CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[INT0]], %[[INIT]] {offsets = [0], strides = [1]} : vector<8xi32> into vector<32xi32>
+//
+/// SLICE 1:
 // CHECK:   %[[L1:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
 // CHECK:   %[[R1:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
 // CHECK:   %[[INT1:.*]] = vector.interleave %[[L1]], %[[R1]] : vector<4xi32> -> vector<8xi32>
 // CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[INT1]], %[[I0]] {offsets = [8], strides = [1]} : vector<8xi32> into vector<32xi32>
+//
+/// SLICE 2:
 // CHECK:   %[[L2:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [8], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
 // CHECK:   %[[R2:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [8], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
 // CHECK:   %[[INT2:.*]] = vector.interleave %[[L2]], %[[R2]] : vector<4xi32> -> vector<8xi32>
 // CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[INT2]], %[[I1]] {offsets = [16], strides = [1]} : vector<8xi32> into vector<32xi32>
+//
+/// SLICE 3:
 // CHECK:   %[[L3:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [12], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
 // CHECK:   %[[R3:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [12], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
 // CHECK:   %[[INT3:.*]] = vector.interleave %[[L3]], %[[R3]] : vector<4xi32> -> vector<8xi32>
@@ -723,18 +739,22 @@ func.func @interleave_unroll(%arg0: vector<16xi32>, %arg1: vector<16xi32>) -> ve
 // -----
 
 // Test DeinterleaveOp unrolling - target shape [4]
-func.func @deinterleave_unroll(%arg0: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>) {
-  %0, %1 = vector.deinterleave %arg0 : vector<16xi32> -> vector<8xi32>
+func.func @deinterleave_1d(%V: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>) {
+  %0, %1 = vector.deinterleave %v : vector<16xi32> -> vector<8xi32>
   return %0, %1 : vector<8xi32>, vector<8xi32>
 }
-// CHECK-LABEL: func @deinterleave_unroll
-// CHECK-SAME: (%[[ARG:.*]]: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>)
+// CHECK-LABEL: func @deinterleave_1d
+// CHECK-SAME: (%[[V:.*]]: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>)
 // CHECK:   %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
-// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xi32> to vector<8xi32>
+//
+/// SLICE 0:
+// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xi32> to vector<8xi32>
 // CHECK:   {{.*}} = vector.deinterleave %[[S0]] : vector<8xi32> -> vector<4xi32>
 // CHECK:   {{.*}} = vector.insert_strided_slice {{.*}}, %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
 // CHECK:   {{.*}} = vector.insert_strided_slice {{.*}}, %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
-// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[ARG]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xi32> to vector<8xi32>
+//
+/// SLICE 1:
+// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xi32> to vector<8xi32>
 // CHECK:   {{.*}} = vector.deinterleave %[[S1]] : vector<8xi32> -> vector<4xi32>
 // CHECK:   {{.*}} = vector.insert_strided_slice {{.*}}, {{.*}} {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
 // CHECK:   {{.*}} = vector.insert_strided_slice {{.*}}, {{.*}} {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>

>From 7c8150e313525b64d44bf3a09de48abdd40d19c7 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 29 Apr 2026 00:51:51 +0000
Subject: [PATCH 4/6] fix test

---
 mlir/test/Dialect/Vector/vector-unroll-options.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 16637eacd5b95..b1a0c4211f5d1 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -739,7 +739,7 @@ func.func @interleave_1d(%V: vector<16xi32>, %arg1: vector<16xi32>) -> vector<32
 // -----
 
 // Test DeinterleaveOp unrolling - target shape [4]
-func.func @deinterleave_1d(%V: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>) {
+func.func @deinterleave_1d(%v: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>) {
   %0, %1 = vector.deinterleave %v : vector<16xi32> -> vector<8xi32>
   return %0, %1 : vector<8xi32>, vector<8xi32>
 }

>From d1a859adf7f55506f6c766bb1ba8ba4af2c04976 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 29 Apr 2026 18:25:19 +0000
Subject: [PATCH 5/6] address feedback and improve tests

---
 .../Vector/Transforms/VectorUnroll.cpp        | 48 ++++++----
 .../Dialect/Vector/vector-unroll-options.mlir | 96 +++++++++----------
 .../Dialect/Vector/TestVectorTransforms.cpp   |  4 +-
 3 files changed, 80 insertions(+), 68 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 4d1d39cd9d61d..acf05a00872d7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1389,11 +1389,20 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
   vector::UnrollVectorOptions options;
 };
 
-// Unroll vector::BitCastOp into smaller tile-based bitcast operations.
+// Unroll vector::BitCastOp into smaller slice-based bitcast operations.
 // Tiles the result vector into target shape chunks and bitcasts corresponding
 // source slices, accounting for element bitwidth ratios.
-// Example: bitcast v8f32 to v16f16 with target shape [4] unrolls into
-// multiple bitcast operations on 4-element tiles.
+/// Example:
+///   Given a deinterleave Op:
+///
+///     vector.bitcast %src : vector<4x8xf32>
+///
+///   and a target unroll shape of <2x4>, the pattern produces:
+///
+///     %slice_0 = vector.extract_strided_slice %lhs[0, 0] : vector<2x4xf32>
+///     %slice_0 = vector.bitcast %slice_0 : vector<2x4xf32>
+///     %result = vector.insert_strided_slice %slice_0, %init[0, 0]
+///     // ... repeat for remaining slices
 struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
   UnrollBitCastPattern(MLIRContext *context,
                        const vector::UnrollVectorOptions &options,
@@ -1458,8 +1467,8 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
   vector::UnrollVectorOptions options;
 };
 
-/// Pattern to unroll vector.interleave into smaller tile-sized operations.
-/// Decomposes a large interleave into tiles by extracting slices from both
+/// Pattern to unroll vector.interleave into smaller slice-sized operations.
+/// Decomposes a large interleave into slices by extracting slices from both
 /// input vectors, interleaving them, and inserting back into the result.
 ///
 /// Example:
@@ -1471,10 +1480,10 @@ struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
 ///
 ///     %slice_lhs_0 = vector.extract_strided_slice %lhs[0, 0] : vector<2x2xf32>
 ///     %slice_rhs_0 = vector.extract_strided_slice %rhs[0, 0] : vector<2x2xf32>
-///     %tile_0 = vector.interleave %slice_lhs_0, %slice_rhs_0
+///     %slice_0 = vector.interleave %slice_lhs_0, %slice_rhs_0
 ///       : vector<2x4xf32>
-///     %result = vector.insert_strided_slice %tile_0, %init[0, 0]
-///     // ... repeat for remaining tiles
+///     %result = vector.insert_strided_slice %slice_0, %init[0, 0]
+///     // ... repeat for remaining slices
 struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
   UnrollInterleavePattern(MLIRContext *context,
                           const vector::UnrollVectorOptions &options,
@@ -1535,21 +1544,24 @@ struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
   vector::UnrollVectorOptions options;
 };
 
-/// Pattern to unroll vector.deinterleave into smaller tile-sized operations.
+/// Pattern to unroll vector.deinterleave into smaller slice-sized operations.
 /// Decomposes a large deinterleave (which splits a vector into even/odd halves)
 /// by extracting source slices, deinterleaving them, and inserting into two
 /// result vectors.
 ///
 /// Example:
-///   %res1, %res2 = vector.deinterleave %src : vector<8xf32>
-///   // Result: %res1 = [src[0], src[2], src[4], src[6]]
-///   //         %res2 = [src[1], src[3], src[5], src[7]]
-///   // Unrolled with target shape [2]:
-///   %slice_0 = vector.extract_strided_slice %src[0] : vector<4xf32>
-///   %tile1_0, %tile2_0 = vector.deinterleave %slice_0 : vector<2xf32>
-///   %result1 = vector.insert_strided_slice %tile1_0, %init1[0]
-///   %result2 = vector.insert_strided_slice %tile2_0, %init2[0]
-///   // ... repeat for remaining tiles
+///   Given a deinterleave Op:
+///
+///     vector.deinterleave %src : vector<4x8xf32>
+///
+///   and a target unroll shape of <2x4>, the pattern produces:
+///
+///   %slice_0 = vector.extract_strided_slice %src[0, 0] : vector<2x4xf32>
+///   %slice_lhs_0, %slice_rhs_0 = vector.deinterleave %slice_0 :
+///   vector<2x4xf32> %result1 = vector.insert_strided_slice %slice_lhs_0,
+///   %init1[0, 0] %result2 = vector.insert_strided_slice %slice_rhs_0,
+///   %init2[0, 0]
+///   // ... repeat for remaining slices
 struct UnrollDeinterleavePattern
     : public OpRewritePattern<vector::DeinterleaveOp> {
   UnrollDeinterleavePattern(MLIRContext *context,
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index b1a0c4211f5d1..bb6fc4e38813d 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -679,22 +679,22 @@ func.func @bitcast_2d(%v: vector<8x4xf32>) -> vector<8x8xi16> {
 // CHECK-SAME: (%[[V:.*]]: vector<8x4xf32>) -> vector<8x8xi16>
 // CHECK:   %[[INIT:.*]] = arith.constant dense<0> : vector<8x8xi16>
 //
-/// SLICE 0:
+/// SLICE 0,0:
 // CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
 // CHECK:   %[[BC0:.*]] = vector.bitcast %[[S0]] : vector<4x2xf32> to vector<4x4xi16>
 // CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[BC0]], %[[INIT]] {offsets = [0, 0], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
 //
-/// SLICE 1:
+/// SLICE 0,1:
 // CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 2], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
 // CHECK:   %[[BC1:.*]] = vector.bitcast %[[S1]] : vector<4x2xf32> to vector<4x4xi16>
 // CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[BC1]], %[[I0]] {offsets = [0, 4], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
 //
-/// SLICE 2:
+/// SLICE 1,0:
 // CHECK:   %[[S2:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
 // CHECK:   %[[BC2:.*]] = vector.bitcast %[[S2]] : vector<4x2xf32> to vector<4x4xi16>
 // CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[BC2]], %[[I1]] {offsets = [4, 0], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
 //
-// SLICE 3:
+/// SLICE 1,1:
 // CHECK:   %[[S3:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 2], sizes = [4, 2], strides = [1, 1]} : vector<8x4xf32> to vector<4x2xf32>
 // CHECK:   %[[BC3:.*]] = vector.bitcast %[[S3]] : vector<4x2xf32> to vector<4x4xi16>
 // CHECK:   %[[I3:.*]] = vector.insert_strided_slice %[[BC3]], %[[I2]] {offsets = [4, 4], strides = [1, 1]} : vector<4x4xi16> into vector<8x8xi16>
@@ -702,60 +702,60 @@ func.func @bitcast_2d(%v: vector<8x4xf32>) -> vector<8x8xi16> {
 
 // -----
 
-// Test InterleaveOp unrolling - target shape [8]
-func.func @interleave_1d(%V: vector<16xi32>, %arg1: vector<16xi32>) -> vector<32xi32> {
-  %0 = vector.interleave %V, %arg1 : vector<16xi32> -> vector<32xi32>
-  return %0 : vector<32xi32>
+// Test InterleaveOp unrolling - target shape [2x4]
+func.func @interleave_2d(%V: vector<4x4xi32>, %arg1: vector<4x4xi32>) -> vector<4x8xi32> {
+  %0 = vector.interleave %V, %arg1 : vector<4x4xi32> -> vector<4x8xi32>
+  return %0 : vector<4x8xi32>
 }
-// CHECK-LABEL: func @interleave_1d
-// CHECK-SAME: (%[[LHS:.*]]: vector<16xi32>, %[[RHS:.*]]: vector<16xi32>) -> vector<32xi32>
-// CHECK:   %[[INIT:.*]] = arith.constant dense<0> : vector<32xi32>
+// CHECK-LABEL: func @interleave_2d
+// CHECK-SAME: (%[[LHS:.*]]: vector<4x4xi32>, %[[RHS:.*]]: vector<4x4xi32>) -> vector<4x8xi32>
+// CHECK:   %[[INIT:.*]] = arith.constant dense<0> : vector<4x8xi32>
 //
-/// SLICE 0:
-// CHECK:   %[[L0:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
-// CHECK:   %[[R0:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
-// CHECK:   %[[INT0:.*]] = vector.interleave %[[L0]], %[[R0]] : vector<4xi32> -> vector<8xi32>
-// CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[INT0]], %[[INIT]] {offsets = [0], strides = [1]} : vector<8xi32> into vector<32xi32>
+/// SLICE 0,0:
+// CHECK:   %[[L0:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:   %[[R0:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:   %[[INT0:.*]] = vector.interleave %[[L0]], %[[R0]] : vector<2x2xi32> -> vector<2x4xi32>
+// CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[INT0]], %[[INIT]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xi32> into vector<4x8xi32>
 //
-/// SLICE 1:
-// CHECK:   %[[L1:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
-// CHECK:   %[[R1:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
-// CHECK:   %[[INT1:.*]] = vector.interleave %[[L1]], %[[R1]] : vector<4xi32> -> vector<8xi32>
-// CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[INT1]], %[[I0]] {offsets = [8], strides = [1]} : vector<8xi32> into vector<32xi32>
+/// SLICE 0,1:
+// CHECK:   %[[L1:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:   %[[R1:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:   %[[INT1:.*]] = vector.interleave %[[L1]], %[[R1]] : vector<2x2xi32> -> vector<2x4xi32>
+// CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[INT1]], %[[I0]] {offsets = [0, 4], strides = [1, 1]} : vector<2x4xi32> into vector<4x8xi32>
 //
-/// SLICE 2:
-// CHECK:   %[[L2:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [8], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
-// CHECK:   %[[R2:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [8], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
-// CHECK:   %[[INT2:.*]] = vector.interleave %[[L2]], %[[R2]] : vector<4xi32> -> vector<8xi32>
-// CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[INT2]], %[[I1]] {offsets = [16], strides = [1]} : vector<8xi32> into vector<32xi32>
+/// SLICE 1,0:
+// CHECK:   %[[L2:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:   %[[R2:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:   %[[INT2:.*]] = vector.interleave %[[L2]], %[[R2]] : vector<2x2xi32> -> vector<2x4xi32>
+// CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[INT2]], %[[I1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xi32> into vector<4x8xi32>
 //
-/// SLICE 3:
-// CHECK:   %[[L3:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [12], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
-// CHECK:   %[[R3:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [12], sizes = [4], strides = [1]} : vector<16xi32> to vector<4xi32>
-// CHECK:   %[[INT3:.*]] = vector.interleave %[[L3]], %[[R3]] : vector<4xi32> -> vector<8xi32>
-// CHECK:   %[[I3:.*]] = vector.insert_strided_slice %[[INT3]], %[[I2]] {offsets = [24], strides = [1]} : vector<8xi32> into vector<32xi32>
-// CHECK:   return %[[I3]] : vector<32xi32>
+/// SLICE 1,1:
+// CHECK:   %[[L3:.*]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:   %[[R3:.*]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:   %[[INT3:.*]] = vector.interleave %[[L3]], %[[R3]] : vector<2x2xi32> -> vector<2x4xi32>
+// CHECK:   %[[I3:.*]] = vector.insert_strided_slice %[[INT3]], %[[I2]] {offsets = [2, 4], strides = [1, 1]} : vector<2x4xi32> into vector<4x8xi32>
+// CHECK:   return %[[I3]] : vector<4x8xi32>
 
 // -----
 
-// Test DeinterleaveOp unrolling - target shape [4]
-func.func @deinterleave_1d(%v: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>) {
-  %0, %1 = vector.deinterleave %v : vector<16xi32> -> vector<8xi32>
-  return %0, %1 : vector<8xi32>, vector<8xi32>
+// Test DeinterleaveOp unrolling - target shape [2x4]
+func.func @deinterleave_2d(%v: vector<4x8xi32>) -> (vector<4x4xi32>, vector<4x4xi32>) {
+  %0, %1 = vector.deinterleave %v : vector<4x8xi32> -> vector<4x4xi32>
+  return %0, %1 : vector<4x4xi32>, vector<4x4xi32>
 }
-// CHECK-LABEL: func @deinterleave_1d
-// CHECK-SAME: (%[[V:.*]]: vector<16xi32>) -> (vector<8xi32>, vector<8xi32>)
-// CHECK:   %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+// CHECK-LABEL: func @deinterleave_2d
+// CHECK-SAME: (%[[V:.*]]: vector<4x8xi32>) -> (vector<4x4xi32>, vector<4x4xi32>)
+// CHECK:   %[[CST:.*]] = arith.constant dense<0> : vector<4x4xi32>
 //
 /// SLICE 0:
-// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xi32> to vector<8xi32>
-// CHECK:   {{.*}} = vector.deinterleave %[[S0]] : vector<8xi32> -> vector<4xi32>
-// CHECK:   {{.*}} = vector.insert_strided_slice {{.*}}, %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
-// CHECK:   {{.*}} = vector.insert_strided_slice {{.*}}, %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
+// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi32> to vector<2x8xi32>
+// CHECK:   {{.*}} = vector.deinterleave %[[S0]] : vector<2x8xi32> -> vector<2x4xi32>
+// CHECK:   {{.*}} = vector.insert_strided_slice {{.*}}, %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xi32> into vector<4x4xi32>
+// CHECK:   {{.*}} = vector.insert_strided_slice {{.*}}, %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xi32> into vector<4x4xi32>
 //
 /// SLICE 1:
-// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xi32> to vector<8xi32>
-// CHECK:   {{.*}} = vector.deinterleave %[[S1]] : vector<8xi32> -> vector<4xi32>
-// CHECK:   {{.*}} = vector.insert_strided_slice {{.*}}, {{.*}} {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
-// CHECK:   {{.*}} = vector.insert_strided_slice {{.*}}, {{.*}} {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
-// CHECK:   return {{.*}}, {{.*}} : vector<8xi32>, vector<8xi32>
+// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi32> to vector<2x8xi32>
+// CHECK:   {{.*}} = vector.deinterleave %[[S1]] : vector<2x8xi32> -> vector<2x4xi32>
+// CHECK:   {{.*}} = vector.insert_strided_slice {{.*}}, {{.*}} {offsets = [2, 0], strides = [1, 1]} : vector<2x4xi32> into vector<4x4xi32>
+// CHECK:   {{.*}} = vector.insert_strided_slice {{.*}}, {{.*}} {offsets = [2, 0], strides = [1, 1]} : vector<2x4xi32> into vector<4x4xi32>
+// CHECK:   return {{.*}}, {{.*}} : vector<4x4xi32>, vector<4x4xi32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index fe31d6b3e9639..043181c16c759 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -226,13 +226,13 @@ struct TestVectorUnrollingPatterns
                       }));
     populateVectorUnrollPatterns(
         patterns, UnrollVectorOptions()
-                      .setNativeShape(ArrayRef<int64_t>{8})
+                      .setNativeShape(ArrayRef<int64_t>{2, 4})
                       .setFilterConstraint([](Operation *op) {
                         return success(isa<vector::InterleaveOp>(op));
                       }));
     populateVectorUnrollPatterns(
         patterns, UnrollVectorOptions()
-                      .setNativeShape(ArrayRef<int64_t>{4})
+                      .setNativeShape(ArrayRef<int64_t>{2, 4})
                       .setFilterConstraint([](Operation *op) {
                         return success(isa<vector::DeinterleaveOp>(op));
                       }));

>From 2d52306205ab8fa5a1e3976eb4321812a5c8e6e7 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 30 Apr 2026 22:51:05 +0000
Subject: [PATCH 6/6] fix comments

---
 mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index acf05a00872d7..25d2e2c578441 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1390,10 +1390,10 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
 };
 
 // Unroll vector::BitCastOp into smaller slice-based bitcast operations.
-// Tiles the result vector into target shape chunks and bitcasts corresponding
-// source slices, accounting for element bitwidth ratios.
+// Decomposes the result vector into target shape chunks and bitcasts
+// corresponding source slices, accounting for element bitwidth ratios.
 /// Example:
-///   Given a deinterleave Op:
+///   Given a bitcast Op:
 ///
 ///     vector.bitcast %src : vector<4x8xf32>
 ///



More information about the Mlir-commits mailing list