[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for	vector.shape_cast (PR #164010)
    Jakub Kuderski 
    llvmlistbot at llvm.org
       
    Sat Oct 18 00:52:09 PDT 2025
    
    
  
================
@@ -1003,6 +1003,153 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
   vector::UnrollVectorOptions options;
 };
 
+/// This pattern unrolls `vector.shape_cast` operations according to the
+/// provided target unroll shape. It decomposes a large shape_cast operation
+/// into smaller tiles and reconstructs each tile by extracting individual
+/// elements from the source vector and placing them at the correct positions.
+///
+/// Since shape_cast performs linear element reindexing, the pattern uses
+/// linear indexing as a bridge to map between source and result coordinates.
+/// For each element in a result tile, it calculates the corresponding source
+/// position and extracts that element.
+///
+/// Example:
+///   Given a shape_cast operation:
+///     %0 = vector.shape_cast %src : vector<2x8xf32> to vector<4x4xf32>
+///
+///   and a target unroll shape of <2x2>, the pattern produces:
+///
+///     %zero = arith.constant dense<0.0> : vector<4x4xf32>
+///     %tile_zero = arith.constant dense<0.0> : vector<2x2xf32>
+///
+///     // First tile [0,0]: elements at result positions
+///     (0,0),(0,1),(1,0),(1,1)
+///     %e0 = vector.extract %src[0, 0] : f32 from vector<2x8xf32>
+///     %t0 = vector.insert %e0, %tile_zero [0, 0] : f32 into vector<2x2xf32>
+///     %e1 = vector.extract %src[0, 1] : f32 from vector<2x8xf32>
+///     %t1 = vector.insert %e1, %t0 [0, 1] : f32 into vector<2x2xf32>
+///     %e2 = vector.extract %src[0, 4] : f32 from vector<2x8xf32>
+///     %t2 = vector.insert %e2, %t1 [1, 0] : f32 into vector<2x2xf32>
+///     %e3 = vector.extract %src[0, 5] : f32 from vector<2x8xf32>
+///     %t3 = vector.insert %e3, %t2 [1, 1] : f32 into vector<2x2xf32>
+///     %r0 = vector.insert_strided_slice %t3, %zero
+///       {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into
+///       vector<4x4xf32>
+///
+struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
+  UnrollShapeCastPattern(MLIRContext *context,
+                         const vector::UnrollVectorOptions &options,
+                         PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, shapeCastOp);
+    if (!targetShape)
+      return failure();
+
+    Location loc = shapeCastOp.getLoc();
+    VectorType sourceType = shapeCastOp.getSourceVectorType();
+    VectorType resultType = shapeCastOp.getResultVectorType();
+
+    ArrayRef<int64_t> resultShape = resultType.getShape();
+    ArrayRef<int64_t> sourceShape = sourceType.getShape();
+
+    SmallVector<int64_t> strides(targetShape->size(), 1);
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+
+    // For each unrolled tile in the result
+    for (SmallVector<int64_t> tileOffsets :
+         StaticTileOffsetRange(resultShape, *targetShape)) {
+
+      // Create the target tile type
+      VectorType tileType =
+          VectorType::get(*targetShape, resultType.getElementType());
----------------
kuhar wrote:
```suggestion
      auto tileType =
          VectorType::get(*targetShape, resultType.getElementType());
```
See https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable
https://github.com/llvm/llvm-project/pull/164010
    
    
More information about the Mlir-commits
mailing list