[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.shape_cast (PR #164010)
Jakub Kuderski
llvmlistbot at llvm.org
Sat Oct 18 00:52:09 PDT 2025
================
@@ -1003,6 +1003,153 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
vector::UnrollVectorOptions options;
};
+/// This pattern unrolls `vector.shape_cast` operations according to the
+/// provided target unroll shape. It decomposes a large shape_cast operation
+/// into smaller tiles and reconstructs each tile by extracting individual
+/// elements from the source vector and placing them at the correct positions.
+///
+/// Since shape_cast performs linear element reindexing, the pattern uses
+/// linear indexing as a bridge to map between source and result coordinates.
+/// For each element in a result tile, it calculates the corresponding source
+/// position and extracts that element.
+///
+/// Example:
+/// Given a shape_cast operation:
+/// %0 = vector.shape_cast %src : vector<2x8xf32> to vector<4x4xf32>
+///
+/// and a target unroll shape of <2x2>, the pattern produces:
+///
+/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
+/// %tile_zero = arith.constant dense<0.0> : vector<2x2xf32>
+///
+/// // First tile [0,0]: elements at result positions
+/// (0,0),(0,1),(1,0),(1,1)
+/// %e0 = vector.extract %src[0, 0] : f32 from vector<2x8xf32>
+/// %t0 = vector.insert %e0, %tile_zero [0, 0] : f32 into vector<2x2xf32>
+/// %e1 = vector.extract %src[0, 1] : f32 from vector<2x8xf32>
+/// %t1 = vector.insert %e1, %t0 [0, 1] : f32 into vector<2x2xf32>
+/// %e2 = vector.extract %src[0, 4] : f32 from vector<2x8xf32>
+/// %t2 = vector.insert %e2, %t1 [1, 0] : f32 into vector<2x2xf32>
+/// %e3 = vector.extract %src[0, 5] : f32 from vector<2x8xf32>
+/// %t3 = vector.insert %e3, %t2 [1, 1] : f32 into vector<2x2xf32>
+/// %r0 = vector.insert_strided_slice %t3, %zero
+/// {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into
+/// vector<4x4xf32>
+///
+struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
+ UnrollShapeCastPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, shapeCastOp);
+ if (!targetShape)
+ return failure();
+
+ Location loc = shapeCastOp.getLoc();
+ VectorType sourceType = shapeCastOp.getSourceVectorType();
+ VectorType resultType = shapeCastOp.getResultVectorType();
+
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+ ArrayRef<int64_t> sourceShape = sourceType.getShape();
+
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resultType, rewriter.getZeroAttr(resultType));
+
+ // For each unrolled tile in the result
+ for (SmallVector<int64_t> tileOffsets :
+ StaticTileOffsetRange(resultShape, *targetShape)) {
+
+ // Create the target tile type
+ VectorType tileType =
+ VectorType::get(*targetShape, resultType.getElementType());
+
+ // Build the tile by extracting individual elements
+ Value tile = createTileFromElements(
+ rewriter, loc, shapeCastOp.getSource(), sourceShape, resultShape,
+ tileOffsets, *targetShape, tileType);
+
+ // Insert the tile into the result
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, tile, result, tileOffsets, strides);
+ }
+
+ rewriter.replaceOp(shapeCastOp, result);
+ return success();
+ }
+
+private:
+ /// Creates a result tile by extracting individual elements from the source
+ /// and inserting them at the correct positions in the tile.
+ Value createTileFromElements(PatternRewriter &rewriter, Location loc,
----------------
kuhar wrote:
Can we make it a static function?
https://github.com/llvm/llvm-project/pull/164010
More information about the Mlir-commits
mailing list