[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.shape_cast (PR #167738)

Andrzej Warzyński llvmlistbot at llvm.org
Wed Nov 19 11:50:26 PST 2025


================
@@ -1003,6 +1003,194 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
   vector::UnrollVectorOptions options;
 };
 
+/// Checks whether extractShape is contiguous in shape.
+/// For extractShape to be contiguous in shape:
+/// 1) The inner dimensions of extractShape and shape must match exactly.
+/// 2) The total number of elements in shape must be evenly divisible by
+///    the total number of elements in extractShape.
+/// Examples:
+///   isContiguous([4, 4], [8, 4]) == true
+///   isContiguous([2, 4], [8, 4]) == true
+///   isContiguous([2, 2], [8, 4]) == false
+/// Removes leading unit dimensions to handle cases like:
+///   isContiguous([1, 16], [1, 32]) == true
+static bool isContiguous(ArrayRef<int64_t> extractShape,
+                         ArrayRef<int64_t> shape) {
+
+  if (extractShape.size() > shape.size())
+    return false;
+
+  while (!extractShape.empty() && extractShape.front() == 1) {
+    extractShape = extractShape.drop_front();
+  }
+
+  while (!shape.empty() && shape.front() == 1) {
+    shape = shape.drop_front();
+  }
+
+  size_t rankDiff = shape.size() - extractShape.size();
+  if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1)))
+    return false;
+
+  int64_t extractElements = ShapedType::getNumElements(extractShape);
+  int64_t shapeElements = ShapedType::getNumElements(shape);
+  return shapeElements % extractElements == 0;
+}
+
+/// This function determines what shape to use with
+/// `vector.extract_strided_slice` to extract a contiguous memory region from a
+/// source vector. The extraction must be contiguous and contain exactly the
+/// specified number of elements. If such an extraction shape cannot be
+/// determined, the function returns std::nullopt.
+/// Examples:
+///   sourceShape = [16], targetElements = 8
+///   Working right-to-left:
+///   - Take min(8, 16) = 8 from only dim → extractShape = [8],
+///     remaining = 8/8 = 1
+///     Result: [8]
+///
+///   sourceShape = [4, 4], targetElements = 8
+///   Working right-to-left:
+///   - Take min(8, 4) = 4 from last dim → extractShape = [4],
+///     remaining = 8/4 = 2
+///   - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
+///     remaining = 2/2 = 1
+///     Result: [2, 4]
+static std::optional<SmallVector<int64_t>>
+calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
+                            int64_t targetElements) {
+  SmallVector<int64_t> extractShape;
+  int64_t remainingElements = targetElements;
+
+  // Build extract shape from innermost dimension outward to ensure contiguity.
+  for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
+    int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
+    extractShape.insert(extractShape.begin(), takeFromDim);
+
+    if (remainingElements % takeFromDim != 0)
+      return std::nullopt; // Not evenly divisible.
+    remainingElements /= takeFromDim;
+  }
+
+  // Fill remaining dimensions with 1.
+  while (extractShape.size() < sourceShape.size())
+    extractShape.insert(extractShape.begin(), 1);
+
+  if (ShapedType::getNumElements(extractShape) != targetElements)
+    return std::nullopt;
+
+  return extractShape;
+}
+
+// Convert result offsets to source offsets via linear position.
+static SmallVector<int64_t>
+calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
+                       ArrayRef<int64_t> sourceShape,
+                       ArrayRef<int64_t> resultShape) {
----------------
banach-space wrote:

> wouldn't shape_cast op verifier take care of that?

Of making sure that no invalid inputs are ever passed to this method? I doubt that ;-) 

https://github.com/llvm/llvm-project/pull/167738


More information about the Mlir-commits mailing list