[Mlir-commits] [mlir] [MLIR][Vector] Add unrolling support for bitcast, interleave, and deinterleave ops (PR #194513)
Jianhui Li
llvmlistbot at llvm.org
Tue Apr 28 17:48:51 PDT 2026
================
@@ -1389,6 +1389,222 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
vector::UnrollVectorOptions options;
};
+// Unroll vector::BitCastOp into smaller tile-based bitcast operations.
+// Tiles the result vector into target shape chunks and bitcasts corresponding
+// source slices, accounting for element bitwidth ratios.
+// Example: bitcast v8f32 to v16f16 with target shape [4] unrolls into
+// multiple bitcast operations on 4-element tiles.
+struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
+ UnrollBitCastPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::BitCastOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, bitCastOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType sourceType = bitCastOp.getSourceVectorType();
+ VectorType resultType = bitCastOp.getResultVectorType();
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+ Location loc = bitCastOp.getLoc();
+
+ if (targetShape->size() != resultShape.size())
+ return rewriter.notifyMatchFailure(
+ bitCastOp, "target shape rank must match result rank");
+
+ unsigned sourceElementBits = sourceType.getElementTypeBitWidth();
+ unsigned resultElementBits = resultType.getElementTypeBitWidth();
+
+ SmallVector<int64_t> sourceTileShape(targetShape->begin(),
+ targetShape->end());
+ int64_t lastDim = sourceTileShape.size() - 1;
+
+ sourceTileShape[lastDim] =
+ ((*targetShape)[lastDim] * resultElementBits) / sourceElementBits;
+
+ Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
+ SmallVector<int64_t> resultStrides(targetShape->size(), 1);
+ SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
+
+ VectorType targetType =
+ VectorType::get(*targetShape, resultType.getElementType());
+
+ for (SmallVector<int64_t> resultOffsets :
+ StaticTileOffsetRange(resultShape, *targetShape)) {
+ SmallVector<int64_t> sourceOffsets = resultOffsets;
+ sourceOffsets[lastDim] =
+ (resultOffsets[lastDim] * resultElementBits) / sourceElementBits;
+
+ Value sourceSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, bitCastOp.getSource(), sourceOffsets, sourceTileShape,
+ sourceStrides);
+ Value bitcastSlice = rewriter.createOrFold<vector::BitCastOp>(
+ loc, targetType, sourceSlice);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, bitcastSlice, result, resultOffsets, resultStrides);
+ }
+
+ rewriter.replaceOp(bitCastOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
+/// Pattern to unroll vector.interleave into smaller tile-sized operations.
+/// Decomposes a large interleave into tiles by extracting slices from both
+/// input vectors, interleaving them, and inserting back into the result.
+///
+/// Example:
+/// vector.interleave %lhs, %rhs : vector<8xf32>
+/// // Unrolled with target shape [4]:
----------------
Jianhui-Li wrote:
modified as suggested
https://github.com/llvm/llvm-project/pull/194513
More information about the Mlir-commits
mailing list