[Mlir-commits] [mlir] [MLIR][Vector] Add unrolling support for bitcast, interleave, and deinterleave ops (PR #194513)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Apr 29 02:08:10 PDT 2026
================
@@ -1389,6 +1389,230 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
vector::UnrollVectorOptions options;
};
+// Unroll vector::BitCastOp into smaller tile-based bitcast operations.
+// Tiles the result vector into target shape chunks and bitcasts corresponding
+// source slices, accounting for element bitwidth ratios.
+// Example: bitcast v8f32 to v16f16 with target shape [4] unrolls into
+// multiple bitcast operations on 4-element tiles.
+struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
+ UnrollBitCastPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::BitCastOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, bitCastOp);
+ if (!targetShape)
+ return rewriter.notifyMatchFailure(bitCastOp,
+ "failed to get target shape");
+
+ VectorType sourceType = bitCastOp.getSourceVectorType();
+ VectorType resultType = bitCastOp.getResultVectorType();
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+ Location loc = bitCastOp.getLoc();
+
+ if (targetShape->size() != resultShape.size())
+ return rewriter.notifyMatchFailure(
+ bitCastOp, "target shape rank must match result rank");
+
+ unsigned sourceElementBits = sourceType.getElementTypeBitWidth();
+ unsigned resultElementBits = resultType.getElementTypeBitWidth();
+
+ SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
+ targetShape->end());
+ int64_t lastDim = sourceSliceShape.size() - 1;
+
+ sourceSliceShape[lastDim] =
+ ((*targetShape)[lastDim] * resultElementBits) / sourceElementBits;
+
+ Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
+ SmallVector<int64_t> resultStrides(targetShape->size(), 1);
+ SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
+
+ VectorType targetType =
+ VectorType::get(*targetShape, resultType.getElementType());
+
+ for (SmallVector<int64_t> resultOffsets :
+ StaticTileOffsetRange(resultShape, *targetShape)) {
+ SmallVector<int64_t> sourceOffsets = resultOffsets;
+ sourceOffsets[lastDim] =
+ (resultOffsets[lastDim] * resultElementBits) / sourceElementBits;
+
+ Value sourceSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, bitCastOp.getSource(), sourceOffsets, sourceSliceShape,
+ sourceStrides);
+ Value bitcastSlice = rewriter.createOrFold<vector::BitCastOp>(
+ loc, targetType, sourceSlice);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, bitcastSlice, result, resultOffsets, resultStrides);
+ }
+
+ rewriter.replaceOp(bitCastOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
+/// Pattern to unroll vector.interleave into smaller tile-sized operations.
+/// Decomposes a large interleave into tiles by extracting slices from both
+/// input vectors, interleaving them, and inserting back into the result.
+///
+/// Example:
+/// Given an interleave Op:
+///
+/// vector.interleave %lhs, %rhs : vector<4x8xf32>
+///
+/// and a target unroll shape of <2x4>, the pattern produces:
+///
+/// %slice_lhs_0 = vector.extract_strided_slice %lhs[0, 0] : vector<2x2xf32>
+/// %slice_rhs_0 = vector.extract_strided_slice %rhs[0, 0] : vector<2x2xf32>
+/// %tile_0 = vector.interleave %slice_lhs_0, %slice_rhs_0
+/// : vector<2x4xf32>
+/// %result = vector.insert_strided_slice %tile_0, %init[0, 0]
+/// // ... repeat for remaining tiles
+struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
+ UnrollInterleavePattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::InterleaveOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::InterleaveOp interleaveOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, interleaveOp);
+ if (!targetShape)
+ return rewriter.notifyMatchFailure(interleaveOp,
+ "failed to get target shape");
+
+ VectorType resultType = interleaveOp.getResultVectorType();
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+ Location loc = interleaveOp.getLoc();
+
+ if (targetShape->size() != resultShape.size())
+ return rewriter.notifyMatchFailure(
+ interleaveOp, "target shape rank must match result rank");
+
+ SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
+ targetShape->end());
+ int64_t lastDim = sourceSliceShape.size() - 1;
+ sourceSliceShape[lastDim] = (*targetShape)[lastDim] / 2;
+
+ Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
+ SmallVector<int64_t> resultStrides(targetShape->size(), 1);
+ SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
+
+ VectorType targetType =
+ VectorType::get(*targetShape, resultType.getElementType());
+
+ for (SmallVector<int64_t> resultOffsets :
+ StaticTileOffsetRange(resultShape, *targetShape)) {
+ SmallVector<int64_t> sourceOffsets = resultOffsets;
+ sourceOffsets[lastDim] = resultOffsets[lastDim] / 2;
+
+ Value lhsSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, interleaveOp.getLhs(), sourceOffsets, sourceSliceShape,
+ sourceStrides);
+ Value rhsSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, interleaveOp.getRhs(), sourceOffsets, sourceSliceShape,
+ sourceStrides);
+ Value interleaveSlice = rewriter.createOrFold<vector::InterleaveOp>(
+ loc, targetType, lhsSlice, rhsSlice);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, interleaveSlice, result, resultOffsets, resultStrides);
+ }
+
+ rewriter.replaceOp(interleaveOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
+/// Pattern to unroll vector.deinterleave into smaller tile-sized operations.
+/// Decomposes a large deinterleave (which splits a vector into even/odd halves)
+/// by extracting source slices, deinterleaving them, and inserting into two
+/// result vectors.
+///
+/// Example:
+/// %res1, %res2 = vector.deinterleave %src : vector<8xf32>
+/// // Result: %res1 = [src[0], src[2], src[4], src[6]]
+/// // %res2 = [src[1], src[3], src[5], src[7]]
+/// // Unrolled with target shape [2]:
+/// %slice_0 = vector.extract_strided_slice %src[0] : vector<4xf32>
+/// %tile1_0, %tile2_0 = vector.deinterleave %slice_0 : vector<2xf32>
+/// %result1 = vector.insert_strided_slice %tile1_0, %init1[0]
+/// %result2 = vector.insert_strided_slice %tile2_0, %init2[0]
+/// // ... repeat for remaining tiles
----------------
banach-space wrote:
[nit] Please reformat to match `UnrollInterleavePattern`
https://github.com/llvm/llvm-project/pull/194513
More information about the Mlir-commits
mailing list