[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.shape_cast (PR #164010)
Nishant Patel
llvmlistbot at llvm.org
Sat Oct 18 09:13:33 PDT 2025
================
@@ -1003,6 +1003,153 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
vector::UnrollVectorOptions options;
};
+/// This pattern unrolls `vector.shape_cast` operations according to the
+/// provided target unroll shape. It decomposes a large shape_cast operation
+/// into smaller tiles and reconstructs each tile by extracting individual
+/// elements from the source vector and placing them at the correct positions.
+///
+/// Since shape_cast performs linear element reindexing, the pattern uses
+/// linear indexing as a bridge to map between source and result coordinates.
+/// For each element in a result tile, it calculates the corresponding source
+/// position and extracts that element.
+///
+/// Example:
+/// Given a shape_cast operation:
+/// %0 = vector.shape_cast %src : vector<2x8xf32> to vector<4x4xf32>
+///
+/// and a target unroll shape of <2x2>, the pattern produces:
+///
+/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
+/// %tile_zero = arith.constant dense<0.0> : vector<2x2xf32>
+///
+/// // First tile [0,0]: elements at result positions
+/// (0,0),(0,1),(1,0),(1,1)
+/// %e0 = vector.extract %src[0, 0] : f32 from vector<2x8xf32>
+/// %t0 = vector.insert %e0, %tile_zero [0, 0] : f32 into vector<2x2xf32>
+/// %e1 = vector.extract %src[0, 1] : f32 from vector<2x8xf32>
+/// %t1 = vector.insert %e1, %t0 [0, 1] : f32 into vector<2x2xf32>
+/// %e2 = vector.extract %src[0, 4] : f32 from vector<2x8xf32>
+/// %t2 = vector.insert %e2, %t1 [1, 0] : f32 into vector<2x2xf32>
+/// %e3 = vector.extract %src[0, 5] : f32 from vector<2x8xf32>
+/// %t3 = vector.insert %e3, %t2 [1, 1] : f32 into vector<2x2xf32>
+/// %r0 = vector.insert_strided_slice %t3, %zero
+/// {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into
+/// vector<4x4xf32>
+///
+struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
+ UnrollShapeCastPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, shapeCastOp);
+ if (!targetShape)
+ return failure();
+
+ Location loc = shapeCastOp.getLoc();
+ VectorType sourceType = shapeCastOp.getSourceVectorType();
+ VectorType resultType = shapeCastOp.getResultVectorType();
+
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+ ArrayRef<int64_t> sourceShape = sourceType.getShape();
+
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+ Value result = rewriter.create<arith::ConstantOp>(
----------------
nbpatel wrote:
sorry, my bad. keeps slipping out of my head
https://github.com/llvm/llvm-project/pull/164010
More information about the Mlir-commits
mailing list