[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.shape_cast (PR #164010)

Nishant Patel llvmlistbot at llvm.org
Fri Oct 17 13:43:59 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/164010

>From 813a9da0f144f251d2e13dde4c8a5275e378e8d0 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 15 Oct 2025 17:33:06 +0000
Subject: [PATCH 1/2] Add unroll pattern for vector.shape_cast

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |   1 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |   4 +
 .../Vector/Transforms/VectorUnroll.cpp        | 151 +++++++++++++++++-
 .../Dialect/Vector/vector-unroll-options.mlir |  92 +++++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |   4 +-
 5 files changed, 248 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 6e79085afac9f..39097368b1e71 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2408,6 +2408,7 @@ def Vector_CompressStoreOp :
 
 def Vector_ShapeCastOp :
   Vector_Op<"shape_cast", [Pure,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
   ]>,
     Arguments<(ins AnyVectorOfAnyRank:$source)>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 58256b0ade9f6..dff66a6e829a9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6233,6 +6233,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), argRanges.front());
 }
 
+std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
 LogicalResult ShapeCastOp::verify() {
 
   VectorType sourceType = getSourceVectorType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fbae0989bed26..8a969b6c6be6b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1003,6 +1003,153 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
   vector::UnrollVectorOptions options;
 };
 
+/// This pattern unrolls `vector.shape_cast` operations according to the
+/// provided target unroll shape. It decomposes a large shape_cast operation
+/// into smaller tiles and reconstructs each tile by extracting individual
+/// elements from the source vector and placing them at the correct positions.
+///
+/// Since shape_cast performs linear element reindexing, the pattern uses
+/// linear indexing as a bridge to map between source and result coordinates.
+/// For each element in a result tile, it calculates the corresponding source
+/// position and extracts that element.
+///
+/// Example:
+///   Given a shape_cast operation:
+///     %0 = vector.shape_cast %src : vector<2x8xf32> to vector<4x4xf32>
+///
+///   and a target unroll shape of <2x2>, the pattern produces:
+///
+///     %zero = arith.constant dense<0.0> : vector<4x4xf32>
+///     %tile_zero = arith.constant dense<0.0> : vector<2x2xf32>
+///
+///     // First tile [0,0]: elements at result positions
+///     (0,0),(0,1),(1,0),(1,1)
+///     %e0 = vector.extract %src[0, 0] : f32 from vector<2x8xf32>
+///     %t0 = vector.insert %e0, %tile_zero [0, 0] : f32 into vector<2x2xf32>
+///     %e1 = vector.extract %src[0, 1] : f32 from vector<2x8xf32>
+///     %t1 = vector.insert %e1, %t0 [0, 1] : f32 into vector<2x2xf32>
+///     %e2 = vector.extract %src[0, 4] : f32 from vector<2x8xf32>
+///     %t2 = vector.insert %e2, %t1 [1, 0] : f32 into vector<2x2xf32>
+///     %e3 = vector.extract %src[0, 5] : f32 from vector<2x8xf32>
+///     %t3 = vector.insert %e3, %t2 [1, 1] : f32 into vector<2x2xf32>
+///     %r0 = vector.insert_strided_slice %t3, %zero
+///       {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into
+///       vector<4x4xf32>
+///
+struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
+  UnrollShapeCastPattern(MLIRContext *context,
+                         const vector::UnrollVectorOptions &options,
+                         PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, shapeCastOp);
+    if (!targetShape)
+      return failure();
+
+    Location loc = shapeCastOp.getLoc();
+    VectorType sourceType = shapeCastOp.getSourceVectorType();
+    VectorType resultType = shapeCastOp.getResultVectorType();
+
+    ArrayRef<int64_t> resultShape = resultType.getShape();
+    ArrayRef<int64_t> sourceShape = sourceType.getShape();
+
+    SmallVector<int64_t> strides(targetShape->size(), 1);
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+
+    // For each unrolled tile in the result
+    for (SmallVector<int64_t> tileOffsets :
+         StaticTileOffsetRange(resultShape, *targetShape)) {
+
+      // Create the target tile type
+      VectorType tileType =
+          VectorType::get(*targetShape, resultType.getElementType());
+
+      // Build the tile by extracting individual elements
+      Value tile = createTileFromElements(
+          rewriter, loc, shapeCastOp.getSource(), sourceShape, resultShape,
+          tileOffsets, *targetShape, tileType);
+
+      // Insert the tile into the result
+      result = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, tile, result, tileOffsets, strides);
+    }
+
+    rewriter.replaceOp(shapeCastOp, result);
+    return success();
+  }
+
+private:
+  /// Creates a result tile by extracting individual elements from the source
+  /// and inserting them at the correct positions in the tile.
+  Value createTileFromElements(PatternRewriter &rewriter, Location loc,
+                               Value source, ArrayRef<int64_t> sourceShape,
+                               ArrayRef<int64_t> resultShape,
+                               ArrayRef<int64_t> tileOffsets,
+                               ArrayRef<int64_t> tileShape,
+                               VectorType tileType) const {
+
+    // Initialize tile with zeros
+    Value tile = rewriter.create<arith::ConstantOp>(
+        loc, tileType, rewriter.getZeroAttr(tileType));
+
+    // Calculate strides for both source and result shapes
+    SmallVector<int64_t> sourceStrides = computeStrides(sourceShape);
+    SmallVector<int64_t> resultStrides = computeStrides(resultShape);
+
+    // Iterate over all positions in the tile using linear indexing
+    for (int64_t linearTileIdx = 0; linearTileIdx < computeProduct(tileShape);
+         ++linearTileIdx) {
+      // Convert linear tile index to multi-dimensional tile position
+      SmallVector<int64_t> tilePosition =
+          linearIndexToMultiDim(linearTileIdx, tileShape);
+
+      // Calculate the global position in the result
+      SmallVector<int64_t> globalResultPos;
+      globalResultPos.reserve(tileOffsets.size());
+      for (auto [offset, pos] : llvm::zip(tileOffsets, tilePosition)) {
+        globalResultPos.push_back(offset + pos);
+      }
+
+      // Convert result position to linear index
+      int64_t linearIndex = linearize(globalResultPos, resultStrides);
+
+      // Convert linear index to source position
+      SmallVector<int64_t> sourcePos =
+          linearIndexToMultiDim(linearIndex, sourceShape);
+
+      // Extract element from source
+      Value element =
+          rewriter.create<vector::ExtractOp>(loc, source, sourcePos);
+
+      // Insert element into tile
+      tile =
+          rewriter.create<vector::InsertOp>(loc, element, tile, tilePosition);
+    }
+
+    return tile;
+  }
+
+  /// Converts a linear index to multi-dimensional position within a given
+  /// shape. Used for both tile iteration and source coordinate computation.
+  SmallVector<int64_t> linearIndexToMultiDim(int64_t linearIndex,
+                                             ArrayRef<int64_t> shape) const {
+    SmallVector<int64_t> position(shape.size());
+
+    for (int64_t i = shape.size() - 1; i >= 0; --i) {
+      position[i] = linearIndex % shape[i];
+      linearIndex /= shape[i];
+    }
+
+    return position;
+  }
+
+  vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
@@ -1013,8 +1160,8 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollReductionPattern, UnrollMultiReductionPattern,
                UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
                UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
