[Mlir-commits] [mlir] [MLIR][Vector] Add unrolling support for bitcast, interleave, and deinterleave ops (PR #194513)

Jianhui Li llvmlistbot at llvm.org
Tue Apr 28 17:49:02 PDT 2026


================
@@ -1389,6 +1389,222 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
   vector::UnrollVectorOptions options;
 };
 
+// Unroll vector::BitCastOp into smaller tile-based bitcast operations.
+// Tiles the result vector into target shape chunks and bitcasts corresponding
+// source slices, accounting for element bitwidth ratios.
+// Example: bitcast v8f32 to v16f16 with target shape [4] unrolls into
+// multiple bitcast operations on 4-element tiles.
+struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
+  UnrollBitCastPattern(MLIRContext *context,
+                       const vector::UnrollVectorOptions &options,
+                       PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::BitCastOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, bitCastOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType sourceType = bitCastOp.getSourceVectorType();
+    VectorType resultType = bitCastOp.getResultVectorType();
+    ArrayRef<int64_t> resultShape = resultType.getShape();
+    Location loc = bitCastOp.getLoc();
+
+    if (targetShape->size() != resultShape.size())
+      return rewriter.notifyMatchFailure(
+          bitCastOp, "target shape rank must match result rank");
+
+    unsigned sourceElementBits = sourceType.getElementTypeBitWidth();
+    unsigned resultElementBits = resultType.getElementTypeBitWidth();
+
+    SmallVector<int64_t> sourceTileShape(targetShape->begin(),
+                                         targetShape->end());
+    int64_t lastDim = sourceTileShape.size() - 1;
+
+    sourceTileShape[lastDim] =
+        ((*targetShape)[lastDim] * resultElementBits) / sourceElementBits;
+
+    Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+                                             rewriter.getZeroAttr(resultType));
+    SmallVector<int64_t> resultStrides(targetShape->size(), 1);
+    SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
+
+    VectorType targetType =
+        VectorType::get(*targetShape, resultType.getElementType());
+
+    for (SmallVector<int64_t> resultOffsets :
+         StaticTileOffsetRange(resultShape, *targetShape)) {
+      SmallVector<int64_t> sourceOffsets = resultOffsets;
+      sourceOffsets[lastDim] =
+          (resultOffsets[lastDim] * resultElementBits) / sourceElementBits;
+
+      Value sourceSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+          loc, bitCastOp.getSource(), sourceOffsets, sourceTileShape,
+          sourceStrides);
+      Value bitcastSlice = rewriter.createOrFold<vector::BitCastOp>(
+          loc, targetType, sourceSlice);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, bitcastSlice, result, resultOffsets, resultStrides);
+    }
+
+    rewriter.replaceOp(bitCastOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
+/// Pattern to unroll vector.interleave into smaller tile-sized operations.
+/// Decomposes a large interleave into tiles by extracting slices from both
+/// input vectors, interleaving them, and inserting back into the result.
+///
+/// Example:
+///   vector.interleave %lhs, %rhs : vector<8xf32>
+///   // Unrolled with target shape [4]:
+///   %slice_lhs_0 = vector.extract_strided_slice %lhs[0] : vector<2xf32>
+///   %slice_rhs_0 = vector.extract_strided_slice %rhs[0] : vector<2xf32>
+///   %tile_0 = vector.interleave %slice_lhs_0, %slice_rhs_0 : vector<4xf32>
+///   %result = vector.insert_strided_slice %tile_0, %init[0]
+///   // ... repeat for remaining tiles
+struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
+  UnrollInterleavePattern(MLIRContext *context,
+                          const vector::UnrollVectorOptions &options,
+                          PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::InterleaveOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::InterleaveOp interleaveOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, interleaveOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType resultType = interleaveOp.getResultVectorType();
+    ArrayRef<int64_t> resultShape = resultType.getShape();
+    Location loc = interleaveOp.getLoc();
+
+    if (targetShape->size() != resultShape.size())
+      return rewriter.notifyMatchFailure(
+          interleaveOp, "target shape rank must match result rank");
+
+    SmallVector<int64_t> sourceTileShape(targetShape->begin(),
+                                         targetShape->end());
+    int64_t lastDim = sourceTileShape.size() - 1;
+    sourceTileShape[lastDim] = (*targetShape)[lastDim] / 2;
+
+    Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+                                             rewriter.getZeroAttr(resultType));
+    SmallVector<int64_t> resultStrides(targetShape->size(), 1);
+    SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
+
+    VectorType targetType =
+        VectorType::get(*targetShape, resultType.getElementType());
+
+    for (SmallVector<int64_t> resultOffsets :
+         StaticTileOffsetRange(resultShape, *targetShape)) {
+      SmallVector<int64_t> sourceOffsets = resultOffsets;
+      sourceOffsets[lastDim] = resultOffsets[lastDim] / 2;
+
+      Value lhsSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+          loc, interleaveOp.getLhs(), sourceOffsets, sourceTileShape,
+          sourceStrides);
+      Value rhsSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+          loc, interleaveOp.getRhs(), sourceOffsets, sourceTileShape,
+          sourceStrides);
+      Value interleaveSlice = rewriter.createOrFold<vector::InterleaveOp>(
+          loc, targetType, lhsSlice, rhsSlice);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, interleaveSlice, result, resultOffsets, resultStrides);
+    }
+
+    rewriter.replaceOp(interleaveOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
+/// Pattern to unroll vector.deinterleave into smaller tile-sized operations.
+/// Decomposes a large deinterleave (which splits a vector into even/odd halves)
+/// by extracting source slices, deinterleaving them, and inserting into two
+/// result vectors.
+///
+/// Example:
+///   %res1, %res2 = vector.deinterleave %src : vector<8xf32>
+///   // Result: %res1 = [src[0], src[2], src[4], src[6]]
+///   //         %res2 = [src[1], src[3], src[5], src[7]]
+///   // Unrolled with target shape [2]:
+///   %slice_0 = vector.extract_strided_slice %src[0] : vector<4xf32>
+///   %tile1_0, %tile2_0 = vector.deinterleave %slice_0 : vector<2xf32>
+///   %result1 = vector.insert_strided_slice %tile1_0, %init1[0]
+///   %result2 = vector.insert_strided_slice %tile2_0, %init2[0]
+///   // ... repeat for remaining tiles
+struct UnrollDeinterleavePattern
+    : public OpRewritePattern<vector::DeinterleaveOp> {
+  UnrollDeinterleavePattern(MLIRContext *context,
+                            const vector::UnrollVectorOptions &options,
+                            PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::DeinterleaveOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::DeinterleaveOp deinterleaveOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, deinterleaveOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType resultType = deinterleaveOp.getResultVectorType();
+    ArrayRef<int64_t> resultShape = resultType.getShape();
+    Location loc = deinterleaveOp.getLoc();
+
+    if (targetShape->size() != resultShape.size())
+      return rewriter.notifyMatchFailure(
+          deinterleaveOp, "target shape rank must match result rank");
+
+    SmallVector<int64_t> sourceTileShape(targetShape->begin(),
+                                         targetShape->end());
+    int64_t lastDim = sourceTileShape.size() - 1;
+    sourceTileShape[lastDim] = (*targetShape)[lastDim] * 2;
+
+    Value result1 = arith::ConstantOp::create(rewriter, loc, resultType,
+                                              rewriter.getZeroAttr(resultType));
+    Value result2 = arith::ConstantOp::create(rewriter, loc, resultType,
+                                              rewriter.getZeroAttr(resultType));
----------------
Jianhui-Li wrote:

modfied. Thanks!

https://github.com/llvm/llvm-project/pull/194513


More information about the Mlir-commits mailing list