[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.shape_cast (PR #167738)
Andrzej Warzyński
llvmlistbot at llvm.org
Thu Nov 20 04:15:31 PST 2025
================
@@ -1003,6 +1003,194 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
vector::UnrollVectorOptions options;
};
+/// Checks whether extractShape is contiguous in shape.
+/// For extractShape to be contiguous in shape:
+/// 1) The inner dimensions of extractShape and shape must match exactly.
+/// 2) The total number of elements in shape must be evenly divisible by
+/// the total number of elements in extractShape.
+/// Examples:
+/// isContiguous([4, 4], [8, 4]) == true
+/// isContiguous([2, 4], [8, 4]) == true
+/// isContiguous([2, 2], [8, 4]) == false
+/// Removes leading unit dimensions to handle cases like:
+/// isContiguous([1, 16], [1, 32]) == true
+static bool isContiguous(ArrayRef<int64_t> extractShape,
+ ArrayRef<int64_t> shape) {
+
+ if (extractShape.size() > shape.size())
+ return false;
+
+ while (!extractShape.empty() && extractShape.front() == 1) {
+ extractShape = extractShape.drop_front();
+ }
+
+ while (!shape.empty() && shape.front() == 1) {
+ shape = shape.drop_front();
+ }
+
+ size_t rankDiff = shape.size() - extractShape.size();
+ if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1)))
+ return false;
+
+ int64_t extractElements = ShapedType::getNumElements(extractShape);
+ int64_t shapeElements = ShapedType::getNumElements(shape);
+ return shapeElements % extractElements == 0;
+}
+
+/// This function determines what shape to use with
+/// `vector.extract_strided_slice` to extract a contiguous memory region from a
+/// source vector. The extraction must be contiguous and contain exactly the
+/// specified number of elements. If such an extraction shape cannot be
+/// determined, the function returns std::nullopt.
+/// Examples:
+/// sourceShape = [16], targetElements = 8
+/// Working right-to-left:
+/// - Take min(8, 16) = 8 from only dim → extractShape = [8],
+/// remaining = 8/8 = 1
+/// Result: [8]
+///
+/// sourceShape = [4, 4], targetElements = 8
+/// Working right-to-left:
+/// - Take min(8, 4) = 4 from last dim → extractShape = [4],
+/// remaining = 8/4 = 2
+/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
+/// remaining = 2/2 = 1
+/// Result: [2, 4]
+static std::optional<SmallVector<int64_t>>
+calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
+ int64_t targetElements) {
+ SmallVector<int64_t> extractShape;
+ int64_t remainingElements = targetElements;
+
+ // Build extract shape from innermost dimension outward to ensure contiguity.
+ for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
+ int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
+ extractShape.insert(extractShape.begin(), takeFromDim);
+
+ if (remainingElements % takeFromDim != 0)
+ return std::nullopt; // Not evenly divisible.
+ remainingElements /= takeFromDim;
+ }
+
+ // Fill remaining dimensions with 1.
+ while (extractShape.size() < sourceShape.size())
+ extractShape.insert(extractShape.begin(), 1);
+
+ if (ShapedType::getNumElements(extractShape) != targetElements)
+ return std::nullopt;
+
+ return extractShape;
+}
+
+// Convert result offsets to source offsets via linear position.
+static SmallVector<int64_t>
+calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
+ ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> resultShape) {
----------------
banach-space wrote:
> I meant if the shapeCast is not maintaining the semantics of NumElements(src) == NumElements(dst) how is it even a valid instruction?
shapeCast verifier will indeed maintain that, but only for shapeCast Ops. However, how do you make sure that the inputs used in this method always come from shapeCast? Perhaps I am missing something, but what is stopping anyone/anything from using this method with some random arrays that don't come from shapeCast?
https://github.com/llvm/llvm-project/pull/167738
More information about the Mlir-commits
mailing list