-               UnrollToElements, UnrollStepPattern>(patterns.getContext(),
-                                                    options, benefit);
+               UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
+      patterns.getContext(), options, benefit);
 }
 
 void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e5a98b5c67f33..7a7129e9027a0 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -496,3 +496,95 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
 // CHECK-COUNT-4:   arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
 // CHECK-NOT: arith.addf
 // CHECK: return
+
+//CHECK-LABEL: func @shape_cast_1D_to_2D
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<16xf32>) -> vector<4x4xf32>
+//       CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+//       CHECK:   %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+//       CHECK:   %[[E0:.*]] = vector.extract %[[ARG0]][0] : f32 from vector<16xf32>
+//       CHECK:   %[[INS0:.*]] = vector.insert %[[E0]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E1:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<16xf32>
+//       CHECK:   %[[INS1:.*]] = vector.insert %[[E1]], %[[INS0]] [0, 1] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E2:.*]] = vector.extract %[[ARG0]][4] : f32 from vector<16xf32>
+//       CHECK:   %[[INS2:.*]] = vector.insert %[[E2]], %[[INS1]] [1, 0] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E3:.*]] = vector.extract %[[ARG0]][5] : f32 from vector<16xf32>
+//       CHECK:   %[[V0:.*]] = vector.insert %[[E3]], %[[INS2]] [1, 1] : f32 into vector<2x2xf32>
+//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK:   %[[E4:.*]] = vector.extract %[[ARG0]][2] : f32 from vector<16xf32>
+//       CHECK:   %[[INS3:.*]] = vector.insert %[[E4]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E5:.*]] = vector.extract %[[ARG0]][3] : f32 from vector<16xf32>
+//       CHECK:   %[[INS4:.*]] = vector.insert %[[E5]], %[[INS3]] [0, 1] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E6:.*]] = vector.extract %[[ARG0]][6] : f32 from vector<16xf32>
+//       CHECK:   %[[INS5:.*]] = vector.insert %[[E6]], %[[INS4]] [1, 0] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E7:.*]] = vector.extract %[[ARG0]][7] : f32 from vector<16xf32>
+//       CHECK:   %[[V1:.*]] = vector.insert %[[E7]], %[[INS5]] [1, 1] : f32 into vector<2x2xf32>
+//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[V1]], %[[I0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK:   %[[E8:.*]] = vector.extract %[[ARG0]][8] : f32 from vector<16xf32>
+//       CHECK:   %[[INS6:.*]] = vector.insert %[[E8]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E9:.*]] = vector.extract %[[ARG0]][9] : f32 from vector<16xf32>
+//       CHECK:   %[[INS7:.*]] = vector.insert %[[E9]], %[[INS6]] [0, 1] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E10:.*]] = vector.extract %[[ARG0]][12] : f32 from vector<16xf32>
+//       CHECK:   %[[INS8:.*]] = vector.insert %[[E10]], %[[INS7]] [1, 0] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E11:.*]] = vector.extract %[[ARG0]][13] : f32 from vector<16xf32>
+//       CHECK:   %[[V2:.*]] = vector.insert %[[E11]], %[[INS8]] [1, 1] : f32 into vector<2x2xf32>
+//       CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[V2]], %[[I1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK:   %[[E12:.*]] = vector.extract %[[ARG0]][10] : f32 from vector<16xf32>
+//       CHECK:   %[[INS9:.*]] = vector.insert %[[E12]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E13:.*]] = vector.extract %[[ARG0]][11] : f32 from vector<16xf32>
+//       CHECK:   %[[INS10:.*]] = vector.insert %[[E13]], %[[INS9]] [0, 1] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E14:.*]] = vector.extract %[[ARG0]][14] : f32 from vector<16xf32>
+//       CHECK:   %[[INS11:.*]] = vector.insert %[[E14]], %[[INS10]] [1, 0] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E15:.*]] = vector.extract %[[ARG0]][15] : f32 from vector<16xf32>
+//       CHECK:   %[[V3:.*]] = vector.insert %[[E15]], %[[INS11]] [1, 1] : f32 into vector<2x2xf32>
+//       CHECK:   %[[I3:.*]] = vector.insert_strided_slice %[[V3]], %[[I2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK:   return %[[I3]] : vector<4x4xf32>
+func.func @shape_cast_1D_to_2D(%v: vector<16xf32>) -> vector<4x4xf32> {
+  %0 = vector.shape_cast %v : vector<16xf32> to vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+//CHECK-LABEL: func @shape_cast_2D
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<2x8xf32>) -> vector<4x4xf32>
+//       CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+//       CHECK:   %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+//       CHECK:   %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : f32 from vector<2x8xf32>
+//       CHECK:   %[[INS0:.*]] = vector.insert %[[E0]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E1:.*]] = vector.extract %[[ARG0]][0, 1] : f32 from vector<2x8xf32>
+//       CHECK:   %[[INS1:.*]] = vector.insert %[[E1]], %[[INS0]] [0, 1] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E2:.*]] = vector.extract %[[ARG0]][0, 4] : f32 from vector<2x8xf32>
+//       CHECK:   %[[INS2:.*]] = vector.insert %[[E2]], %[[INS1]] [1, 0] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E3:.*]] = vector.extract %[[ARG0]][0, 5] : f32 from vector<2x8xf32>
+//       CHECK:   %[[V0:.*]] = vector.insert %[[E3]], %[[INS2]] [1, 1] : f32 into vector<2x2xf32>
+//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK:   %[[E4:.*]] = vector.extract %[[ARG0]][0, 2] : f32 from vector<2x8xf32>
+//       CHECK:   %[[INS3:.*]] = vector.insert %[[E4]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E5:.*]] = vector.extract %[[ARG0]][0, 3] : f32 from vector<2x8xf32>
+//       CHECK:   %[[INS4:.*]] = vector.insert %[[E5]], %[[INS3]] [0, 1] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E6:.*]] = vector.extract %[[ARG0]][0, 6] : f32 from vector<2x8xf32>
+//       CHECK:   %[[INS5:.*]] = vector.insert %[[E6]], %[[INS4]] [1, 0] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E7:.*]] = vector.extract %[[ARG0]][0, 7] : f32 from vector<2x8xf32>
+//       CHECK:   %[[V1:.*]] = vector.insert %[[E7]], %[[INS5]] [1, 1] : f32 into vector<2x2xf32>
+//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[V1]], %[[I0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK:   %[[E8:.*]] = vector.extract %[[ARG0]][1, 0] : f32 from vector<2x8xf32>
+//       CHECK:   %[[INS6:.*]] = vector.insert %[[E8]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E9:.*]] = vector.extract %[[ARG0]][1, 1] : f32 from vector<2x8xf32>
+//       CHECK:   %[[INS7:.*]] = vector.insert %[[E9]], %[[INS6]] [0, 1] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E10:.*]] = vector.extract %[[ARG0]][1, 4] : f32 from vector<2x8xf32>
+//       CHECK:   %[[INS8:.*]] = vector.insert %[[E10]], %[[INS7]] [1, 0] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E11:.*]] = vector.extract %[[ARG0]][1, 5] : f32 from vector<2x8xf32>
+//       CHECK:   %[[V2:.*]] = vector.insert %[[E11]], %[[INS8]] [1, 1] : f32 into vector<2x2xf32>
+//       CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[V2]], %[[I1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK:   %[[E12:.*]] = vector.extract %[[ARG0]][1, 2] : f32 from vector<2x8xf32>
+//       CHECK:   %[[INS9:.*]] = vector.insert %[[E12]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E13:.*]] = vector.extract %[[ARG0]][1, 3] : f32 from vector<2x8xf32>
+//       CHECK:   %[[INS10:.*]] = vector.insert %[[E13]], %[[INS9]] [0, 1] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E14:.*]] = vector.extract %[[ARG0]][1, 6] : f32 from vector<2x8xf32>
+//       CHECK:   %[[INS11:.*]] = vector.insert %[[E14]], %[[INS10]] [1, 0] : f32 into vector<2x2xf32>
+//       CHECK:   %[[E15:.*]] = vector.extract %[[ARG0]][1, 7] : f32 from vector<2x8xf32>
+//       CHECK:   %[[V3:.*]] = vector.insert %[[E15]], %[[INS11]] [1, 1] : f32 into vector<2x2xf32>
+//       CHECK:   %[[I3:.*]] = vector.insert_strided_slice %[[V3]], %[[I2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK:   return %[[I3]] : vector<4x4xf32>
+func.func @shape_cast_2D(%v: vector<2x8xf32>) -> vector<4x4xf32> {
+  %0 = vector.shape_cast %v : vector<2x8xf32> to vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 79bfc9bbcda71..0a54f06f5d6b6 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -163,8 +163,8 @@ struct TestVectorUnrollingPatterns
             .setFilterConstraint([](Operation *op) {
               return success(
                   isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp,
-                      vector::BroadcastOp, vector::LoadOp, vector::StoreOp>(
-                      op));
+                      vector::BroadcastOp, vector::LoadOp, vector::StoreOp,
+                      vector::ShapeCastOp>(op));
             }));
     populateVectorUnrollPatterns(
         patterns, UnrollVectorOptions()

>From 37ea270bbbcaf27090402ff1be0a0698b89488f4 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 17 Oct 2025 20:42:28 +0000
Subject: [PATCH 2/2] Remove comment

---
 mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 8a969b6c6be6b..6cac6a30aa59b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1134,7 +1134,7 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
   }
 
   /// Converts a linear index to multi-dimensional position within a given
-  /// shape. Used for both tile iteration and source coordinate computation.
+  /// shape.
   SmallVector<int64_t> linearIndexToMultiDim(int64_t linearIndex,
                                              ArrayRef<int64_t> shape) const {
     SmallVector<int64_t> position(shape.size());



More information about the Mlir-commits mailing list