[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.shape_cast (PR #164010)
Nishant Patel
llvmlistbot at llvm.org
Fri Oct 17 13:43:59 PDT 2025
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/164010
>From 813a9da0f144f251d2e13dde4c8a5275e378e8d0 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 15 Oct 2025 17:33:06 +0000
Subject: [PATCH 1/2] Add unroll pattern for vector.shape_cast
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 1 +
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 +
.../Vector/Transforms/VectorUnroll.cpp | 151 +++++++++++++++++-
.../Dialect/Vector/vector-unroll-options.mlir | 92 +++++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 4 +-
5 files changed, 248 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 6e79085afac9f..39097368b1e71 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2408,6 +2408,7 @@ def Vector_CompressStoreOp :
def Vector_ShapeCastOp :
Vector_Op<"shape_cast", [Pure,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
]>,
Arguments<(ins AnyVectorOfAnyRank:$source)>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 58256b0ade9f6..dff66a6e829a9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6233,6 +6233,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}
+std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
LogicalResult ShapeCastOp::verify() {
VectorType sourceType = getSourceVectorType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fbae0989bed26..8a969b6c6be6b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1003,6 +1003,153 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
vector::UnrollVectorOptions options;
};
+/// This pattern unrolls `vector.shape_cast` operations according to the
+/// provided target unroll shape. It decomposes a large shape_cast operation
+/// into smaller tiles and reconstructs each tile by extracting individual
+/// elements from the source vector and placing them at the correct positions.
+///
+/// Since shape_cast performs linear element reindexing, the pattern uses
+/// linear indexing as a bridge to map between source and result coordinates.
+/// For each element in a result tile, it calculates the corresponding source
+/// position and extracts that element.
+///
+/// Example:
+/// Given a shape_cast operation:
+/// %0 = vector.shape_cast %src : vector<2x8xf32> to vector<4x4xf32>
+///
+/// and a target unroll shape of <2x2>, the pattern produces:
+///
+/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
+/// %tile_zero = arith.constant dense<0.0> : vector<2x2xf32>
+///
+/// // First tile [0,0]: elements at result positions
+/// (0,0),(0,1),(1,0),(1,1)
+/// %e0 = vector.extract %src[0, 0] : f32 from vector<2x8xf32>
+/// %t0 = vector.insert %e0, %tile_zero [0, 0] : f32 into vector<2x2xf32>
+/// %e1 = vector.extract %src[0, 1] : f32 from vector<2x8xf32>
+/// %t1 = vector.insert %e1, %t0 [0, 1] : f32 into vector<2x2xf32>
+/// %e2 = vector.extract %src[0, 4] : f32 from vector<2x8xf32>
+/// %t2 = vector.insert %e2, %t1 [1, 0] : f32 into vector<2x2xf32>
+/// %e3 = vector.extract %src[0, 5] : f32 from vector<2x8xf32>
+/// %t3 = vector.insert %e3, %t2 [1, 1] : f32 into vector<2x2xf32>
+/// %r0 = vector.insert_strided_slice %t3, %zero
+/// {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into
+/// vector<4x4xf32>
+///
+struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
+ UnrollShapeCastPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, shapeCastOp);
+ if (!targetShape)
+ return failure();
+
+ Location loc = shapeCastOp.getLoc();
+ VectorType sourceType = shapeCastOp.getSourceVectorType();
+ VectorType resultType = shapeCastOp.getResultVectorType();
+
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+ ArrayRef<int64_t> sourceShape = sourceType.getShape();
+
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resultType, rewriter.getZeroAttr(resultType));
+
+ // For each unrolled tile in the result
+ for (SmallVector<int64_t> tileOffsets :
+ StaticTileOffsetRange(resultShape, *targetShape)) {
+
+ // Create the target tile type
+ VectorType tileType =
+ VectorType::get(*targetShape, resultType.getElementType());
+
+ // Build the tile by extracting individual elements
+ Value tile = createTileFromElements(
+ rewriter, loc, shapeCastOp.getSource(), sourceShape, resultShape,
+ tileOffsets, *targetShape, tileType);
+
+ // Insert the tile into the result
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, tile, result, tileOffsets, strides);
+ }
+
+ rewriter.replaceOp(shapeCastOp, result);
+ return success();
+ }
+
+private:
+ /// Creates a result tile by extracting individual elements from the source
+ /// and inserting them at the correct positions in the tile.
+ Value createTileFromElements(PatternRewriter &rewriter, Location loc,
+ Value source, ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> resultShape,
+ ArrayRef<int64_t> tileOffsets,
+ ArrayRef<int64_t> tileShape,
+ VectorType tileType) const {
+
+ // Initialize tile with zeros
+ Value tile = rewriter.create<arith::ConstantOp>(
+ loc, tileType, rewriter.getZeroAttr(tileType));
+
+ // Calculate strides for both source and result shapes
+ SmallVector<int64_t> sourceStrides = computeStrides(sourceShape);
+ SmallVector<int64_t> resultStrides = computeStrides(resultShape);
+
+ // Iterate over all positions in the tile using linear indexing
+ for (int64_t linearTileIdx = 0; linearTileIdx < computeProduct(tileShape);
+ ++linearTileIdx) {
+ // Convert linear tile index to multi-dimensional tile position
+ SmallVector<int64_t> tilePosition =
+ linearIndexToMultiDim(linearTileIdx, tileShape);
+
+ // Calculate the global position in the result
+ SmallVector<int64_t> globalResultPos;
+ globalResultPos.reserve(tileOffsets.size());
+ for (auto [offset, pos] : llvm::zip(tileOffsets, tilePosition)) {
+ globalResultPos.push_back(offset + pos);
+ }
+
+ // Convert result position to linear index
+ int64_t linearIndex = linearize(globalResultPos, resultStrides);
+
+ // Convert linear index to source position
+ SmallVector<int64_t> sourcePos =
+ linearIndexToMultiDim(linearIndex, sourceShape);
+
+ // Extract element from source
+ Value element =
+ rewriter.create<vector::ExtractOp>(loc, source, sourcePos);
+
+ // Insert element into tile
+ tile =
+ rewriter.create<vector::InsertOp>(loc, element, tile, tilePosition);
+ }
+
+ return tile;
+ }
+
+ /// Converts a linear index to multi-dimensional position within a given
+ /// shape. Used for both tile iteration and source coordinate computation.
+ SmallVector<int64_t> linearIndexToMultiDim(int64_t linearIndex,
+ ArrayRef<int64_t> shape) const {
+ SmallVector<int64_t> position(shape.size());
+
+ for (int64_t i = shape.size() - 1; i >= 0; --i) {
+ position[i] = linearIndex % shape[i];
+ linearIndex /= shape[i];
+ }
+
+ return position;
+ }
+
+ vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
@@ -1013,8 +1160,8 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
- UnrollToElements, UnrollStepPattern>(patterns.getContext(),
- options, benefit);
+ UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
+ patterns.getContext(), options, benefit);
}
void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e5a98b5c67f33..7a7129e9027a0 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -496,3 +496,95 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
// CHECK-NOT: arith.addf
// CHECK: return
+
+//CHECK-LABEL: func @shape_cast_1D_to_2D
+// CHECK-SAME: (%[[ARG0:.*]]: vector<16xf32>) -> vector<4x4xf32>
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0] : f32 from vector<16xf32>
+// CHECK: %[[INS0:.*]] = vector.insert %[[E0]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<16xf32>
+// CHECK: %[[INS1:.*]] = vector.insert %[[E1]], %[[INS0]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[E2:.*]] = vector.extract %[[ARG0]][4] : f32 from vector<16xf32>
+// CHECK: %[[INS2:.*]] = vector.insert %[[E2]], %[[INS1]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[E3:.*]] = vector.extract %[[ARG0]][5] : f32 from vector<16xf32>
+// CHECK: %[[V0:.*]] = vector.insert %[[E3]], %[[INS2]] [1, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK: %[[E4:.*]] = vector.extract %[[ARG0]][2] : f32 from vector<16xf32>
+// CHECK: %[[INS3:.*]] = vector.insert %[[E4]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[E5:.*]] = vector.extract %[[ARG0]][3] : f32 from vector<16xf32>
+// CHECK: %[[INS4:.*]] = vector.insert %[[E5]], %[[INS3]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[E6:.*]] = vector.extract %[[ARG0]][6] : f32 from vector<16xf32>
+// CHECK: %[[INS5:.*]] = vector.insert %[[E6]], %[[INS4]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[E7:.*]] = vector.extract %[[ARG0]][7] : f32 from vector<16xf32>
+// CHECK: %[[V1:.*]] = vector.insert %[[E7]], %[[INS5]] [1, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[V1]], %[[I0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK: %[[E8:.*]] = vector.extract %[[ARG0]][8] : f32 from vector<16xf32>
+// CHECK: %[[INS6:.*]] = vector.insert %[[E8]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[E9:.*]] = vector.extract %[[ARG0]][9] : f32 from vector<16xf32>
+// CHECK: %[[INS7:.*]] = vector.insert %[[E9]], %[[INS6]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[E10:.*]] = vector.extract %[[ARG0]][12] : f32 from vector<16xf32>
+// CHECK: %[[INS8:.*]] = vector.insert %[[E10]], %[[INS7]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[E11:.*]] = vector.extract %[[ARG0]][13] : f32 from vector<16xf32>
+// CHECK: %[[V2:.*]] = vector.insert %[[E11]], %[[INS8]] [1, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[V2]], %[[I1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK: %[[E12:.*]] = vector.extract %[[ARG0]][10] : f32 from vector<16xf32>
+// CHECK: %[[INS9:.*]] = vector.insert %[[E12]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[E13:.*]] = vector.extract %[[ARG0]][11] : f32 from vector<16xf32>
+// CHECK: %[[INS10:.*]] = vector.insert %[[E13]], %[[INS9]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[E14:.*]] = vector.extract %[[ARG0]][14] : f32 from vector<16xf32>
+// CHECK: %[[INS11:.*]] = vector.insert %[[E14]], %[[INS10]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[E15:.*]] = vector.extract %[[ARG0]][15] : f32 from vector<16xf32>
+// CHECK: %[[V3:.*]] = vector.insert %[[E15]], %[[INS11]] [1, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[I3:.*]] = vector.insert_strided_slice %[[V3]], %[[I2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK: return %[[I3]] : vector<4x4xf32>
+func.func @shape_cast_1D_to_2D(%v: vector<16xf32>) -> vector<4x4xf32> {
+ %0 = vector.shape_cast %v : vector<16xf32> to vector<4x4xf32>
+ return %0 : vector<4x4xf32>
+}
+
+//CHECK-LABEL: func @shape_cast_2D
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2x8xf32>) -> vector<4x4xf32>
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : f32 from vector<2x8xf32>
+// CHECK: %[[INS0:.*]] = vector.insert %[[E0]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][0, 1] : f32 from vector<2x8xf32>
+// CHECK: %[[INS1:.*]] = vector.insert %[[E1]], %[[INS0]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[E2:.*]] = vector.extract %[[ARG0]][0, 4] : f32 from vector<2x8xf32>
+// CHECK: %[[INS2:.*]] = vector.insert %[[E2]], %[[INS1]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[E3:.*]] = vector.extract %[[ARG0]][0, 5] : f32 from vector<2x8xf32>
+// CHECK: %[[V0:.*]] = vector.insert %[[E3]], %[[INS2]] [1, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK: %[[E4:.*]] = vector.extract %[[ARG0]][0, 2] : f32 from vector<2x8xf32>
+// CHECK: %[[INS3:.*]] = vector.insert %[[E4]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[E5:.*]] = vector.extract %[[ARG0]][0, 3] : f32 from vector<2x8xf32>
+// CHECK: %[[INS4:.*]] = vector.insert %[[E5]], %[[INS3]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[E6:.*]] = vector.extract %[[ARG0]][0, 6] : f32 from vector<2x8xf32>
+// CHECK: %[[INS5:.*]] = vector.insert %[[E6]], %[[INS4]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[E7:.*]] = vector.extract %[[ARG0]][0, 7] : f32 from vector<2x8xf32>
+// CHECK: %[[V1:.*]] = vector.insert %[[E7]], %[[INS5]] [1, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[V1]], %[[I0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK: %[[E8:.*]] = vector.extract %[[ARG0]][1, 0] : f32 from vector<2x8xf32>
+// CHECK: %[[INS6:.*]] = vector.insert %[[E8]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[E9:.*]] = vector.extract %[[ARG0]][1, 1] : f32 from vector<2x8xf32>
+// CHECK: %[[INS7:.*]] = vector.insert %[[E9]], %[[INS6]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[E10:.*]] = vector.extract %[[ARG0]][1, 4] : f32 from vector<2x8xf32>
+// CHECK: %[[INS8:.*]] = vector.insert %[[E10]], %[[INS7]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[E11:.*]] = vector.extract %[[ARG0]][1, 5] : f32 from vector<2x8xf32>
+// CHECK: %[[V2:.*]] = vector.insert %[[E11]], %[[INS8]] [1, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[V2]], %[[I1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK: %[[E12:.*]] = vector.extract %[[ARG0]][1, 2] : f32 from vector<2x8xf32>
+// CHECK: %[[INS9:.*]] = vector.insert %[[E12]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[E13:.*]] = vector.extract %[[ARG0]][1, 3] : f32 from vector<2x8xf32>
+// CHECK: %[[INS10:.*]] = vector.insert %[[E13]], %[[INS9]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[E14:.*]] = vector.extract %[[ARG0]][1, 6] : f32 from vector<2x8xf32>
+// CHECK: %[[INS11:.*]] = vector.insert %[[E14]], %[[INS10]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[E15:.*]] = vector.extract %[[ARG0]][1, 7] : f32 from vector<2x8xf32>
+// CHECK: %[[V3:.*]] = vector.insert %[[E15]], %[[INS11]] [1, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[I3:.*]] = vector.insert_strided_slice %[[V3]], %[[I2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK: return %[[I3]] : vector<4x4xf32>
+func.func @shape_cast_2D(%v: vector<2x8xf32>) -> vector<4x4xf32> {
+ %0 = vector.shape_cast %v : vector<2x8xf32> to vector<4x4xf32>
+ return %0 : vector<4x4xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 79bfc9bbcda71..0a54f06f5d6b6 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -163,8 +163,8 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(
isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp,
- vector::BroadcastOp, vector::LoadOp, vector::StoreOp>(
- op));
+ vector::BroadcastOp, vector::LoadOp, vector::StoreOp,
+ vector::ShapeCastOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
>From 37ea270bbbcaf27090402ff1be0a0698b89488f4 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 17 Oct 2025 20:42:28 +0000
Subject: [PATCH 2/2] Remove comment
---
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 8a969b6c6be6b..6cac6a30aa59b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1134,7 +1134,7 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
}
/// Converts a linear index to multi-dimensional position within a given
- /// shape. Used for both tile iteration and source coordinate computation.
+ /// shape.
SmallVector<int64_t> linearIndexToMultiDim(int64_t linearIndex,
ArrayRef<int64_t> shape) const {
SmallVector<int64_t> position(shape.size());
More information about the Mlir-commits
mailing list