[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.shape_cast (PR #167738)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Nov 18 02:12:05 PST 2025


================
@@ -1003,6 +1003,195 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
   vector::UnrollVectorOptions options;
 };
 
+static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
+                                ArrayRef<int64_t> resultShape) {
+  if (targetShape.size() > resultShape.size())
+    return false;
+
+  int64_t targetElements = ShapedType::getNumElements(targetShape);
+  int64_t resultElements = ShapedType::getNumElements(resultShape);
+
+  // Result must be evenly divisible by target.
+  if (resultElements % targetElements != 0)
+    return false;
+
+  // For contiguous extraction, we need to be able to
+  // extract targetElements contiguously from the result shape.
+  // This means we can "consume" dimensions from the innermost outward
+  // until we have exactly targetElements.
+
+  int64_t remainingElements = targetElements;
+  int targetDimIdx = targetShape.size() - 1;
+
+  // Work backwards through result dimensions.
+  for (int resultDimIdx = resultShape.size() - 1;
+       resultDimIdx >= 0 && remainingElements > 1 && targetDimIdx >= 0;
+       --resultDimIdx) {
+
+    int64_t resultDimSize = resultShape[resultDimIdx];
+    int64_t targetDimSize = targetShape[targetDimIdx];
+
+    if (targetDimSize > resultDimSize)
+      return false;
+
+    if (targetDimSize == resultDimSize) {
+      if (remainingElements % targetDimSize != 0)
+        return false;
+      remainingElements /= targetDimSize;
+      --targetDimIdx;
+    } else {
+      if (remainingElements != targetDimSize)
+        return false;
+      remainingElements = 1;
+      --targetDimIdx;
+    }
+  }
+
+  // Check remaining target dimensions are all 1 and we consumed all elements
+  return remainingElements == 1 &&
+         (targetDimIdx < 0 || llvm::all_of(
+                                  targetShape.take_front(targetDimIdx + 1),
+                                  [](int64_t d) { return d == 1; }));
+}
+
+// Calculate the shape to extract from source.
+static std::optional<SmallVector<int64_t>>
+calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
+                            int64_t targetElements) {
+  SmallVector<int64_t> extractShape;
+  int64_t remainingElements = targetElements;
+
+  // Build extract shape from innermost dimension outward to ensure contiguity.
+  for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
+    int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
+    extractShape.insert(extractShape.begin(), takeFromDim);
+
+    if (remainingElements % takeFromDim != 0)
+      return std::nullopt; // Not evenly divisible.
+    remainingElements /= takeFromDim;
+  }
+
+  // Fill remaining dimensions with 1.
+  while (extractShape.size() < sourceShape.size())
+    extractShape.insert(extractShape.begin(), 1);
+
+  if (ShapedType::getNumElements(extractShape) != targetElements)
+    return std::nullopt;
+
+  return extractShape;
+}
+
+// Convert result offsets to source offsets via linear position.
+static SmallVector<int64_t>
+calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
+                       ArrayRef<int64_t> sourceStrides,
+                       ArrayRef<int64_t> resultStrides) {
+  // Convert result offsets to linear position.
+  int64_t linearIndex = linearize(resultOffsets, resultStrides);
+  // Convert linear position to source offsets.
+  return delinearize(linearIndex, sourceStrides);
+}
+
+/// This pattern unrolls `vector.shape_cast` operations according to the
+/// provided target unroll shape. It unrolls a large shape cast into smaller
+/// shape casts by extracting contiguous slices from the source vector, casting
+/// each slice to the target shape, and assembling the result by inserting each
+/// computed segment into the appropriate offset of the result vector.
+///
+/// This pattern only applies when contiguous slices can be extracted from the
+/// source vector and inserted into the result vector such that each slice
+/// remains a valid vector (and not decompose to scalars). In these cases, the
+/// unrolling proceeds as:
+/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
+/// vector.insert_strided_slice.
+///
+/// Example:
+///   Given a shape cast operation:
+///     %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
+///
+///   and a target unroll shape of <2x4>, the pattern produces:
+///
+///     %zero = arith.constant dense<0.0> : vector<4x4xf32>
+///     %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
+///       : vector<8x2xf32> to vector<4x2xf32>
+///     %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
+///     %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
+///       : vector<2x4xf32> into vector<4x4xf32>
+///     %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
+///       : vector<8x2xf32> to vector<4x2xf32>
+///     %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
+///     %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
+///       : vector<2x4xf32> into vector<4x4xf32>
+///
+struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
+  UnrollShapeCastPattern(MLIRContext *context,
+                         const vector::UnrollVectorOptions &options,
+                         PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    std::optional<SmallVector<int64_t>> targetShape =
+        getTargetShape(options, shapeCastOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType sourceType = shapeCastOp.getSourceVectorType();
+    VectorType resultType = shapeCastOp.getResultVectorType();
+    ArrayRef<int64_t> sourceShape = sourceType.getShape();
+    ArrayRef<int64_t> resultShape = resultType.getShape();
+
+    if (!isContiguousExtract(*targetShape, resultShape))
+      return rewriter.notifyMatchFailure(shapeCastOp,
+                                         "Only supports cases where contiguous "
+                                         "extraction is possible");
----------------
banach-space wrote:

I find this check confusing - iiuc, it verifies that the original `vector.shape_cast` is ... contiguous? Isn't `vector.shape_cast` _always_ contiguous? Do you have a counter-example? I am probably confusing `targetShape` and `resultShape`? If yes, should we try different names?

https://github.com/llvm/llvm-project/pull/167738


More information about the Mlir-commits mailing list