[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.shape_cast (PR #167738)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue Nov 18 02:12:05 PST 2025
================
@@ -1003,6 +1003,195 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
vector::UnrollVectorOptions options;
};
+static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
+ ArrayRef<int64_t> resultShape) {
+ if (targetShape.size() > resultShape.size())
+ return false;
+
+ int64_t targetElements = ShapedType::getNumElements(targetShape);
+ int64_t resultElements = ShapedType::getNumElements(resultShape);
+
+ // Result must be evenly divisible by target.
+ if (resultElements % targetElements != 0)
+ return false;
+
+ // For contiguous extraction, we need to be able to
+ // extract targetElements contiguously from the result shape.
+ // This means we can "consume" dimensions from the innermost outward
+ // until we have exactly targetElements.
+
+ int64_t remainingElements = targetElements;
+ int targetDimIdx = targetShape.size() - 1;
+
+ // Work backwards through result dimensions.
+ for (int resultDimIdx = resultShape.size() - 1;
+ resultDimIdx >= 0 && remainingElements > 1 && targetDimIdx >= 0;
+ --resultDimIdx) {
+
+ int64_t resultDimSize = resultShape[resultDimIdx];
+ int64_t targetDimSize = targetShape[targetDimIdx];
+
+ if (targetDimSize > resultDimSize)
+ return false;
+
+ if (targetDimSize == resultDimSize) {
+ if (remainingElements % targetDimSize != 0)
+ return false;
+ remainingElements /= targetDimSize;
+ --targetDimIdx;
+ } else {
+ if (remainingElements != targetDimSize)
+ return false;
+ remainingElements = 1;
+ --targetDimIdx;
+ }
+ }
+
+ // Check remaining target dimensions are all 1 and we consumed all elements
+ return remainingElements == 1 &&
+ (targetDimIdx < 0 || llvm::all_of(
+ targetShape.take_front(targetDimIdx + 1),
+ [](int64_t d) { return d == 1; }));
+}
+
+// Calculate the shape to extract from source.
+static std::optional<SmallVector<int64_t>>
+calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
+ int64_t targetElements) {
+ SmallVector<int64_t> extractShape;
+ int64_t remainingElements = targetElements;
+
+ // Build extract shape from innermost dimension outward to ensure contiguity.
+ for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
+ int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
+ extractShape.insert(extractShape.begin(), takeFromDim);
+
+ if (remainingElements % takeFromDim != 0)
+ return std::nullopt; // Not evenly divisible.
+ remainingElements /= takeFromDim;
+ }
+
+ // Fill remaining dimensions with 1.
+ while (extractShape.size() < sourceShape.size())
+ extractShape.insert(extractShape.begin(), 1);
+
+ if (ShapedType::getNumElements(extractShape) != targetElements)
+ return std::nullopt;
+
+ return extractShape;
+}
+
+// Convert result offsets to source offsets via linear position.
+static SmallVector<int64_t>
+calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
+ ArrayRef<int64_t> sourceStrides,
+ ArrayRef<int64_t> resultStrides) {
+ // Convert result offsets to linear position.
+ int64_t linearIndex = linearize(resultOffsets, resultStrides);
+ // Convert linear position to source offsets.
+ return delinearize(linearIndex, sourceStrides);
+}
+
+/// This pattern unrolls `vector.shape_cast` operations according to the
+/// provided target unroll shape. It unrolls a large shape cast into smaller
+/// shape casts by extracting contiguous slices from the source vector, casting
+/// each slice to the target shape, and assembling the result by inserting each
+/// computed segment into the appropriate offset of the result vector.
+///
+/// This pattern only applies when contiguous slices can be extracted from the
+/// source vector and inserted into the result vector such that each slice
+/// remains a valid vector (and not decompose to scalars). In these cases, the
+/// unrolling proceeds as:
+/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
+/// vector.insert_strided_slice.
+///
+/// Example:
+/// Given a shape cast operation:
+/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
+///
+/// and a target unroll shape of <2x4>, the pattern produces:
+///
+/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
+/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
+/// : vector<8x2xf32> to vector<4x2xf32>
+/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
+/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
+/// : vector<2x4xf32> into vector<4x4xf32>
+/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
+/// : vector<8x2xf32> to vector<4x2xf32>
+/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
+/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
+/// : vector<2x4xf32> into vector<4x4xf32>
+///
+struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
+ UnrollShapeCastPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ std::optional<SmallVector<int64_t>> targetShape =
+ getTargetShape(options, shapeCastOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType sourceType = shapeCastOp.getSourceVectorType();
+ VectorType resultType = shapeCastOp.getResultVectorType();
+ ArrayRef<int64_t> sourceShape = sourceType.getShape();
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+
+ if (!isContiguousExtract(*targetShape, resultShape))
+ return rewriter.notifyMatchFailure(shapeCastOp,
+ "Only supports cases where contiguous "
+ "extraction is possible");
----------------
banach-space wrote:
I find this check confusing - iiuc, it verifies that the original `vector.shape_cast` is ... contiguous? Isn't `vector.shape_cast` _always_ contiguous? Do you have a counter-example? I am probably confusing `targetShape` and `resultShape`? If yes, should we try different names?
https://github.com/llvm/llvm-project/pull/167738
More information about the Mlir-commits
mailing list