[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.shape_cast (PR #167738)
Nishant Patel
llvmlistbot at llvm.org
Tue Nov 18 10:26:10 PST 2025
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/167738
>From cd8b818297287afbed0c675d9bf491bfb296f385 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 11 Nov 2025 00:14:40 +0000
Subject: [PATCH 1/4] Add unroll pattern for vector.shape_cast
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 1 +
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 +
.../Vector/Transforms/VectorUnroll.cpp | 170 +++++++++++++++++-
.../Dialect/Vector/vector-unroll-options.mlir | 34 ++++
.../Dialect/Vector/TestVectorTransforms.cpp | 6 +
5 files changed, 213 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43172ff2082df..6ad179349f90f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2427,6 +2427,7 @@ def Vector_CompressStoreOp :
def Vector_ShapeCastOp :
Vector_Op<"shape_cast", [Pure,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
]>,
Arguments<(ins AnyVectorOfAnyRank:$source)>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index daef0ba02100a..4cac137478fab 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6241,6 +6241,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}
+std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
LogicalResult ShapeCastOp::verify() {
VectorType sourceType = getSourceVectorType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fbae0989bed26..a4830809aaac8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1003,6 +1003,172 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
vector::UnrollVectorOptions options;
};
+static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
+ ArrayRef<int64_t> resultShape) {
+ if (targetShape.size() > resultShape.size()) {
+ return false;
+ }
+
+ size_t rankDiff = resultShape.size() - targetShape.size();
+ // Inner dimensions must match exactly & total resultElements should be
+ // evenly divisible by targetElements.
+ for (size_t i = 1; i < targetShape.size(); ++i) {
+ if (targetShape[i] != resultShape[rankDiff + i]) {
+ return false;
+ }
+ }
+
+ int64_t targetElements = ShapedType::getNumElements(targetShape);
+ int64_t resultElements = ShapedType::getNumElements(resultShape);
+ if (resultElements % targetElements != 0) {
+ return false;
+ }
+ return true;
+}
+
+// Calculate the shape to extract from source
+static std::optional<SmallVector<int64_t>>
+calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
+ int64_t targetElements) {
+ SmallVector<int64_t> extractShape;
+ int64_t remainingElements = targetElements;
+
+ // Build extract shape from innermost dimension outward to ensure contiguity
+ for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
+ int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
+ extractShape.insert(extractShape.begin(), takeFromDim);
+
+ if (remainingElements % takeFromDim != 0) {
+ return std::nullopt; // Not evenly divisible
+ }
+ remainingElements /= takeFromDim;
+ }
+
+ // Fill remaining dimensions with 1
+ while (extractShape.size() < sourceShape.size()) {
+ extractShape.insert(extractShape.begin(), 1);
+ }
+
+ if (ShapedType::getNumElements(extractShape) != targetElements) {
+ return std::nullopt;
+ }
+
+ return extractShape;
+}
+
+// Convert result offsets to source offsets via linear position
+static SmallVector<int64_t>
+calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
+ ArrayRef<int64_t> sourceStrides,
+ ArrayRef<int64_t> resultStrides) {
+ // Convert result offsets to linear position
+ int64_t linearIndex = linearize(resultOffsets, resultStrides);
+ // Convert linear position to source offsets
+ SmallVector<int64_t> sourceOffsets = delinearize(linearIndex, sourceStrides);
+ return sourceOffsets;
+}
+
+/// This pattern unrolls `vector.shape_cast` operations according to the
+/// provided target unroll shape. It unrolls a large shape cast into smaller
+/// shape casts by extracting contiguous slices from the source vector, casting
+/// each slice to the target shape, and assembling the result by inserting each
+/// computed segment into the appropriate offset of the result vector.
+///
+/// This pattern only applies when contiguous slices can be extracted from the
+/// source vector and inserted into the result vector such that each slice
+/// remains a valid vector (and not decompose to scalars). In these cases, the
+/// unrolling proceeds as:
+/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
+/// vector.insert_strided_slice
+///
+/// Example:
+/// Given a shape cast operation:
+/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
+///
+/// and a target unroll shape of <2x4>, the pattern produces:
+///
+/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
+/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
+/// : vector<8x2xf32> to vector<4x2xf32>
+/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
+/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
+/// : vector<2x4xf32> into vector<4x4xf32>
+/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
+/// : vector<8x2xf32> to vector<4x2xf32>
+/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
+/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
+/// : vector<2x4xf32> into vector<4x4xf32>
+///
+struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
+ UnrollShapeCastPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, shapeCastOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType sourceType = shapeCastOp.getSourceVectorType();
+ VectorType resultType = shapeCastOp.getResultVectorType();
+ ArrayRef<int64_t> sourceShape = sourceType.getShape();
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+
+ if (!isContiguousExtract(*targetShape, resultShape)) {
+ return rewriter.notifyMatchFailure(shapeCastOp,
+ "Only supports cases where contiguous "
+ "extraction is possible");
+ }
+
+ int64_t targetElements = ShapedType::getNumElements(*targetShape);
+
+ // Calculate the shape to extract from source
+ auto extractShape =
+ calculateSourceExtractShape(sourceShape, targetElements);
+ if (!extractShape) {
+ return rewriter.notifyMatchFailure(
+ shapeCastOp,
+ "cannot extract target number of elements contiguously from source");
+ }
+
+ Location loc = shapeCastOp.getLoc();
+
+ // Create result vector initialized to zero
+ Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
+
+ VectorType targetType =
+ VectorType::get(*targetShape, sourceType.getElementType());
+
+ SmallVector<int64_t> extractStrides(extractShape->size(), 1);
+ SmallVector<int64_t> insertStrides(targetShape->size(), 1);
+ SmallVector<int64_t> sourceStrides = computeStrides(sourceShape);
+ SmallVector<int64_t> resultStrides = computeStrides(resultShape);
+
+ for (SmallVector<int64_t> resultOffsets :
+ StaticTileOffsetRange(resultShape, *targetShape)) {
+ SmallVector<int64_t> sourceOffsets =
+ calculateSourceOffsets(resultOffsets, sourceStrides, resultStrides);
+ Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
+ extractStrides);
+ Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, targetType, sourceChunk);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, targetChunk, result, resultOffsets, insertStrides);
+ }
+
+ rewriter.replaceOp(shapeCastOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
@@ -1013,8 +1179,8 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
- UnrollToElements, UnrollStepPattern>(patterns.getContext(),
- options, benefit);
+ UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
+ patterns.getContext(), options, benefit);
}
void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e5a98b5c67f33..c94a502fa3654 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -496,3 +496,37 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
// CHECK-NOT: arith.addf
// CHECK: return
+
+
+func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
+ %0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
+ return %0 : vector<2x2x4xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_1D
+// CHECK-SAME: (%[[ARG0:.*]]: vector<16xf32>) -> vector<2x2x4xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32>
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<8xf32> to vector<2x4xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<8xf32> to vector<2x4xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
+// CHECK: return %[[I1]] : vector<2x2x4xf32>
+
+
+func.func @shape_cast_2D(%v: vector<8x2xf32>) -> vector<4x4xf32> {
+ %0 = vector.shape_cast %v : vector<8x2xf32> to vector<4x4xf32>
+ return %0 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_2D
+// CHECK-SAME: (%[[ARG0:.*]]: vector<8x2xf32>) -> vector<4x4xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
+// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<4x2xf32> to vector<2x4xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
+// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<4x2xf32> to vector<2x4xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
+// CHECK: return %[[I1]] : vector<4x4xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 79bfc9bbcda71..0ab4e451d544d 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -178,6 +178,12 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(isa<vector::StepOp>(op));
}));
+ populateVectorUnrollPatterns(
+ patterns, UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{2, 4})
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::ShapeCastOp>(op));
+ }));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
>From 73512fd722ea836ea96ec31d55f55e893c6f9b14 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 12 Nov 2025 19:38:26 +0000
Subject: [PATCH 2/4] Address feedback
---
.../Vector/Transforms/VectorUnroll.cpp | 59 ++++++++-----------
1 file changed, 24 insertions(+), 35 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index a4830809aaac8..7afc83bb8a876 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1005,67 +1005,57 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
ArrayRef<int64_t> resultShape) {
- if (targetShape.size() > resultShape.size()) {
+ if (targetShape.size() > resultShape.size())
return false;
- }
size_t rankDiff = resultShape.size() - targetShape.size();
// Inner dimensions must match exactly & total resultElements should be
// evenly divisible by targetElements.
- for (size_t i = 1; i < targetShape.size(); ++i) {
- if (targetShape[i] != resultShape[rankDiff + i]) {
- return false;
- }
- }
+ if (!llvm::equal(targetShape.drop_front(),
+ resultShape.drop_front(rankDiff + 1)))
+ return false;
int64_t targetElements = ShapedType::getNumElements(targetShape);
int64_t resultElements = ShapedType::getNumElements(resultShape);
- if (resultElements % targetElements != 0) {
- return false;
- }
- return true;
+ return resultElements % targetElements == 0;
}
-// Calculate the shape to extract from source
+// Calculate the shape to extract from source.
static std::optional<SmallVector<int64_t>>
calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
int64_t targetElements) {
SmallVector<int64_t> extractShape;
int64_t remainingElements = targetElements;
- // Build extract shape from innermost dimension outward to ensure contiguity
+ // Build extract shape from innermost dimension outward to ensure contiguity.
for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
extractShape.insert(extractShape.begin(), takeFromDim);
- if (remainingElements % takeFromDim != 0) {
- return std::nullopt; // Not evenly divisible
- }
+ if (remainingElements % takeFromDim != 0)
+ return std::nullopt; // Not evenly divisible.
remainingElements /= takeFromDim;
}
- // Fill remaining dimensions with 1
- while (extractShape.size() < sourceShape.size()) {
+ // Fill remaining dimensions with 1.
+ while (extractShape.size() < sourceShape.size())
extractShape.insert(extractShape.begin(), 1);
- }
- if (ShapedType::getNumElements(extractShape) != targetElements) {
+ if (ShapedType::getNumElements(extractShape) != targetElements)
return std::nullopt;
- }
return extractShape;
}
-// Convert result offsets to source offsets via linear position
+// Convert result offsets to source offsets via linear position.
static SmallVector<int64_t>
calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
ArrayRef<int64_t> sourceStrides,
ArrayRef<int64_t> resultStrides) {
- // Convert result offsets to linear position
+ // Convert result offsets to linear position.
int64_t linearIndex = linearize(resultOffsets, resultStrides);
- // Convert linear position to source offsets
- SmallVector<int64_t> sourceOffsets = delinearize(linearIndex, sourceStrides);
- return sourceOffsets;
+ // Convert linear position to source offsets.
+ return delinearize(linearIndex, sourceStrides);
}
/// This pattern unrolls `vector.shape_cast` operations according to the
@@ -1079,7 +1069,7 @@ calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
/// remains a valid vector (and not decompose to scalars). In these cases, the
/// unrolling proceeds as:
/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
-/// vector.insert_strided_slice
+/// vector.insert_strided_slice.
///
/// Example:
/// Given a shape cast operation:
@@ -1108,7 +1098,8 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
- auto targetShape = getTargetShape(options, shapeCastOp);
+ std::optional<SmallVector<int64_t>> targetShape =
+ getTargetShape(options, shapeCastOp);
if (!targetShape)
return failure();
@@ -1117,26 +1108,24 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
ArrayRef<int64_t> sourceShape = sourceType.getShape();
ArrayRef<int64_t> resultShape = resultType.getShape();
- if (!isContiguousExtract(*targetShape, resultShape)) {
+ if (!isContiguousExtract(*targetShape, resultShape))
return rewriter.notifyMatchFailure(shapeCastOp,
"Only supports cases where contiguous "
"extraction is possible");
- }
int64_t targetElements = ShapedType::getNumElements(*targetShape);
- // Calculate the shape to extract from source
- auto extractShape =
+ // Calculate the shape to extract from source.
+ std::optional<SmallVector<int64_t>> extractShape =
calculateSourceExtractShape(sourceShape, targetElements);
- if (!extractShape) {
+ if (!extractShape)
return rewriter.notifyMatchFailure(
shapeCastOp,
"cannot extract target number of elements contiguously from source");
- }
Location loc = shapeCastOp.getLoc();
- // Create result vector initialized to zero
+ // Create result vector initialized to zero.
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
rewriter.getZeroAttr(resultType));
>From 9b4191a1c63c033fbf8f88dc9b227c1db1a936db Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 13 Nov 2025 00:40:23 +0000
Subject: [PATCH 3/4] Fix isContiguousExtract
---
.../Vector/Transforms/VectorUnroll.cpp | 50 ++++++++++++++++---
1 file changed, 42 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 7afc83bb8a876..885fcf835c1a3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1008,16 +1008,50 @@ static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
if (targetShape.size() > resultShape.size())
return false;
- size_t rankDiff = resultShape.size() - targetShape.size();
- // Inner dimensions must match exactly & total resultElements should be
- // evenly divisible by targetElements.
- if (!llvm::equal(targetShape.drop_front(),
- resultShape.drop_front(rankDiff + 1)))
- return false;
-
int64_t targetElements = ShapedType::getNumElements(targetShape);
int64_t resultElements = ShapedType::getNumElements(resultShape);
- return resultElements % targetElements == 0;
+
+ // Result must be evenly divisible by target.
+ if (resultElements % targetElements != 0)
+ return false;
+
+ // For contiguous extraction, we need to be able to
+ // extract targetElements contiguously from the result shape.
+ // This means we can "consume" dimensions from the innermost outward
+ // until we have exactly targetElements.
+
+ int64_t remainingElements = targetElements;
+ int targetDimIdx = targetShape.size() - 1;
+
+ // Work backwards through result dimensions.
+ for (int resultDimIdx = resultShape.size() - 1;
+ resultDimIdx >= 0 && remainingElements > 1 && targetDimIdx >= 0;
+ --resultDimIdx) {
+
+ int64_t resultDimSize = resultShape[resultDimIdx];
+ int64_t targetDimSize = targetShape[targetDimIdx];
+
+ if (targetDimSize > resultDimSize)
+ return false;
+
+ if (targetDimSize == resultDimSize) {
+ if (remainingElements % targetDimSize != 0)
+ return false;
+ remainingElements /= targetDimSize;
+ --targetDimIdx;
+ } else {
+ if (remainingElements != targetDimSize)
+ return false;
+ remainingElements = 1;
+ --targetDimIdx;
+ }
+ }
+
+ // Check remaining target dimensions are all 1 and we consumed all elements
+ return remainingElements == 1 &&
+ (targetDimIdx < 0 || llvm::all_of(
+ targetShape.take_front(targetDimIdx + 1),
+ [](int64_t d) { return d == 1; }));
}
// Calculate the shape to extract from source.
>From d4ea820d64c74a829225de31715be50a96045fa7 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 18 Nov 2025 16:52:35 +0000
Subject: [PATCH 4/4] Address feedback
---
.../Vector/Transforms/VectorUnroll.cpp | 110 +++++++++---------
.../Dialect/Vector/vector-unroll-options.mlir | 39 ++++++-
2 files changed, 88 insertions(+), 61 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 885fcf835c1a3..0a1d86109beea 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1003,58 +1003,60 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
vector::UnrollVectorOptions options;
};
-static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
- ArrayRef<int64_t> resultShape) {
- if (targetShape.size() > resultShape.size())
- return false;
-
- int64_t targetElements = ShapedType::getNumElements(targetShape);
- int64_t resultElements = ShapedType::getNumElements(resultShape);
+/// Checks whether targetShape is contiguous in resultShape.
+/// For targetShape to be contiguous in resultShape:
+/// 1) The inner dimensions of targetShape and resultShape must match exactly.
+/// 2) The total number of elements in resultShape must be evenly divisible by
+/// the total number of elements in targetShape.
+/// Examples:
+/// isContiguous([4, 4], [8, 4]) == true
+/// isContiguous([2, 4], [8, 4]) == true
+/// isContiguous([2, 2], [8, 4]) == false
+/// Removes leading unit dimensions to handle cases like:
+/// isContiguous([1, 16], [1, 32]) == true
+static bool isContiguous(ArrayRef<int64_t> targetShape,
+ ArrayRef<int64_t> resultShape) {
- // Result must be evenly divisible by target.
- if (resultElements % targetElements != 0)
+ if (targetShape.size() > resultShape.size())
return false;
- // For contiguous extraction, we need to be able to
- // extract targetElements contiguously from the result shape.
- // This means we can "consume" dimensions from the innermost outward
- // until we have exactly targetElements.
+ while (!targetShape.empty() && targetShape.front() == 1) {
+ targetShape = targetShape.drop_front();
+ }
- int64_t remainingElements = targetElements;
- int targetDimIdx = targetShape.size() - 1;
-
- // Work backwards through result dimensions.
- for (int resultDimIdx = resultShape.size() - 1;
- resultDimIdx >= 0 && remainingElements > 1 && targetDimIdx >= 0;
- --resultDimIdx) {
-
- int64_t resultDimSize = resultShape[resultDimIdx];
- int64_t targetDimSize = targetShape[targetDimIdx];
-
- if (targetDimSize > resultDimSize)
- return false;
-
- if (targetDimSize == resultDimSize) {
- if (remainingElements % targetDimSize != 0)
- return false;
- remainingElements /= targetDimSize;
- --targetDimIdx;
- } else {
- if (remainingElements != targetDimSize)
- return false;
- remainingElements = 1;
- --targetDimIdx;
- }
+ while (!resultShape.empty() && resultShape.front() == 1) {
+ resultShape = resultShape.drop_front();
}
- // Check remaining target dimensions are all 1 and we consumed all elements
- return remainingElements == 1 &&
- (targetDimIdx < 0 || llvm::all_of(
- targetShape.take_front(targetDimIdx + 1),
- [](int64_t d) { return d == 1; }));
+ size_t rankDiff = resultShape.size() - targetShape.size();
+ if (!llvm::equal(targetShape.drop_front(),
+ resultShape.drop_front(rankDiff + 1)))
+ return false;
+
+ int64_t targetElements = ShapedType::getNumElements(targetShape);
+ int64_t resultElements = ShapedType::getNumElements(resultShape);
+ return resultElements % targetElements == 0;
}
-// Calculate the shape to extract from source.
+/// This function determines what shape to use with
+/// `vector.extract_strided_slice` to extract a contiguous memory region from a
+/// source vector. The extraction must be contiguous and contain exactly the
+/// specified number of elements. If such an extraction shape cannot be
+/// determined, the function returns std::nullopt.
+/// Examples:
+/// sourceShape = [16], targetElements = 8
+/// Working right-to-left:
+/// - Take min(8, 16) = 8 from only dim → extractShape = [8],
+/// remaining = 8/8 = 1
+/// Result: [8]
+///
+/// sourceShape = [4, 4], targetElements = 8
+/// Working right-to-left:
+/// - Take min(8, 4) = 4 from last dim → extractShape = [4],
+/// remaining = 8/4 = 2
+/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
+/// remaining = 2/2 = 1
+/// Result: [2, 4]
static std::optional<SmallVector<int64_t>>
calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
int64_t targetElements) {
@@ -1084,12 +1086,12 @@ calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
// Convert result offsets to source offsets via linear position.
static SmallVector<int64_t>
calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
- ArrayRef<int64_t> sourceStrides,
- ArrayRef<int64_t> resultStrides) {
+ ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> resultShape) {
// Convert result offsets to linear position.
- int64_t linearIndex = linearize(resultOffsets, resultStrides);
+ int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape));
// Convert linear position to source offsets.
- return delinearize(linearIndex, sourceStrides);
+ return delinearize(linearIndex, computeStrides(sourceShape));
}
/// This pattern unrolls `vector.shape_cast` operations according to the
@@ -1142,10 +1144,10 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
ArrayRef<int64_t> sourceShape = sourceType.getShape();
ArrayRef<int64_t> resultShape = resultType.getShape();
- if (!isContiguousExtract(*targetShape, resultShape))
- return rewriter.notifyMatchFailure(shapeCastOp,
- "Only supports cases where contiguous "
- "extraction is possible");
+ if (!isContiguous(*targetShape, resultShape))
+ return rewriter.notifyMatchFailure(
+ shapeCastOp, "Only supports cases where target shape is "
+ "contiguous in result vector shape");
int64_t targetElements = ShapedType::getNumElements(*targetShape);
@@ -1168,13 +1170,11 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
SmallVector<int64_t> extractStrides(extractShape->size(), 1);
SmallVector<int64_t> insertStrides(targetShape->size(), 1);
- SmallVector<int64_t> sourceStrides = computeStrides(sourceShape);
- SmallVector<int64_t> resultStrides = computeStrides(resultShape);
for (SmallVector<int64_t> resultOffsets :
StaticTileOffsetRange(resultShape, *targetShape)) {
SmallVector<int64_t> sourceOffsets =
- calculateSourceOffsets(resultOffsets, sourceStrides, resultStrides);
+ calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
extractStrides);
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index c94a502fa3654..8e2caa39696cb 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -504,12 +504,12 @@ func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
}
// CHECK-LABEL: func @shape_cast_1D
-// CHECK-SAME: (%[[ARG0:.*]]: vector<16xf32>) -> vector<2x2x4xf32> {
+// CHECK-SAME: (%[[V:.*]]: vector<16xf32>) -> vector<2x2x4xf32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32>
-// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<8xf32> to vector<2x4xf32>
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
-// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<8xf32> to vector<2x4xf32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
// CHECK: return %[[I1]] : vector<2x2x4xf32>
@@ -521,12 +521,39 @@ func.func @shape_cast_2D(%v: vector<8x2xf32>) -> vector<4x4xf32> {
}
// CHECK-LABEL: func @shape_cast_2D
-// CHECK-SAME: (%[[ARG0:.*]]: vector<8x2xf32>) -> vector<4x4xf32> {
+// CHECK-SAME: (%[[V:.*]]: vector<8x2xf32>) -> vector<4x4xf32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
-// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<4x2xf32> to vector<2x4xf32>
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
-// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<4x2xf32> to vector<2x4xf32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
// CHECK: return %[[I1]] : vector<4x4xf32>
+
+
+// This is a negative test case to ensure that such shape casts are not unrolled
+// because the targetShape (2x4) is not contiguous in result vector
+func.func @negative_shape_cast_target_shape_not_contiguous(%v: vector<64xf32>) -> vector<8x8xf32> {
+ %0 = vector.shape_cast %v : vector<64xf32> to vector<8x8xf32>
+ return %0 : vector<8x8xf32>
+}
+
+// CHECK-LABEL: func @negative_shape_cast_target_shape_not_contiguous
+// CHECK-SAME: (%[[V:.*]]: vector<64xf32>) -> vector<8x8xf32> {
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<64xf32> to vector<8x8xf32>
+// CHECK: return %[[SC]] : vector<8x8xf32>
+
+
+// This is negative test case to ensure that such shape casts are not unrolled
+// because it cannot determine the extractShape from source vector (8x3)
+// to extract conitguous targetShape (2x4)
+func.func @negative_shape_cast_source_shape_not_determinable(%v: vector<8x3xf32>) -> vector<6x4xf32> {
+ %0 = vector.shape_cast %v : vector<8x3xf32> to vector<6x4xf32>
+ return %0 : vector<6x4xf32>
+}
+
+// CHECK-LABEL: func @negative_shape_cast_source_shape_not_determinable
+// CHECK-SAME: (%[[V:.*]]: vector<8x3xf32>) -> vector<6x4xf32> {
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<8x3xf32> to vector<6x4xf32>
+// CHECK: return %[[SC]] : vector<6x4xf32>
More information about the Mlir-commits
mailing list