[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.shape_cast (PR #167738)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Nov 18 02:12:05 PST 2025


================
@@ -1003,6 +1003,195 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
   vector::UnrollVectorOptions options;
 };
 
+static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
+                                ArrayRef<int64_t> resultShape) {
+  if (targetShape.size() > resultShape.size())
+    return false;
+
+  int64_t targetElements = ShapedType::getNumElements(targetShape);
+  int64_t resultElements = ShapedType::getNumElements(resultShape);
+
+  // Result must be evenly divisible by target.
+  if (resultElements % targetElements != 0)
+    return false;
+
+  // For contiguous extraction, we need to be able to
+  // extract targetElements contiguously from the result shape.
+  // This means we can "consume" dimensions from the innermost outward
+  // until we have exactly targetElements.
+
+  int64_t remainingElements = targetElements;
+  int targetDimIdx = targetShape.size() - 1;
+
+  // Work backwards through result dimensions.
+  for (int resultDimIdx = resultShape.size() - 1;
+       resultDimIdx >= 0 && remainingElements > 1 && targetDimIdx >= 0;
+       --resultDimIdx) {
+
+    int64_t resultDimSize = resultShape[resultDimIdx];
+    int64_t targetDimSize = targetShape[targetDimIdx];
+
+    if (targetDimSize > resultDimSize)
+      return false;
+
+    if (targetDimSize == resultDimSize) {
+      if (remainingElements % targetDimSize != 0)
+        return false;
+      remainingElements /= targetDimSize;
+      --targetDimIdx;
+    } else {
+      if (remainingElements != targetDimSize)
+        return false;
+      remainingElements = 1;
+      --targetDimIdx;
+    }
+  }
+
+  // Check remaining target dimensions are all 1 and we consumed all elements
+  return remainingElements == 1 &&
+         (targetDimIdx < 0 || llvm::all_of(
+                                  targetShape.take_front(targetDimIdx + 1),
+                                  [](int64_t d) { return d == 1; }));
+}
+
+// Calculate the shape to extract from source.
----------------
banach-space wrote:

This comment is just repeating the name of the function. Could you add more detailed explanation? How are `sourceShape` and `targetElements` related? What do these mean? How is the shape calculated (high-level description, fine details are available within the implementation)?

https://github.com/llvm/llvm-project/pull/167738


More information about the Mlir-commits mailing list