[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.shape_cast (PR #167738)
Nishant Patel
llvmlistbot at llvm.org
Mon Nov 17 13:52:30 PST 2025
================
@@ -1003,6 +1003,195 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
vector::UnrollVectorOptions options;
};
+static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
+ ArrayRef<int64_t> resultShape) {
+ if (targetShape.size() > resultShape.size())
+ return false;
+
+ int64_t targetElements = ShapedType::getNumElements(targetShape);
+ int64_t resultElements = ShapedType::getNumElements(resultShape);
+
+ // Result must be evenly divisible by target.
+ if (resultElements % targetElements != 0)
+ return false;
+
+ // For contiguous extraction, we need to be able to
+ // extract targetElements contiguously from the result shape.
+ // This means we can "consume" dimensions from the innermost outward
+ // until we have exactly targetElements.
+
+ int64_t remainingElements = targetElements;
+ int targetDimIdx = targetShape.size() - 1;
+
+ // Work backwards through result dimensions.
+ for (int resultDimIdx = resultShape.size() - 1;
+ resultDimIdx >= 0 && remainingElements > 1 && targetDimIdx >= 0;
+ --resultDimIdx) {
+
+ int64_t resultDimSize = resultShape[resultDimIdx];
+ int64_t targetDimSize = targetShape[targetDimIdx];
+
+ if (targetDimSize > resultDimSize)
+ return false;
+
+ if (targetDimSize == resultDimSize) {
+ if (remainingElements % targetDimSize != 0)
+ return false;
+ remainingElements /= targetDimSize;
+ --targetDimIdx;
+ } else {
+ if (remainingElements != targetDimSize)
+ return false;
+ remainingElements = 1;
+ --targetDimIdx;
+ }
+ }
+
+ // Check remaining target dimensions are all 1 and we consumed all elements
+ return remainingElements == 1 &&
+ (targetDimIdx < 0 || llvm::all_of(
+ targetShape.take_front(targetDimIdx + 1),
+ [](int64_t d) { return d == 1; }));
+}
+
+// Calculate the shape to extract from source.
+static std::optional<SmallVector<int64_t>>
+calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
----------------
nbpatel wrote:
This computes the shape for the extract_strided_slice to extract from source vector ...isContiguousExtract checks if a particular targetShape is contiguous in the resultShape.. maybe "Extract" in the function name is confusing.. I will change the name
https://github.com/llvm/llvm-project/pull/167738
More information about the Mlir-commits
mailing list