[Mlir-commits] [mlir] [MLIR][Vector] Add unrolling support for bitcast, interleave, and deinterleave ops (PR #194513)

Jianhui Li llvmlistbot at llvm.org
Tue Apr 28 17:48:51 PDT 2026


================
@@ -1389,6 +1389,222 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
   vector::UnrollVectorOptions options;
 };
 
+// Unroll vector::BitCastOp into smaller tile-based bitcast operations.
+// Tiles the result vector into target shape chunks and bitcasts corresponding
+// source slices, accounting for element bitwidth ratios.
+// Example: bitcast v8f32 to v16f16 with target shape [4] unrolls into
+// multiple bitcast operations on 4-element tiles.
+struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
+  UnrollBitCastPattern(MLIRContext *context,
+                       const vector::UnrollVectorOptions &options,
+                       PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::BitCastOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, bitCastOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType sourceType = bitCastOp.getSourceVectorType();
+    VectorType resultType = bitCastOp.getResultVectorType();
+    ArrayRef<int64_t> resultShape = resultType.getShape();
+    Location loc = bitCastOp.getLoc();
+
+    if (targetShape->size() != resultShape.size())
+      return rewriter.notifyMatchFailure(
+          bitCastOp, "target shape rank must match result rank");
+
+    unsigned sourceElementBits = sourceType.getElementTypeBitWidth();
+    unsigned resultElementBits = resultType.getElementTypeBitWidth();
+
+    SmallVector<int64_t> sourceTileShape(targetShape->begin(),
+                                         targetShape->end());
+    int64_t lastDim = sourceTileShape.size() - 1;
+
+    sourceTileShape[lastDim] =
+        ((*targetShape)[lastDim] * resultElementBits) / sourceElementBits;
+
+    Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+                                             rewriter.getZeroAttr(resultType));
+    SmallVector<int64_t> resultStrides(targetShape->size(), 1);
+    SmallVector<int64_t> sourceStrides(sourceTileShape.size(), 1);
+
+    VectorType targetType =
+        VectorType::get(*targetShape, resultType.getElementType());
+
+    for (SmallVector<int64_t> resultOffsets :
+         StaticTileOffsetRange(resultShape, *targetShape)) {
+      SmallVector<int64_t> sourceOffsets = resultOffsets;
+      sourceOffsets[lastDim] =
+          (resultOffsets[lastDim] * resultElementBits) / sourceElementBits;
+
+      Value sourceSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+          loc, bitCastOp.getSource(), sourceOffsets, sourceTileShape,
+          sourceStrides);
+      Value bitcastSlice = rewriter.createOrFold<vector::BitCastOp>(
+          loc, targetType, sourceSlice);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, bitcastSlice, result, resultOffsets, resultStrides);
+    }
+
+    rewriter.replaceOp(bitCastOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
+/// Pattern to unroll vector.interleave into smaller tile-sized operations.
+/// Decomposes a large interleave into tiles by extracting slices from both
+/// input vectors, interleaving them, and inserting back into the result.
+///
+/// Example:
+///   vector.interleave %lhs, %rhs : vector<8xf32>
+///   // Unrolled with target shape [4]:
----------------
Jianhui-Li wrote:

modified as suggested

https://github.com/llvm/llvm-project/pull/194513


More information about the Mlir-commits mailing list