[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.shape_cast (PR #167738)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue Nov 18 02:12:05 PST 2025
================
@@ -1003,6 +1003,195 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
vector::UnrollVectorOptions options;
};
+static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
+ ArrayRef<int64_t> resultShape) {
+ if (targetShape.size() > resultShape.size())
+ return false;
+
+ int64_t targetElements = ShapedType::getNumElements(targetShape);
+ int64_t resultElements = ShapedType::getNumElements(resultShape);
+
+ // Result must be evenly divisible by target.
+ if (resultElements % targetElements != 0)
+ return false;
+
+ // For contiguous extraction, we need to be able to
+ // extract targetElements contiguously from the result shape.
+ // This means we can "consume" dimensions from the innermost outward
+ // until we have exactly targetElements.
+
+ int64_t remainingElements = targetElements;
+ int targetDimIdx = targetShape.size() - 1;
+
+ // Work backwards through result dimensions.
+ for (int resultDimIdx = resultShape.size() - 1;
+ resultDimIdx >= 0 && remainingElements > 1 && targetDimIdx >= 0;
+ --resultDimIdx) {
+
+ int64_t resultDimSize = resultShape[resultDimIdx];
+ int64_t targetDimSize = targetShape[targetDimIdx];
+
+ if (targetDimSize > resultDimSize)
+ return false;
+
+ if (targetDimSize == resultDimSize) {
+ if (remainingElements % targetDimSize != 0)
+ return false;
+ remainingElements /= targetDimSize;
+ --targetDimIdx;
+ } else {
+ if (remainingElements != targetDimSize)
+ return false;
+ remainingElements = 1;
+ --targetDimIdx;
+ }
+ }
+
+ // Check remaining target dimensions are all 1 and we consumed all elements
+ return remainingElements == 1 &&
+ (targetDimIdx < 0 || llvm::all_of(
+ targetShape.take_front(targetDimIdx + 1),
+ [](int64_t d) { return d == 1; }));
+}
+
+// Calculate the shape to extract from source.
----------------
banach-space wrote:
This comment is just repeating the name of the function. Could you add more detailed explanation? How are `sourceShape` and `targetElements` related? What do these mean? How is the shape calculated (high-level description, fine details are available within the implementation)?
https://github.com/llvm/llvm-project/pull/167738
More information about the Mlir-commits
mailing list