[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.shape_cast (PR #167738)
Andrzej Warzyński
llvmlistbot at llvm.org
Wed Nov 19 02:06:37 PST 2025
================
@@ -1003,6 +1003,194 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
vector::UnrollVectorOptions options;
};
+/// Checks whether extractShape is contiguous in shape.
+/// For extractShape to be contiguous in shape:
+/// 1) The inner dimensions of extractShape and shape must match exactly.
+/// 2) The total number of elements in shape must be evenly divisible by
+/// the total number of elements in extractShape.
+/// Examples:
+/// isContiguous([4, 4], [8, 4]) == true
+/// isContiguous([2, 4], [8, 4]) == true
+/// isContiguous([2, 2], [8, 4]) == false
+/// Removes leading unit dimensions to handle cases like:
+/// isContiguous([1, 16], [1, 32]) == true
+static bool isContiguous(ArrayRef<int64_t> extractShape,
+ ArrayRef<int64_t> shape) {
+
+ if (extractShape.size() > shape.size())
+ return false;
+
+ while (!extractShape.empty() && extractShape.front() == 1) {
+ extractShape = extractShape.drop_front();
+ }
+
+ while (!shape.empty() && shape.front() == 1) {
+ shape = shape.drop_front();
+ }
+
+ size_t rankDiff = shape.size() - extractShape.size();
+ if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1)))
+ return false;
+
+ int64_t extractElements = ShapedType::getNumElements(extractShape);
+ int64_t shapeElements = ShapedType::getNumElements(shape);
+ return shapeElements % extractElements == 0;
+}
+
+/// This function determines what shape to use with
+/// `vector.extract_strided_slice` to extract a contiguous memory region from a
+/// source vector. The extraction must be contiguous and contain exactly the
+/// specified number of elements. If such an extraction shape cannot be
+/// determined, the function returns std::nullopt.
+/// Examples:
+/// sourceShape = [16], targetElements = 8
+/// Working right-to-left:
+/// - Take min(8, 16) = 8 from only dim → extractShape = [8],
+/// remaining = 8/8 = 1
+/// Result: [8]
+///
+/// sourceShape = [4, 4], targetElements = 8
+/// Working right-to-left:
+/// - Take min(8, 4) = 4 from last dim → extractShape = [4],
+/// remaining = 8/4 = 2
+/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
+/// remaining = 2/2 = 1
+/// Result: [2, 4]
----------------
banach-space wrote:
[nit[ Just to make it a bit easier to parse.
```suggestion
/// EXAMPLE 1:
/// sourceShape = [16], targetElements = 8
/// Working right-to-left:
/// - Take min(8, 16) = 8 from only dim → extractShape = [8],
/// remaining = 8/8 = 1
/// Result: [8]
///
/// EXAMPLE 2:
/// sourceShape = [4, 4], targetElements = 8
/// Working right-to-left:
/// - Take min(8, 4) = 4 from last dim → extractShape = [4],
/// remaining = 8/4 = 2
/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
/// remaining = 2/2 = 1
/// Result: [2, 4]
```
https://github.com/llvm/llvm-project/pull/167738
More information about the Mlir-commits
mailing list